import os.path as osp
from collections import OrderedDict
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS
from mmengine.runner import Runner
@HOOKS.register_module()
[docs]class T2IAdapterSaveHook(Hook):
"""T2I Adapter Save Hook.
Save T2I-Adapter weights with diffusers format and pick up weights from
checkpoint.
"""
[docs] def before_save_checkpoint(self, runner: Runner, checkpoint: dict) -> None:
"""Before save checkpoint hook.
Args:
----
runner (Runner): The runner of the training, validation or testing
process.
checkpoint (dict): Model's checkpoint.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module
ckpt_path = osp.join(runner.work_dir, f"step{runner.iter}")
model.adapter.save_pretrained(osp.join(ckpt_path, "adapter"))
# not save no grad key
new_ckpt = OrderedDict()
sd_keys = checkpoint["state_dict"].keys()
for k in sd_keys:
if k.startswith("adapter"):
new_ckpt[k] = checkpoint["state_dict"][k]
checkpoint["state_dict"] = new_ckpt