Source code for stable_diffusion_1k

from mmengine.hooks import CheckpointHook
from mmengine.optim import AmpOptimWrapper
from mmengine.runner import IterBasedTrainLoop
from torch.optim import AdamW

[docs]optim_wrapper = dict( type=AmpOptimWrapper, dtype="float16", optimizer=dict(type=AdamW, lr=1e-4, weight_decay=1e-2), clip_grad=dict(max_norm=1.0))
# train, val, test setting
[docs]train_cfg = dict(type=IterBasedTrainLoop, max_iters=1000)
[docs]val_cfg = None
[docs]test_cfg = None
[docs]default_hooks = dict( checkpoint=dict( type=CheckpointHook, interval=100, by_epoch=False, max_keep_ckpts=3, ))
[docs]log_processor = dict(by_epoch=False)