import torchvision
from mmengine.dataset import DefaultSampler
from diffengine.datasets import HFControlNetDataset
from diffengine.datasets.transforms import (
ComputeTimeIds,
DumpImage,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
SaveImageShape,
TorchVisonTransformWrapper,
)
from diffengine.engine.hooks import SDCheckpointHook, VisualizationHook
[docs]train_pipeline = [
dict(type=SaveImageShape),
dict(
type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=768,
interpolation="bilinear",
keys=["img", "condition_img"]),
dict(type=RandomCrop, size=768, keys=["img", "condition_img"],
force_same_size=False),
dict(type=RandomHorizontalFlip, p=0.5, keys=["img", "condition_img"]),
dict(type=ComputeTimeIds),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor, keys=["img", "condition_img"]),
dict(type=DumpImage, max_imgs=10, dump_dir="work_dirs/dump"),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5],
keys=["img", "condition_img"]),
dict(
type=PackInputs,
input_keys=["img", "condition_img", "text", "time_ids"]),
]
[docs]train_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dict(
type=HFControlNetDataset,
dataset="instruction-tuning-sd/cartoonization",
image_column="cartoonized_image",
condition_column="original_image",
caption_column="edit_prompt",
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
[docs]test_dataloader = val_dataloader
[docs]test_evaluator = val_evaluator
[docs]custom_hooks = [
dict(
type=VisualizationHook,
prompt=["Generate a cartoonized version of the natural image"] * 4,
condition_image=[
'https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png' # noqa
] * 4,
height=768,
width=768),
dict(type=SDCheckpointHook),
]