Source code for diffengine.engine.hooks.controlnet_save_hook

import os.path as osp
from collections import OrderedDict

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS
from mmengine.runner import Runner


@HOOKS.register_module()
[docs]class ControlNetSaveHook(Hook): """ControlNet Save Hook. Save ControlNet weights with diffusers format and pick up ControlNet weights from checkpoint. """
[docs] priority = "VERY_LOW"
[docs] def before_save_checkpoint(self, runner: Runner, checkpoint: dict) -> None: """Before save checkpoint hook. Args: ---- runner (Runner): The runner of the training, validation or testing process. checkpoint (dict): Model's checkpoint. """ model = runner.model if is_model_wrapper(model): model = model.module ckpt_path = osp.join(runner.work_dir, f"step{runner.iter}") model.controlnet.save_pretrained(osp.join(ckpt_path, "controlnet")) # not save no grad key new_ckpt = OrderedDict() sd_keys = checkpoint["state_dict"].keys() for k in sd_keys: if "controlnet" in k: new_ckpt[k] = checkpoint["state_dict"][k] checkpoint["state_dict"] = new_ckpt