Source code for pokemon_blip_wuerstchen

import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDataset
from diffengine.datasets.transforms import (
    PackInputs,
    RandomCrop,
    RandomHorizontalFlip,
    TorchVisonTransformWrapper,
)
from diffengine.engine.hooks import PriorSaveHook, VisualizationHook

[docs]train_pipeline = [ dict(type=TorchVisonTransformWrapper, transform=torchvision.transforms.Resize, size=768, interpolation="bilinear"), dict(type=RandomCrop, size=768), dict(type=RandomHorizontalFlip, p=0.5), dict(type=TorchVisonTransformWrapper, transform=torchvision.transforms.ToTensor), dict(type=TorchVisonTransformWrapper, transform=torchvision.transforms.Normalize, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), dict(type=PackInputs), ]
[docs]train_dataloader = dict( batch_size=4, num_workers=4, dataset=dict( type=HFDataset, dataset="lambdalabs/pokemon-blip-captions", pipeline=train_pipeline), sampler=dict(type=DefaultSampler, shuffle=True), )
[docs]val_dataloader = None
[docs]val_evaluator = None
[docs]test_dataloader = val_dataloader
[docs]test_evaluator = val_evaluator
[docs]custom_hooks = [ dict(type=VisualizationHook, prompt=["A robot pokemon, 4k photo"] * 4, height=768, width=768), dict(type=PriorSaveHook), ]