Source code for diffengine.models.utils.timesteps

import torch
from diffusers import DDPMScheduler
from torch import nn

from diffengine.registry import MODELS


@MODELS.register_module()
[docs]class TimeSteps(nn.Module): """Time Steps module."""
[docs] def forward(self, scheduler: DDPMScheduler, num_batches: int, device: str, ) -> torch.Tensor: """Forward pass. Generates time steps for the given batches. Args: ---- scheduler (DDPMScheduler): Scheduler for training diffusion model. num_batches (int): Batch size. device (str): Device. """ timesteps = torch.randint( 0, scheduler.config.num_train_timesteps, (num_batches, ), device=device) return timesteps.long()
@MODELS.register_module()
[docs]class LaterTimeSteps(nn.Module): """Later biased Time Steps module. Args: ---- bias_multiplier (float): Bias multiplier. Defaults to 10. bias_portion (float): Portion of later time steps to bias. Defaults to 0.25. """ def __init__(self, bias_multiplier: float = 5., bias_portion: float = 0.25, ) -> None: super().__init__() lower_limit = 0. upper_limit = 1. assert lower_limit <= bias_portion <= upper_limit, \ "bias_portion must be in [0, 1]" self.bias_multiplier = bias_multiplier self.bias_portion = bias_portion
[docs] def forward(self, scheduler: DDPMScheduler, num_batches: int, device: str, ) -> torch.Tensor: """Forward pass. Generates time steps for the given batches. Args: ---- scheduler (DDPMScheduler): Scheduler for training diffusion model. num_batches (int): Batch size. device (str): Device. """ weights = torch.ones( scheduler.config.num_train_timesteps, device=device) num_to_bias = int( self.bias_portion * scheduler.config.num_train_timesteps) bias_indices = slice(-num_to_bias, None) weights[bias_indices] *= self.bias_multiplier weights /= weights.sum() timesteps = torch.multinomial(weights, num_batches, replacement=True) return timesteps.long()
@MODELS.register_module()
[docs]class EarlierTimeSteps(nn.Module): """Earlier biased Time Steps module. Args: ---- bias_multiplier (float): Bias multiplier. Defaults to 10. bias_portion (float): Portion of earlier time steps to bias. Defaults to 0.25. """ def __init__(self, bias_multiplier: float = 5., bias_portion: float = 0.25, ) -> None: super().__init__() lower_limit = 0. upper_limit = 1. assert lower_limit <= bias_portion <= upper_limit, \ "bias_portion must be in [0, 1]" self.bias_multiplier = bias_multiplier self.bias_portion = bias_portion
[docs] def forward(self, scheduler: DDPMScheduler, num_batches: int, device: str, ) -> torch.Tensor: """Forward pass. Generates time steps for the given batches. Args: ---- scheduler (DDPMScheduler): Scheduler for training diffusion model. num_batches (int): Batch size. device (str): Device. """ weights = torch.ones( scheduler.config.num_train_timesteps, device=device) num_to_bias = int( self.bias_portion * scheduler.config.num_train_timesteps) bias_indices = slice(0, num_to_bias) weights[bias_indices] *= self.bias_multiplier weights /= weights.sum() timesteps = torch.multinomial(weights, num_batches, replacement=True) return timesteps.long()
@MODELS.register_module()
[docs]class RangeTimeSteps(nn.Module): """Range biased Time Steps module. Args: ---- bias_multiplier (float): Bias multiplier. Defaults to 10. bias_begin (float): Portion of begin time steps to bias. Defaults to 0.25. bias_end (float): Portion of end time steps to bias. Defaults to 0.75. """ def __init__(self, bias_multiplier: float = 5., bias_begin: float = 0.25, bias_end: float = 0.75) -> None: super().__init__() lower_limit = 0. upper_limit = 1. assert bias_begin < bias_end, "bias_begin must be less than bias_end" assert lower_limit <= bias_begin <= upper_limit, \ "bias_begin must be in [0, 1]" assert lower_limit <= bias_end <= upper_limit, \ "bias_end must be in [0, 1]" self.bias_multiplier = bias_multiplier self.bias_begin = bias_begin self.bias_end = bias_end
[docs] def forward(self, scheduler: DDPMScheduler, num_batches: int, device: str, ) -> torch.Tensor: """Forward pass. Generates time steps for the given batches. Args: ---- scheduler (DDPMScheduler): Scheduler for training diffusion model. num_batches (int): Batch size. device (str): Device. """ weights = torch.ones( scheduler.config.num_train_timesteps, device=device) bias_begin = int( self.bias_begin * scheduler.config.num_train_timesteps) bias_end = int( self.bias_end * scheduler.config.num_train_timesteps) bias_indices = slice(bias_begin, bias_end) weights[bias_indices] *= self.bias_multiplier weights /= weights.sum() timesteps = torch.multinomial(weights, num_batches, replacement=True) return timesteps.long()
@MODELS.register_module()
[docs]class CubicSamplingTimeSteps(nn.Module): """Cubic Sampling Time Steps module. For more details about why cubic sampling is used, refer to section 3.4 of https://arxiv.org/abs/2302.08453 """
[docs] def forward(self, scheduler: DDPMScheduler, num_batches: int, device: str, ) -> torch.Tensor: """Forward pass. Generates time steps for the given batches. Args: ---- scheduler (DDPMScheduler): Scheduler for training diffusion model. num_batches (int): Batch size. device (str): Device. """ timesteps = torch.rand((num_batches, ), device=device) timesteps = ( 1 - timesteps ** 3) * scheduler.config.num_train_timesteps timesteps = timesteps.long() return timesteps.clamp( 0, scheduler.config.num_train_timesteps - 1)
@MODELS.register_module()
[docs]class WuerstchenRandomTimeSteps(nn.Module): """Wuerstchen Random Time Steps module."""
[docs] def forward(self, num_batches: int, device: str, ) -> torch.Tensor: """Forward pass. Generates time steps for the given batches. Args: ---- scheduler (DDPMScheduler): Scheduler for training diffusion model. num_batches (int): Batch size. device (str): Device. """ return torch.rand((num_batches, ), device=device)
@MODELS.register_module()
[docs]class DDIMTimeSteps(nn.Module): """DDIM Time Steps module. Args: ---- num_ddim_timesteps (int): Number of DDIM timesteps. Defaults to 50. """ def __init__(self, num_ddim_timesteps: int = 50) -> None: super().__init__() self.num_ddim_timesteps = num_ddim_timesteps self.register_buffer("ddim_timesteps", torch.arange(1, num_ddim_timesteps + 1))
[docs] def forward(self, scheduler: DDPMScheduler, num_batches: int, device: str, ) -> torch.Tensor: """Forward pass. Generates time steps for the given batches. Args: ---- scheduler (DDPMScheduler): Scheduler for training diffusion model. num_batches (int): Batch size. device (str): Device. """ step_ratio = scheduler.config.num_train_timesteps // self.num_ddim_timesteps index = torch.randint(0, self.num_ddim_timesteps, (num_batches,), device=device) timesteps = self.ddim_timesteps[index] * step_ratio - 1 return timesteps.long()