diffengine.engine.hooks.fast_norm_hook

Module Contents

Classes

FastNormHook

Fast Normalization Hook.

Functions

_fast_gn_forward(self, x)

Faster group normalization forward.

Attributes

apex

diffengine.engine.hooks.fast_norm_hook.apex[source]
diffengine.engine.hooks.fast_norm_hook._fast_gn_forward(self, x)[source]

Faster group normalization forward.

Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/ fast_norm.py

Parameters:

x (torch.Tensor) –

Return type:

torch.Tensor

class diffengine.engine.hooks.fast_norm_hook.FastNormHook(*, fuse_text_encoder_ln=False, fuse_main_ln=True, fuse_gn=False)[source]

Bases: mmengine.hooks.Hook

Fast Normalization Hook.

Replace the normalization layer with a faster one.

Args:

fuse_text_encoder_ln (bool): Whether to fuse the text encoder layer

normalization. Defaults to False.

fuse_main_ln (bool): Whether to replace the layer normalization

in main module like unet or transformer. Defaults to True.

fuse_gn (bool)Whether to replace the group normalization.

Defaults to False.

priority = VERY_LOW[source]
_replace_ln(module, name, device)[source]

Replace the layer normalization with a fused one.

Parameters:
  • module (torch.nn.Module) –

  • name (str) –

  • device (str) –

Return type:

None

_replace_gn(module, name, device)[source]

Replace the layer normalization with a fused one.

Parameters:
  • module (torch.nn.Module) –

  • name (str) –

  • device (str) –

Return type:

None

_replace_gn_forward(module, name)[source]

Replace the group normalization forward with a faster one.

Parameters:
  • module (torch.nn.Module) –

  • name (str) –

Return type:

None

before_train(runner)[source]

Replace the normalization layer with a faster one.

Args:

runner (Runner): The runner of the training process.

Parameters:

runner (mmengine.runner.Runner) –

Return type:

None

Parameters:
  • fuse_text_encoder_ln (bool) –

  • fuse_main_ln (bool) –

  • fuse_gn (bool) –