import torch
import torch.nn.functional as F # noqa
from diffengine.models.losses.base import BaseLoss
from diffengine.models.losses.utils import compute_snr
from diffengine.registry import MODELS
@MODELS.register_module()
[docs]class SNRL2Loss(BaseLoss):
"""SNR weighting gamma L2 loss.
https://arxiv.org/abs/2303.09556
Args:
----
loss_weight (float): Weight of this loss item.
Defaults to ``1.``.
snr_gamma (float): SNR weighting gamma to be used if re balancing the
loss. "More details here: https://arxiv.org/abs/2303.09556."
Defaults to ``5.``.
reduction: (str): The reduction method for the loss.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'l2'.
"""
def __init__(self,
loss_weight: float = 1.0,
snr_gamma: float = 5.0,
reduction: str = "mean",
loss_name: str = "snrl2") -> None:
super().__init__()
assert reduction in ["mean", "none"], (
f"reduction should be 'mean' or 'none', got {reduction}"
)
self.loss_weight = loss_weight
self.snr_gamma = snr_gamma
self.reduction = reduction
self._loss_name = loss_name
@property
[docs] def use_snr(self) -> bool:
"""Whether or not this loss uses SNR."""
return True
[docs] def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
timesteps: torch.Tensor,
alphas_cumprod: torch.Tensor,
prediction_type: str,
weight: torch.Tensor | None = None) -> torch.Tensor:
"""Forward function.
Args:
----
pred (torch.Tensor): The predicted tensor.
gt (torch.Tensor): The ground truth tensor.
timesteps (torch.Tensor): The timestep tensor.
alphas_cumprod (torch.Tensor): The alphas_cumprod from the
scheduler.
prediction_type (str): The prediction type from scheduler.
weight (torch.Tensor | None, optional): The loss weight.
Defaults to None.
Returns:
-------
torch.Tensor: loss
"""
snr = compute_snr(timesteps, alphas_cumprod)
if prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before
# we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)],
dim=1).min(dim=1)[0] / snr)
loss = F.mse_loss(pred, gt, reduction="none")
loss = loss.mean(
dim=list(range(1, len(loss.shape)))) * mse_loss_weights
if weight is not None:
loss = loss * weight
if self.reduction == "mean":
loss = loss.mean()
return loss * self.loss_weight