diffengine.engine.hooks.unet_ema_hook

Module Contents

Classes

UnetEMAHook

Unet EMA Hook.

class diffengine.engine.hooks.unet_ema_hook.UnetEMAHook(ema_type='ExponentialMovingAverage', strict_load=False, begin_iter=0, begin_epoch=0, **kwargs)[source]

Bases: mmengine.hooks.ema_hook.EMAHook

Unet EMA Hook.

Parameters:
  • ema_type (str) –

  • strict_load (bool) –

  • begin_iter (int) –

  • begin_epoch (int) –

before_run(runner)[source]

Create an ema copy of the model.

Args:

runner (Runner): The runner of the training process.

Parameters:

runner (mmengine.runner.Runner) –

Return type:

None

_swap_ema_state_dict(checkpoint)[source]

Swap the state dict values of model with ema_model.

Parameters:

checkpoint (dict) –

Return type:

None

after_load_checkpoint(runner, checkpoint)[source]

Resume ema parameters from checkpoint.

Args:

runner (Runner): The runner of the testing process. checkpoint (dict): Model’s checkpoint.

Parameters:
  • runner (mmengine.runner.Runner) –

  • checkpoint (dict) –

Return type:

None