diffengine.models.editors.esd.esd_xl¶
Module Contents¶
Classes¶
Stable Diffusion XL Erasing Concepts from Diffusion Models. |
- class diffengine.models.editors.esd.esd_xl.ESDXL(*args, finetune_text_encoder=False, pre_compute_text_embeddings=True, height=1024, width=1024, negative_guidance=1.0, train_method='full', prediction_type=None, data_preprocessor=None, **kwargs)[source]¶
Bases:
diffengine.models.editors.stable_diffusion_xl.StableDiffusionXLStable Diffusion XL Erasing Concepts from Diffusion Models.
Args:¶
height (int): Image height. Defaults to 1024. width (int): Image width. Defaults to 1024. negative_guidance (float): Negative guidance for loss. Defaults to 1.0. train_method (str): Training method. Choice from full, xattn,
noxattn, selfattn. Defaults to full
- prepare_model()[source]¶
Prepare model for training.
Disable gradient for some models.
- Return type:
None
- train(*, mode=True)[source]¶
Convert the model into training mode.
- Parameters:
mode (bool) –
- Return type:
None
- abstract _preprocess_model_input(latents, noise, timesteps)[source]¶
Preprocess model input.
- Parameters:
latents (torch.Tensor) –
noise (torch.Tensor) –
timesteps (torch.Tensor) –
- Return type:
torch.Tensor
- forward(inputs, data_samples=None, mode='loss')[source]¶
Forward function.
Args:¶
inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.
Defaults to None.
mode (str, optional): The mode. Defaults to “loss”.
Returns:¶
dict: The loss dict.
- Parameters:
inputs (dict) –
data_samples (Optional[list]) –
mode (str) –
- Return type:
dict
- Parameters:
finetune_text_encoder (bool) –
pre_compute_text_embeddings (bool) –
height (int) –
width (int) –
negative_guidance (float) –
train_method (str) –
prediction_type (str | None) –
data_preprocessor (dict | torch.nn.Module | None) –