Source code for stable_diffusion_1e
from mmengine.hooks import CheckpointHook
from mmengine.optim import AmpOptimWrapper
from torch.optim import AdamW
[docs]optim_wrapper = dict(
type=AmpOptimWrapper,
dtype="float16",
optimizer=dict(type=AdamW, lr=1e-4, weight_decay=1e-2),
clip_grad=dict(max_norm=1.0))
# train, val, test setting