Source code for diffengine.engine.hooks.lcm_ema_update_hook

from mmengine.hooks.hook import DATA_BATCH, Hook
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS
from mmengine.runner import Runner


@HOOKS.register_module()
[docs]class LCMEMAUpdateHook(Hook): """LCM EMA Update Hook."""
[docs] def before_run(self, runner: Runner) -> None: """Create an ema copy of the model. Args: ---- runner (Runner): The runner of the training process. """ model = runner.model if is_model_wrapper(model): model = model.module self.src_model = model.unet self.ema_model = model.target_unet
[docs] def after_train_iter(self, runner: Runner, # noqa batch_idx: int, # noqa data_batch: DATA_BATCH = None, # noqa outputs: dict | None = None) -> None: # noqa """Update ema parameter. Args: ---- runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ self.ema_model.update_parameters(self.src_model)