Source code for diffengine.models.utils.noise

import random

import torch
from torch import nn

from diffengine.registry import MODELS


@MODELS.register_module()
[docs]class WhiteNoise(nn.Module): """White noise module."""
[docs] def forward(self, latents: torch.Tensor) -> torch.Tensor: """Forward pass. Generates noise for the given latents. Args: ---- latents (torch.Tensor): Latent vectors. """ return torch.randn_like(latents)
@MODELS.register_module()
[docs]class OffsetNoise(nn.Module): """Offset noise module. https://www.crosslabs.org/blog/diffusion-with-offset-noise Args: ---- offset_weight (float): Noise offset weight. Defaults to 0.05. """ def __init__(self, offset_weight: float = 0.05) -> None: super().__init__() self.offset_weight = offset_weight
[docs] def forward(self, latents: torch.Tensor) -> torch.Tensor: """Forward pass. Generates noise for the given latents. Args: ---- latents (torch.Tensor): Latent vectors. """ noise = torch.randn_like(latents) return noise + self.offset_weight * torch.randn( latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
@MODELS.register_module()
[docs]class PyramidNoise(nn.Module): """Pyramid noise module. https://wandb.ai/johnowhitaker/multires_noise/reports/ Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 Args: ---- discount (float): Noise offset weight. Defaults to 0.9. random_multiplier (bool): Whether to use random multiplier. Defaults to True. """ def __init__(self, discount: float = 0.9, *, random_multiplier: bool = True) -> None: super().__init__() self.discount = discount self.random_multiplier = random_multiplier
[docs] def forward(self, latents: torch.Tensor) -> torch.Tensor: """Forward pass. Generates noise for the given latents. Args: ---- latents (torch.Tensor): Latent vectors. """ noise = torch.randn_like(latents) b, c, w, h = latents.shape u = nn.Upsample(size=(w, h), mode="bilinear") for i in range(16): r = random.random() * 2 + 2 if self.random_multiplier else 2 # noqa: S311 w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i))) noise += u( torch.randn(b, c, w, h).to(latents)) * self.discount ** i if w==1 or h==1: break return noise / noise.std()