diffengine.models.editors.esd.esd_xl

Module Contents

Classes

ESDXL

Stable Diffusion XL Erasing Concepts from Diffusion Models.

class diffengine.models.editors.esd.esd_xl.ESDXL(*args, finetune_text_encoder=False, pre_compute_text_embeddings=True, height=1024, width=1024, negative_guidance=1.0, train_method='full', prediction_type=None, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL Erasing Concepts from Diffusion Models.

Args:

height (int): Image height. Defaults to 1024. width (int): Image width. Defaults to 1024. negative_guidance (float): Negative guidance for loss. Defaults to 1.0. train_method (str): Training method. Choice from full, xattn,

noxattn, selfattn. Defaults to full

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

_freeze_unet()[source]
Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

train(*, mode=True)[source]

Convert the model into training mode.

Parameters:

mode (bool) –

Return type:

None

abstract _preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • finetune_text_encoder (bool) –

  • pre_compute_text_embeddings (bool) –

  • height (int) –

  • width (int) –

  • negative_guidance (float) –

  • train_method (str) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –