Source code for gogh_esd_xl

from mmengine.dataset import InfiniteSampler

from diffengine.datasets import HFESDDatasetPreComputeEmbs
from diffengine.datasets.transforms import PackInputs
from diffengine.engine.hooks import SDCheckpointHook, VisualizationHook

[docs]train_pipeline = [ dict( type=PackInputs, input_keys=[ "text", "prompt_embeds", "pooled_prompt_embeds", "null_prompt_embeds", "null_pooled_prompt_embeds", ]), ]
[docs]train_dataloader = dict( batch_size=1, num_workers=1, dataset=dict( type=HFESDDatasetPreComputeEmbs, forget_caption="Van Gogh", model="stabilityai/stable-diffusion-xl-base-1.0", pipeline=train_pipeline), sampler=dict(type=InfiniteSampler, shuffle=True), )
[docs]val_dataloader = None
[docs]val_evaluator = None
[docs]test_dataloader = val_dataloader
[docs]test_evaluator = val_evaluator
[docs]custom_hooks = [ dict( type=VisualizationHook, prompt=["The starry night by Van Gogh"] * 4, by_epoch=False, interval=100, height=1024, width=1024), dict(type=SDCheckpointHook), ]