Source code for stable_diffusion_3e

from mmengine.hooks import CheckpointHook
from mmengine.optim import AmpOptimWrapper
from torch.optim import AdamW

[docs]optim_wrapper = dict( type=AmpOptimWrapper, dtype="float16", optimizer=dict(type=AdamW, lr=1e-4, weight_decay=1e-2), clip_grad=dict(max_norm=1.0))
# train, val, test setting
[docs]train_cfg = dict(by_epoch=True, max_epochs=3)
[docs]val_cfg = None
[docs]test_cfg = None
[docs]default_hooks = dict( checkpoint=dict( type=CheckpointHook, interval=1, max_keep_ckpts=3, ))