diffengine.models.editors.esd
¶
Submodules¶
Package Contents¶
Classes¶
Stable Diffusion XL Erasing Concepts from Diffusion Models. |
|
ESDXLDataPreprocessor. |
- class diffengine.models.editors.esd.ESDXL(*args, finetune_text_encoder=False, pre_compute_text_embeddings=True, height=1024, width=1024, negative_guidance=1.0, train_method='full', prediction_type=None, data_preprocessor=None, **kwargs)[source]¶
Bases:
diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL
Stable Diffusion XL Erasing Concepts from Diffusion Models.
Args:¶
height (int): Image height. Defaults to 1024. width (int): Image width. Defaults to 1024. negative_guidance (float): Negative guidance for loss. Defaults to 1.0. train_method (str): Training method. Choice from full, xattn,
noxattn, selfattn. Defaults to full
- prepare_model()[source]¶
Prepare model for training.
Disable gradient for some models.
- Return type:
None
- train(*, mode=True)[source]¶
Convert the model into training mode.
- Parameters:
mode (bool) –
- Return type:
None
- abstract _preprocess_model_input(latents, noise, timesteps)[source]¶
Preprocess model input.
- Parameters:
latents (torch.Tensor) –
noise (torch.Tensor) –
timesteps (torch.Tensor) –
- Return type:
torch.Tensor
- forward(inputs, data_samples=None, mode='loss')[source]¶
Forward function.
Args:¶
inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.
Defaults to None.
mode (str, optional): The mode. Defaults to “loss”.
Returns:¶
dict: The loss dict.
- Parameters:
inputs (dict) –
data_samples (Optional[list]) –
mode (str) –
- Return type:
dict
- Parameters:
finetune_text_encoder (bool) –
pre_compute_text_embeddings (bool) –
height (int) –
width (int) –
negative_guidance (float) –
train_method (str) –
prediction_type (str | None) –
data_preprocessor (dict | torch.nn.Module | None) –
- class diffengine.models.editors.esd.ESDXLDataPreprocessor(non_blocking=False)[source]¶
Bases:
mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor
ESDXLDataPreprocessor.
- Parameters:
non_blocking (Optional[bool]) –
- forward(data, training=False)[source]¶
Preprocesses the data into the model input format.
After the data pre-processing of
cast_data()
,forward
will stack the input tensor list to a batch tensor at the first dimension.Args:¶
data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.
Returns:¶
dict or list: Data in the same format as the model input.
- Parameters:
data (dict) –
training (bool) –
- Return type:
Union[dict, list]