diffengine.models.editors.stable_diffusion_xl_dpo¶
Submodules¶
Package Contents¶
Classes¶
SDXLDataPreprocessor. |
|
Stable Diffusion XL DPO. |
- class diffengine.models.editors.stable_diffusion_xl_dpo.SDXLDPODataPreprocessor(non_blocking=False)[source]¶
Bases:
mmengine.model.base_model.data_preprocessor.BaseDataPreprocessorSDXLDataPreprocessor.
- Parameters:
non_blocking (Optional[bool]) –
- forward(data, training=False)[source]¶
Preprocesses the data into the model input format.
After the data pre-processing of
cast_data(),forwardwill stack the input tensor list to a batch tensor at the first dimension.Args:¶
data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.
Returns:¶
dict or list: Data in the same format as the model input.
- Parameters:
data (dict) –
training (bool) –
- Return type:
dict | list
- class diffengine.models.editors.stable_diffusion_xl_dpo.StableDiffusionXLDPO(*args, beta_dpo=5000, loss=None, data_preprocessor=None, **kwargs)[source]¶
Bases:
diffengine.models.editors.stable_diffusion_xl.StableDiffusionXLStable Diffusion XL DPO.
Args:¶
beta_dpo (int): DPO KL Divergence penalty. Defaults to 5000. loss (dict, optional): The loss config. Defaults to None. data_preprocessor (dict, optional): The pre-process config of
- prepare_model()[source]¶
Prepare model for training.
Disable gradient for some models.
- Return type:
None
- loss(model_pred, ref_pred, noise, latents, timesteps, weight=None)[source]¶
Calculate loss.
- Parameters:
model_pred (torch.Tensor) –
ref_pred (torch.Tensor) –
noise (torch.Tensor) –
latents (torch.Tensor) –
timesteps (torch.Tensor) –
weight (torch.Tensor | None) –
- Return type:
dict[str, torch.Tensor]
- forward(inputs, data_samples=None, mode='loss')[source]¶
Forward function.
Args:¶
inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.
Defaults to None.
mode (str, optional): The mode. Defaults to “loss”.
Returns:¶
dict: The loss dict.
- Parameters:
inputs (dict) –
data_samples (Optional[list]) –
mode (str) –
- Return type:
dict
- Parameters:
beta_dpo (int) –
loss (dict | None) –
data_preprocessor (dict | torch.nn.Module | None) –