diffengine.models.losses.debias_estimation_loss¶
Module Contents¶
Classes¶
DeBias Estimation loss. |
- class diffengine.models.losses.debias_estimation_loss.DeBiasEstimationLoss(loss_weight=1.0, reduction='mean', loss_name='debias_estimation')[source]¶
Bases:
diffengine.models.losses.base.BaseLossDeBias Estimation loss.
https://arxiv.org/abs/2310.08442
Args:¶
- loss_weight (float): Weight of this loss item.
Defaults to
1..- reduction: (str): The reduction method for the loss.
Defaults to ‘mean’.
- loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘l2’.
- forward(pred, gt, timesteps, alphas_cumprod, prediction_type, weight=None)[source]¶
Forward function.
Args:¶
pred (torch.Tensor): The predicted tensor. gt (torch.Tensor): The ground truth tensor. timesteps (torch.Tensor): The timestep tensor. alphas_cumprod (torch.Tensor): The alphas_cumprod from the
scheduler.
prediction_type (str): The prediction type from scheduler. weight (torch.Tensor | None, optional): The loss weight.
Defaults to None.
Returns:¶
torch.Tensor: loss
- Parameters:
pred (torch.Tensor) –
gt (torch.Tensor) –
timesteps (torch.Tensor) –
alphas_cumprod (torch.Tensor) –
prediction_type (str) –
weight (torch.Tensor | None) –
- Return type:
torch.Tensor
- Parameters:
loss_weight (float) –
reduction (str) –
loss_name (str) –