Source code for diffengine.engine.hooks.unet_ema_hook

import copy
import logging

from mmengine.hooks.ema_hook import EMAHook
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS
from mmengine.runner import Runner


@HOOKS.register_module()
[docs]class UnetEMAHook(EMAHook): """Unet EMA Hook."""
[docs] def before_run(self, runner: Runner) -> None: """Create an ema copy of the model. Args: ---- runner (Runner): The runner of the training process. """ model = runner.model if is_model_wrapper(model): model = model.module self.src_model = model.unet self.ema_model = MODELS.build( self.ema_cfg, default_args={"model": self.src_model})
[docs] def _swap_ema_state_dict(self, checkpoint: dict) -> None: """Swap the state dict values of model with ema_model.""" model_state = checkpoint["state_dict"] ema_state = checkpoint["ema_state_dict"] for k in ema_state: if k[:7] == "module.": tmp = ema_state[k] # 'module.' -> 'unet.' ema_state[k] = model_state["unet." + k[7:]] model_state["unet." + k[7:]] = tmp
[docs] def after_load_checkpoint(self, runner: Runner, checkpoint: dict) -> None: """Resume ema parameters from checkpoint. Args: ---- runner (Runner): The runner of the testing process. checkpoint (dict): Model's checkpoint. """ from mmengine.runner.checkpoint import load_state_dict if "ema_state_dict" in checkpoint and runner._resume: # noqa # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) self.ema_model.load_state_dict( checkpoint["ema_state_dict"], strict=self.strict_load) # Support load checkpoint without ema state dict. else: if runner._resume: # noqa print_log( "There is no `ema_state_dict` in checkpoint. " "`EMAHook` will make a copy of `state_dict` as the " "initial `ema_state_dict`", "current", logging.WARNING) sd = copy.deepcopy(checkpoint["state_dict"]) new_sd = {} for k, v in sd.items(): if k.startswith("unet."): new_sd[k[5:]] = v load_state_dict( self.ema_model.module, new_sd, strict=self.strict_load)