diffengine.models.editors.stable_diffusion_xl_dpo

Submodules

Package Contents

Classes

SDXLDPODataPreprocessor

SDXLDataPreprocessor.

StableDiffusionXLDPO

Stable Diffusion XL DPO.

class diffengine.models.editors.stable_diffusion_xl_dpo.SDXLDPODataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

SDXLDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.stable_diffusion_xl_dpo.StableDiffusionXLDPO(*args, beta_dpo=5000, loss=None, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL DPO.

Args:

beta_dpo (int): DPO KL Divergence penalty. Defaults to 5000. loss (dict, optional): The loss config. Defaults to None. data_preprocessor (dict, optional): The pre-process config of

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

loss(model_pred, ref_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • ref_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • beta_dpo (int) –

  • loss (dict | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –