diffengine.models.editors.esd

Submodules

Package Contents

Classes

ESDXL

Stable Diffusion XL Erasing Concepts from Diffusion Models.

ESDXLDataPreprocessor

ESDXLDataPreprocessor.

class diffengine.models.editors.esd.ESDXL(*args, finetune_text_encoder=False, pre_compute_text_embeddings=True, height=1024, width=1024, negative_guidance=1.0, train_method='full', prediction_type=None, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL Erasing Concepts from Diffusion Models.

Args:

height (int): Image height. Defaults to 1024. width (int): Image width. Defaults to 1024. negative_guidance (float): Negative guidance for loss. Defaults to 1.0. train_method (str): Training method. Choice from full, xattn,

noxattn, selfattn. Defaults to full

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

_freeze_unet()[source]
Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

train(*, mode=True)[source]

Convert the model into training mode.

Parameters:

mode (bool) –

Return type:

None

abstract _preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • finetune_text_encoder (bool) –

  • pre_compute_text_embeddings (bool) –

  • height (int) –

  • width (int) –

  • negative_guidance (float) –

  • train_method (str) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.esd.ESDXLDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

ESDXLDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

Union[dict, list]