diffengine.models.losses

Submodules

Package Contents

Classes

CrossEntropyLoss

CrossEntropy loss.

DeBiasEstimationLoss

DeBias Estimation loss.

HuberLoss

Huber loss.

L2Loss

L2 loss.

SNRL2Loss

SNR weighting gamma L2 loss.

class diffengine.models.losses.CrossEntropyLoss(loss_weight=1.0, reduction='mean', ignore_index=-100, loss_name='cross_entropy')[source]

Bases: diffengine.models.losses.base.BaseLoss

CrossEntropy loss.

Args:

loss_weight (float, optional): Weight of this loss item.

Defaults to 1..

reduction: (str): The reduction method for the loss.

Defaults to ‘mean’.

ignore_index (int): Specifies a target value that is ignored.

Defaults to -100.

loss_name (str, optional): Name of the loss item. If you want this loss

item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘l2’.

forward(pred, gt, weight=None)[source]

Forward function.

Args:

pred (torch.Tensor): The predicted tensor. gt (torch.Tensor): The ground truth tensor. weight (torch.Tensor | None, optional): The loss weight.

Defaults to None.

Returns:

torch.Tensor: loss

Parameters:
  • pred (torch.Tensor) –

  • gt (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

torch.Tensor

Parameters:
  • loss_weight (float) –

  • reduction (str) –

  • ignore_index (int) –

  • loss_name (str) –

class diffengine.models.losses.DeBiasEstimationLoss(loss_weight=1.0, reduction='mean', loss_name='debias_estimation')[source]

Bases: diffengine.models.losses.base.BaseLoss

DeBias Estimation loss.

https://arxiv.org/abs/2310.08442

Args:

loss_weight (float): Weight of this loss item.

Defaults to 1..

reduction: (str): The reduction method for the loss.

Defaults to ‘mean’.

loss_name (str, optional): Name of the loss item. If you want this loss

item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘l2’.

property use_snr: bool

Whether or not this loss uses SNR.

Return type:

bool

forward(pred, gt, timesteps, alphas_cumprod, prediction_type, weight=None)[source]

Forward function.

Args:

pred (torch.Tensor): The predicted tensor. gt (torch.Tensor): The ground truth tensor. timesteps (torch.Tensor): The timestep tensor. alphas_cumprod (torch.Tensor): The alphas_cumprod from the

scheduler.

prediction_type (str): The prediction type from scheduler. weight (torch.Tensor | None, optional): The loss weight.

Defaults to None.

Returns:

torch.Tensor: loss

Parameters:
  • pred (torch.Tensor) –

  • gt (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • alphas_cumprod (torch.Tensor) –

  • prediction_type (str) –

  • weight (torch.Tensor | None) –

Return type:

torch.Tensor

Parameters:
  • loss_weight (float) –

  • reduction (str) –

  • loss_name (str) –

class diffengine.models.losses.HuberLoss(delta=1.0, loss_weight=1.0, reduction='mean', loss_name='l2')[source]

Bases: diffengine.models.losses.base.BaseLoss

Huber loss.

Args:

delta (float, optional): Specifies the threshold at which to change

between delta-scaled L1 and L2 loss. The value must be positive. Default: 1.0

loss_weight (float, optional): Weight of this loss item.

Defaults to 1..

reduction: (str): The reduction method for the loss.

Defaults to ‘mean’.

loss_name (str, optional): Name of the loss item. If you want this loss

item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘l2’.

forward(pred, gt, weight=None)[source]

Forward function.

Args:

pred (torch.Tensor): The predicted tensor. gt (torch.Tensor): The ground truth tensor. weight (torch.Tensor | None, optional): The loss weight.

Defaults to None.

Returns:

torch.Tensor: loss

Parameters:
  • pred (torch.Tensor) –

  • gt (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

torch.Tensor

Parameters:
  • delta (float) –

  • loss_weight (float) –

  • reduction (str) –

  • loss_name (str) –

class diffengine.models.losses.L2Loss(loss_weight=1.0, reduction='mean', loss_name='l2')[source]

Bases: diffengine.models.losses.base.BaseLoss

L2 loss.

Args:

loss_weight (float, optional): Weight of this loss item.

Defaults to 1..

reduction: (str): The reduction method for the loss.

Defaults to ‘mean’.

loss_name (str, optional): Name of the loss item. If you want this loss

item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘l2’.

forward(pred, gt, weight=None)[source]

Forward function.

Args:

pred (torch.Tensor): The predicted tensor. gt (torch.Tensor): The ground truth tensor. weight (torch.Tensor | None, optional): The loss weight.

Defaults to None.

Returns:

torch.Tensor: loss

Parameters:
  • pred (torch.Tensor) –

  • gt (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

torch.Tensor

Parameters:
  • loss_weight (float) –

  • reduction (str) –

  • loss_name (str) –

class diffengine.models.losses.SNRL2Loss(loss_weight=1.0, snr_gamma=5.0, reduction='mean', loss_name='snrl2')[source]

Bases: diffengine.models.losses.base.BaseLoss

SNR weighting gamma L2 loss.

https://arxiv.org/abs/2303.09556

Args:

loss_weight (float): Weight of this loss item.

Defaults to 1..

snr_gamma (float): SNR weighting gamma to be used if re balancing the

loss. “More details here: https://arxiv.org/abs/2303.09556.” Defaults to 5..

reduction: (str): The reduction method for the loss.

Defaults to ‘mean’.

loss_name (str, optional): Name of the loss item. If you want this loss

item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘l2’.

property use_snr: bool

Whether or not this loss uses SNR.

Return type:

bool

forward(pred, gt, timesteps, alphas_cumprod, prediction_type, weight=None)[source]

Forward function.

Args:

pred (torch.Tensor): The predicted tensor. gt (torch.Tensor): The ground truth tensor. timesteps (torch.Tensor): The timestep tensor. alphas_cumprod (torch.Tensor): The alphas_cumprod from the

scheduler.

prediction_type (str): The prediction type from scheduler. weight (torch.Tensor | None, optional): The loss weight.

Defaults to None.

Returns:

torch.Tensor: loss

Parameters:
  • pred (torch.Tensor) –

  • gt (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • alphas_cumprod (torch.Tensor) –

  • prediction_type (str) –

  • weight (torch.Tensor | None) –

Return type:

torch.Tensor

Parameters:
  • loss_weight (float) –

  • snr_gamma (float) –

  • reduction (str) –

  • loss_name (str) –