import torchvision
from mmengine.dataset import InfiniteSampler
from diffengine.datasets import HFDreamBoothDataset
from diffengine.datasets.transforms import (
ComputePixArtImgInfo,
DumpImage,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
SaveImageShape,
T5TextPreprocess,
TorchVisonTransformWrapper,
)
from diffengine.engine.hooks import PeftSaveHook, VisualizationHook
[docs]train_pipeline = [
dict(type=SaveImageShape),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=1024, interpolation="bilinear"),
dict(type=RandomCrop, size=1024),
dict(type=RandomHorizontalFlip, p=0.5),
dict(type=ComputePixArtImgInfo),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor),
dict(type=DumpImage, max_imgs=5, dump_dir="work_dirs/dump"),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
dict(type=T5TextPreprocess),
dict(type=PackInputs, input_keys=["img", "text", "resolution", "aspect_ratio"]),
]
[docs]train_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dict(
type=HFDreamBoothDataset,
dataset="data/cat_waterpainting",
instance_prompt="A cat in szn style",
pipeline=train_pipeline,
class_prompt=None),
sampler=dict(type=InfiniteSampler, shuffle=True),
)
[docs]test_dataloader = val_dataloader
[docs]test_evaluator = val_evaluator
[docs]custom_hooks = [
dict(
type=VisualizationHook,
prompt=["A man in szn style"] * 4,
by_epoch=False,
interval=100,
height=1024,
width=1024),
dict(type=PeftSaveHook),
]