diffengine.engine.hooks.fast_norm_hook¶
Module Contents¶
Classes¶
Fast Normalization Hook. |
Functions¶
|
Faster group normalization forward. |
Attributes¶
- diffengine.engine.hooks.fast_norm_hook._fast_gn_forward(self, x)[source]¶
Faster group normalization forward.
Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/ fast_norm.py
- Parameters:
x (torch.Tensor) –
- Return type:
torch.Tensor
- class diffengine.engine.hooks.fast_norm_hook.FastNormHook(*, fuse_text_encoder_ln=False, fuse_main_ln=True, fuse_gn=False)[source]¶
Bases:
mmengine.hooks.HookFast Normalization Hook.
Replace the normalization layer with a faster one.
Args:¶
- fuse_text_encoder_ln (bool): Whether to fuse the text encoder layer
normalization. Defaults to False.
- fuse_main_ln (bool): Whether to replace the layer normalization
in main module like unet or transformer. Defaults to True.
- fuse_gn (bool)Whether to replace the group normalization.
Defaults to False.
- _replace_ln(module, name, device)[source]¶
Replace the layer normalization with a fused one.
- Parameters:
module (torch.nn.Module) –
name (str) –
device (str) –
- Return type:
None
- _replace_gn(module, name, device)[source]¶
Replace the layer normalization with a fused one.
- Parameters:
module (torch.nn.Module) –
name (str) –
device (str) –
- Return type:
None
- Parameters:
fuse_text_encoder_ln (bool) –
fuse_main_ln (bool) –
fuse_gn (bool) –