diffengine.models.editors

Subpackages

Package Contents

Classes

AMUSEd

aMUSEd.

AMUSEdPreprocessor

AMUSEdPreprocessor.

DeepFloydIF

DeepFloyd/IF.

DistillSDXL

Distill Stable Diffusion XL.

ESDXL

Stable Diffusion XL Erasing Concepts from Diffusion Models.

ESDXLDataPreprocessor

ESDXLDataPreprocessor.

StableDiffusionXLInstructPix2Pix

Stable Diffusion XL Instruct Pix2Pix.

IPAdapterXL

Stable Diffusion XL IP-Adapter.

IPAdapterXLPlus

Stable Diffusion XL IP-Adapter Plus.

IPAdapterXLDataPreprocessor

IPAdapterXLDataPreprocessor.

TimmIPAdapterXLPlus

Stable Diffusion XL IP-Adapter Plus.

KandinskyV22Prior

KandinskyV22 Prior.

KandinskyV22Decoder

KandinskyV22 Decoder.

KandinskyV22DecoderDataPreprocessor

KandinskyV22DecoderDataPreprocessor.

KandinskyV3

KandinskyV3.

LatentConsistencyModelsXL

Stable Diffusion XL Latent Consistency Models.

PixArtAlpha

PixArt Alpha.

PixArtAlphaDataPreprocessor

PixArtAlphaDataPreprocessor.

SSD1B

SSD1B.

StableDiffusion

Stable Diffusion.

SDDataPreprocessor

SDDataPreprocessor.

StableDiffusionControlNet

Stable Diffusion ControlNet.

SDControlNetDataPreprocessor

SDControlNetDataPreprocessor.

SDInpaintDataPreprocessor

SDInpaintDataPreprocessor.

StableDiffusionInpaint

Stable Diffusion Inpaint.

StableDiffusionXL

`Stable Diffusion XL.

SDXLDataPreprocessor

SDXLDataPreprocessor.

SDXLControlNetDataPreprocessor

SDXLControlNetDataPreprocessor.

StableDiffusionXLControlNet

Stable Diffusion XL ControlNet.

StableDiffusionXLDPO

Stable Diffusion XL DPO.

SDXLDPODataPreprocessor

SDXLDataPreprocessor.

StableDiffusionXLInpaint

Stable Diffusion XL Inpaint.

SDXLInpaintDataPreprocessor

SDXLInpaintDataPreprocessor.

StableDiffusionXLT2IAdapter

Stable Diffusion XL T2I Adapter.

WuerstchenPriorModel

`Wuerstchen Prior.

class diffengine.models.editors.AMUSEd(tokenizer, text_encoder, vae, transformer, model='amused/amused-512', loss=None, transformer_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, data_preprocessor=None, vae_batch_size=8, *, finetune_text_encoder=False, gradient_checkpointing=False, enable_xformers=False)[source]

Bases: mmengine.model.BaseModel

aMUSEd.

Args:

tokenizer (dict): Config of tokenizer. text_encoder (dict): Config of text encoder. vae (dict): Config of vae. transformer (dict): Config of transformer. model (str): pretrained model name.

Defaults to “amused/amused-512”.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

transformer_lora_config (dict, optional): The LoRA config dict for

Transformer. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

data_preprocessor (dict, optional): The pre-process config of

SDDataPreprocessor.

vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=12, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 12.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

_forward_vae(img, num_batches)[source]

Forward vae.

Parameters:
  • img (torch.Tensor) –

  • num_batches (int) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer (dict) –

  • text_encoder (dict) –

  • vae (dict) –

  • transformer (dict) –

  • model (str) –

  • loss (dict | None) –

  • transformer_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • vae_batch_size (int) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –

  • enable_xformers (bool) –

class diffengine.models.editors.AMUSEdPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

AMUSEdPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.DeepFloydIF(tokenizer, scheduler, text_encoder, unet, model='DeepFloyd/IF-I-XL-v1.0', loss=None, unet_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, tokenizer_max_length=77, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, *, finetune_text_encoder=False, gradient_checkpointing=False, enable_xformers=False)[source]

Bases: mmengine.model.BaseModel

DeepFloyd/IF.

Args:

tokenizer (dict): Config of tokenizer. scheduler (dict): Config of scheduler. text_encoder (dict): Config of text encoder. unet (dict): Config of unet. model (str): pretrained model name of stable diffusion.

Defaults to ‘DeepFloyd/IF-I-XL-v1.0’.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

tokenizer_max_length (int): The max length of tokenizer.

Defaults to 77.

prediction_type (str): The prediction_type that shall be used for

training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler: noise_scheduler.config.prediciton_type is chosen. Defaults to None.

data_preprocessor (dict, optional): The pre-process config of

SDDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘pt’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer (dict) –

  • scheduler (dict) –

  • text_encoder (dict) –

  • unet (dict) –

  • model (str) –

  • loss (dict | None) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • tokenizer_max_length (int) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –

  • enable_xformers (bool) –

class diffengine.models.editors.DistillSDXL(*args, model_type, unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Distill Stable Diffusion XL.

Args:

model_type (str): The type of model to use. Choice from sd_tiny,

sd_small.

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. This should be False when training ControlNet. Defaults to False.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

_prepare_student()[source]
Return type:

None

_cast_hook()[source]
Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • model_type (str) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • finetune_text_encoder (bool) –

class diffengine.models.editors.ESDXL(*args, finetune_text_encoder=False, pre_compute_text_embeddings=True, height=1024, width=1024, negative_guidance=1.0, train_method='full', prediction_type=None, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL Erasing Concepts from Diffusion Models.

Args:

height (int): Image height. Defaults to 1024. width (int): Image width. Defaults to 1024. negative_guidance (float): Negative guidance for loss. Defaults to 1.0. train_method (str): Training method. Choice from full, xattn,

noxattn, selfattn. Defaults to full

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

_freeze_unet()[source]
Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

train(*, mode=True)[source]

Convert the model into training mode.

Parameters:

mode (bool) –

Return type:

None

abstract _preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • finetune_text_encoder (bool) –

  • pre_compute_text_embeddings (bool) –

  • height (int) –

  • width (int) –

  • negative_guidance (float) –

  • train_method (str) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.ESDXLDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

ESDXLDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

Union[dict, list]

class diffengine.models.editors.StableDiffusionXLInstructPix2Pix(*args, zeros_image_embeddings_prob=0.1, unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL Instruct Pix2Pix.

Args:

zeros_image_embeddings_prob (float): The probabilities to

generate zeros image embeddings. Defaults to 0.1.

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. This should be False when training ControlNet. Defaults to False.

data_preprocessor (dict, optional): The pre-process config of

SDControlNetDataPreprocessor.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

infer(prompt, condition_image, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

condition_image (List[Union[str, Image.Image]]):

The condition image for ControlNet.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • condition_image (list[str | PIL.Image.Image]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • zeros_image_embeddings_prob (float) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • finetune_text_encoder (bool) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.IPAdapterXL(*args, image_encoder, image_projection, feature_extractor, pretrained_adapter=None, pretrained_adapter_subfolder='', pretrained_adapter_weights_name='', unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, zeros_image_embeddings_prob=0.1, data_preprocessor=None, hidden_states_idx=-2, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL IP-Adapter.

Args:

image_encoder (dict): The image encoder config. image_projection (dict): The image projection config. feature_extractor (dict): The feature extractor config. pretrained_adapter (str, optional): Path to pretrained IP-Adapter.

Defaults to None.

pretrained_adapter_subfolder (str, optional): Sub folder of pretrained

IP-Adapter. Defaults to ‘’.

pretrained_adapter_weights_name (str, optional): Weights name of

pretrained IP-Adapter. Defaults to ‘’.

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. This should be False when training ControlNet. Defaults to False.

zeros_image_embeddings_prob (float): The probabilities to

generate zeros image embeddings. Defaults to 0.1.

data_preprocessor (dict, optional): The pre-process config of

SDControlNetDataPreprocessor.

hidden_states_idx (int): Index of the hidden states to be used.

Defaults to -2.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_ip_adapter()[source]

Set IP-Adapter for model.

Return type:

None

infer(prompt, example_image, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

example_image (List[Union[str, Image.Image]]):

The image prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • example_image (list[str | PIL.Image.Image]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • image_encoder (dict) –

  • image_projection (dict) –

  • feature_extractor (dict) –

  • pretrained_adapter (str | None) –

  • pretrained_adapter_subfolder (str) –

  • pretrained_adapter_weights_name (str) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • finetune_text_encoder (bool) –

  • zeros_image_embeddings_prob (float) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • hidden_states_idx (int) –

class diffengine.models.editors.IPAdapterXLPlus(*args, image_encoder, image_projection, feature_extractor, pretrained_adapter=None, pretrained_adapter_subfolder='', pretrained_adapter_weights_name='', unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, zeros_image_embeddings_prob=0.1, data_preprocessor=None, hidden_states_idx=-2, **kwargs)[source]

Bases: IPAdapterXL

Stable Diffusion XL IP-Adapter Plus.

Parameters:
  • image_encoder (dict) –

  • image_projection (dict) –

  • feature_extractor (dict) –

  • pretrained_adapter (str | None) –

  • pretrained_adapter_subfolder (str) –

  • pretrained_adapter_weights_name (str) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • finetune_text_encoder (bool) –

  • zeros_image_embeddings_prob (float) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • hidden_states_idx (int) –

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

class diffengine.models.editors.IPAdapterXLDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

IPAdapterXLDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.TimmIPAdapterXLPlus(*args, image_encoder, image_projection, feature_extractor, pretrained_adapter=None, pretrained_adapter_subfolder='', pretrained_adapter_weights_name='', unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, zeros_image_embeddings_prob=0.1, data_preprocessor=None, hidden_states_idx=-2, **kwargs)[source]

Bases: diffengine.models.editors.ip_adapter.ip_adapter_xl.IPAdapterXLPlus

Stable Diffusion XL IP-Adapter Plus.

Parameters:
  • image_encoder (dict) –

  • image_projection (dict) –

  • feature_extractor (dict) –

  • pretrained_adapter (str | None) –

  • pretrained_adapter_subfolder (str) –

  • pretrained_adapter_weights_name (str) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • finetune_text_encoder (bool) –

  • zeros_image_embeddings_prob (float) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • hidden_states_idx (int) –

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

infer(prompt, example_image, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:

prompt (List[str]):

The prompt or prompts to guide the image generation.

example_image (List[Union[str, Image.Image]]):

The image prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • example_image (list[str | PIL.Image.Image]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (list | None) –

  • mode (str) –

Return type:

dict

class diffengine.models.editors.KandinskyV22Prior(tokenizer, scheduler, text_encoder, image_encoder, prior, decoder_model='kandinsky-community/kandinsky-2-2-decoder', prior_model='kandinsky-community/kandinsky-2-2-prior', loss=None, prior_lora_config=None, prior_loss_weight=1.0, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, *, gradient_checkpointing=False, enable_xformers=False)[source]

Bases: mmengine.model.BaseModel

KandinskyV22 Prior.

Args:

tokenizer (dict): Config of tokenizer. scheduler (dict): Config of scheduler. text_encoder (dict): Config of text encoder. image_encoder (dict): Config of image encoder. prior (dict): Config of prior. decoder_model (str): pretrained model name of decoder.

Defaults to “kandinsky-community/kandinsky-2-2-decoder”.

prior_model (str): pretrained model name of prior.

Defaults to “kandinsky-community/kandinsky-2-2-prior”.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

prior_lora_config (dict, optional): The LoRA config dict for Prior.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

data_preprocessor (dict, optional): The pre-process config of

SDDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer (dict) –

  • scheduler (dict) –

  • text_encoder (dict) –

  • image_encoder (dict) –

  • prior (dict) –

  • decoder_model (str) –

  • prior_model (str) –

  • loss (dict | None) –

  • prior_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • gradient_checkpointing (bool) –

  • enable_xformers (bool) –

class diffengine.models.editors.KandinskyV22Decoder(scheduler, image_encoder, vae, unet, decoder_model='kandinsky-community/kandinsky-2-2-decoder', prior_model='kandinsky-community/kandinsky-2-2-prior', loss=None, unet_lora_config=None, prior_loss_weight=1.0, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, vae_batch_size=8, *, gradient_checkpointing=False, enable_xformers=False)[source]

Bases: mmengine.model.BaseModel

KandinskyV22 Decoder.

Args:

scheduler (dict): Config of scheduler. image_encoder (dict): Config of image encoder. vae (dict): Config of vae. unet (dict): Config of unet. decoder_model (str): pretrained model name of decoder.

Defaults to “kandinsky-community/kandinsky-2-2-decoder”.

prior_model (str): pretrained model name of prior.

Defaults to “kandinsky-community/kandinsky-2-2-prior”.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

prediction_type (str): The prediction_type that shall be used for

training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler will be used. Defaults to None.

data_preprocessor (dict, optional): The pre-process config of

SDDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

vae_batch_size (int): The batch size of vae. Defaults to 8. gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

_forward_vae(img, num_batches)[source]

Forward vae.

Parameters:
  • img (torch.Tensor) –

  • num_batches (int) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • scheduler (dict) –

  • image_encoder (dict) –

  • vae (dict) –

  • unet (dict) –

  • decoder_model (str) –

  • prior_model (str) –

  • loss (dict | None) –

  • unet_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • vae_batch_size (int) –

  • gradient_checkpointing (bool) –

  • enable_xformers (bool) –

class diffengine.models.editors.KandinskyV22DecoderDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

KandinskyV22DecoderDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.KandinskyV3(tokenizer, scheduler, text_encoder, vae, unet, model='kandinsky-community/kandinsky-3', loss=None, unet_lora_config=None, prior_loss_weight=1.0, tokenizer_max_length=128, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, vae_batch_size=8, *, gradient_checkpointing=False, enable_xformers=False)[source]

Bases: mmengine.model.BaseModel

KandinskyV3.

Args:

tokenizer (dict): Config of tokenizer. scheduler (dict): Config of scheduler. text_encoder (dict): Config of text encoder. vae (dict): Config of vae. unet (dict): Config of unet. model (str): pretrained model name.

Defaults to “kandinsky-community/kandinsky-3”.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

tokenizer_max_length (int): The max length of tokenizer.

Defaults to 128.

prediction_type (str): The prediction_type that shall be used for

training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler will be used. Defaults to None.

data_preprocessor (dict, optional): The pre-process config of

SDDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

vae_batch_size (int): The batch size of vae. Defaults to 8. gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

_forward_vae(img, num_batches)[source]

Forward vae.

Parameters:
  • img (torch.Tensor) –

  • num_batches (int) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer (dict) –

  • scheduler (dict) –

  • text_encoder (dict) –

  • vae (dict) –

  • unet (dict) –

  • model (str) –

  • loss (dict | None) –

  • unet_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • tokenizer_max_length (int) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • vae_batch_size (int) –

  • gradient_checkpointing (bool) –

  • enable_xformers (bool) –

class diffengine.models.editors.LatentConsistencyModelsXL(*args, timesteps_generator=None, num_ddim_timesteps=50, w_min=3.0, w_max=15.0, ema_type='ExponentialMovingAverage', ema_momentum=0.05, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL Latent Consistency Models.

Args:

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='DDIMTimeSteps').

num_ddim_timesteps (int): Number of DDIM timesteps. Defaults to 50. w_min (float): Minimum guidance scale. Defaults to 3.0. w_max (float): Maximum guidance scale. Defaults to 15.0. ema_type (str): The type of EMA.

Defaults to ‘ExponentialMovingAverage’.

ema_momentum (float): The EMA momentum. Defaults to 0.05.

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, height=None, width=None, num_inference_steps=4, guidance_scale=1.0, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

guidance_scale (float): The guidance scale. Defaults to 1.0. output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • guidance_scale (float) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

loss(model_pred, gt, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • gt (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

_predicted_origin(model_output, timesteps, sample)[source]

Predict the origin of the model output.

Args:

model_output (torch.Tensor): The model output. timesteps (torch.Tensor): The timesteps. sample (torch.Tensor): The sample.

Parameters:
  • model_output (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • sample (torch.Tensor) –

Return type:

torch.Tensor

Parameters:
  • timesteps_generator (dict | None) –

  • num_ddim_timesteps (int) –

  • w_min (float) –

  • w_max (float) –

  • ema_type (str) –

  • ema_momentum (float) –

class diffengine.models.editors.PixArtAlpha(tokenizer, scheduler, text_encoder, vae, transformer, model='PixArt-alpha/PixArt-XL-2-1024-MS', loss=None, transformer_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, tokenizer_max_length=120, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, vae_batch_size=8, *, finetune_text_encoder=False, gradient_checkpointing=False, enable_xformers=False)[source]

Bases: mmengine.model.BaseModel

PixArt Alpha.

Args:

tokenizer (dict): Config of tokenizer. scheduler (dict): Config of scheduler. text_encoder (dict): Config of text encoder. vae (dict): Config of vae. transformer (dict): Config of transformer. model (str): pretrained model name of stable diffusion.

Defaults to ‘PixArt-alpha/PixArt-XL-2-1024-MS’.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

transformer_lora_config (dict, optional): The LoRA config dict for

Transformer. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

tokenizer_max_length (int): The max length of tokenizer.

Defaults to 120.

prediction_type (str): The prediction_type that shall be used for

training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler will be used. Defaults to None.

data_preprocessor (dict, optional): The pre-process config of

PixArtAlphaDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

_forward_vae(img, num_batches)[source]

Forward vae.

Parameters:
  • img (torch.Tensor) –

  • num_batches (int) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer (dict) –

  • scheduler (dict) –

  • text_encoder (dict) –

  • vae (dict) –

  • transformer (dict) –

  • model (str) –

  • loss (dict | None) –

  • transformer_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • tokenizer_max_length (int) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • vae_batch_size (int) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –

  • enable_xformers (bool) –

class diffengine.models.editors.PixArtAlphaDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

PixArtAlphaDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.SSD1B(tokenizer_one, tokenizer_two, scheduler, text_encoder_one, text_encoder_two, vae, teacher_unet, student_unet, model='stabilityai/stable-diffusion-xl-base-1.0', loss=None, unet_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, vae_batch_size=8, *, finetune_text_encoder=False, gradient_checkpointing=False, pre_compute_text_embeddings=False, enable_xformers=False, student_weight_from_teacher=False)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

SSD1B.

Refer to official implementation: https://github.com/segmind/SSD-1B/blob/main/distill_sdxl.py

Args:

tokenizer_one (dict): Config of tokenizer one. tokenizer_two (dict): Config of tokenizer two. scheduler (dict): Config of scheduler. text_encoder_one (dict): Config of text encoder one. text_encoder_two (dict): Config of text encoder two. vae (dict): Config of vae. teacher_unet (dict): Config of teacher unet. student_unet (dict): Config of student unet. model (str): pretrained model name of stable diffusion xl.

Defaults to ‘stabilityai/stable-diffusion-xl-base-1.0’.

vae_model (str, optional): Path to pretrained VAE model with better

numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038. Defaults to None.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

prediction_type (str): The prediction_type that shall be used for

training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler: noise_scheduler.config.prediciton_type is chosen. Defaults to None.

data_preprocessor (dict, optional): The pre-process config of

SDXLDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

pre_compute_text_embeddings(bool): Whether or not to pre-compute text

embeddings to save memory. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

student_weight_from_teacher (bool): Whether or not to initialize

student model with teacher model. Defaults to False.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

_cast_hook()[source]
Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

_forward_vae(img, num_batches)[source]

Forward vae.

Parameters:
  • img (torch.Tensor) –

  • num_batches (int) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer_one (dict) –

  • tokenizer_two (dict) –

  • scheduler (dict) –

  • text_encoder_one (dict) –

  • text_encoder_two (dict) –

  • vae (dict) –

  • teacher_unet (dict) –

  • student_unet (dict) –

  • model (str) –

  • loss (dict | None) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • vae_batch_size (int) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –

  • pre_compute_text_embeddings (bool) –

  • enable_xformers (bool) –

  • student_weight_from_teacher (bool) –

class diffengine.models.editors.StableDiffusion(tokenizer, scheduler, text_encoder, vae, unet, model='runwayml/stable-diffusion-v1-5', loss=None, unet_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, vae_batch_size=8, *, finetune_text_encoder=False, gradient_checkpointing=False, enable_xformers=False)[source]

Bases: mmengine.model.BaseModel

Stable Diffusion.

Args:

tokenizer (dict): Config of tokenizer. scheduler (dict): Config of scheduler. text_encoder (dict): Config of text encoder. vae (dict): Config of vae. unet (dict): Config of unet. model (str): pretrained model name of stable diffusion.

Defaults to ‘runwayml/stable-diffusion-v1-5’.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

prediction_type (str): The prediction_type that shall be used for

training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler will be used. Defaults to None.

data_preprocessor (dict, optional): The pre-process config of

SDDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

_forward_vae(img, num_batches)[source]

Forward vae.

Parameters:
  • img (torch.Tensor) –

  • num_batches (int) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer (dict) –

  • scheduler (dict) –

  • text_encoder (dict) –

  • vae (dict) –

  • unet (dict) –

  • model (str) –

  • loss (dict | None) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • vae_batch_size (int) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –

  • enable_xformers (bool) –

class diffengine.models.editors.SDDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

SDDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.StableDiffusionControlNet(*args, controlnet_model=None, transformer_layers_per_block=None, unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion.StableDiffusion

Stable Diffusion ControlNet.

Args:

controlnet_model (str, optional): Path to pretrained ControlNet model.

If None, use the default ControlNet model from Unet. Defaults to None.

transformer_layers_per_block (List[int], optional):

The number of layers per block in the transformer. More details: https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0-small. Defaults to None.

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. This should be False when training ControlNet. Defaults to False.

data_preprocessor (dict, optional): The pre-process config of

SDControlNetDataPreprocessor.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, condition_image, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

condition_image (List[Union[str, Image.Image]]):

The condition image for ControlNet.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • condition_image (list[str | PIL.Image.Image]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

_forward_compile(noisy_latents, timesteps, encoder_hidden_states, inputs)[source]

Forward function for torch.compile.

Parameters:
  • noisy_latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • encoder_hidden_states (torch.Tensor) –

  • inputs (dict) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • controlnet_model (str | None) –

  • transformer_layers_per_block (list[int] | None) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • finetune_text_encoder (bool) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.SDControlNetDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

SDControlNetDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.SDInpaintDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

SDInpaintDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.StableDiffusionInpaint(*args, model='runwayml/stable-diffusion-inpainting', data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion.StableDiffusion

Stable Diffusion Inpaint.

Args:

model (str): pretrained model name of stable diffusion.

Defaults to ‘runwayml/stable-diffusion-v1-5’.

data_preprocessor (dict, optional): The pre-process config of

SDInpaintDataPreprocessor.

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

infer(prompt, image, mask, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

image (List[Union[str, Image.Image]]):

The image for inpainting.

mask (List[Union[str, Image.Image]]):

The mask for inpainting.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • image (list[str | PIL.Image.Image]) –

  • mask (list[str | PIL.Image.Image]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • model (str) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.StableDiffusionXL(tokenizer_one, tokenizer_two, scheduler, text_encoder_one, text_encoder_two, vae, unet, model='stabilityai/stable-diffusion-xl-base-1.0', loss=None, unet_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, vae_batch_size=8, *, finetune_text_encoder=False, gradient_checkpointing=False, pre_compute_text_embeddings=False, enable_xformers=False)[source]

Bases: mmengine.model.BaseModel

`Stable Diffusion XL.

<https://huggingface.co/papers/2307.01952>`_

Args:

tokenizer_one (dict): Config of tokenizer one. tokenizer_two (dict): Config of tokenizer two. scheduler (dict): Config of scheduler. text_encoder_one (dict): Config of text encoder one. text_encoder_two (dict): Config of text encoder two. vae (dict): Config of vae. unet (dict): Config of unet. model (str): pretrained model name of stable diffusion xl.

Defaults to ‘stabilityai/stable-diffusion-xl-base-1.0’.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

prediction_type (str): The prediction_type that shall be used for

training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler: noise_scheduler.config.prediciton_type is chosen. Defaults to None.

data_preprocessor (dict, optional): The pre-process config of

SDXLDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

pre_compute_text_embeddings (bool): Whether or not to pre-compute text

embeddings to save memory. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

encode_prompt(text_one, text_two)[source]

Encode prompt.

Args:

text_one (torch.Tensor): Token ids from tokenizer one. text_two (torch.Tensor): Token ids from tokenizer two.

Returns:

tuple[torch.Tensor, torch.Tensor]: Prompt embeddings

Parameters:
  • text_one (torch.Tensor) –

  • text_two (torch.Tensor) –

Return type:

tuple[torch.Tensor, torch.Tensor]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

_forward_vae(img, num_batches)[source]

Forward vae.

Parameters:
  • img (torch.Tensor) –

  • num_batches (int) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer_one (dict) –

  • tokenizer_two (dict) –

  • scheduler (dict) –

  • text_encoder_one (dict) –

  • text_encoder_two (dict) –

  • vae (dict) –

  • unet (dict) –

  • model (str) –

  • loss (dict | None) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • vae_batch_size (int) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –

  • pre_compute_text_embeddings (bool) –

  • enable_xformers (bool) –

class diffengine.models.editors.SDXLDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

SDXLDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.SDXLControlNetDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

SDXLControlNetDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.StableDiffusionXLControlNet(*args, controlnet_model=None, transformer_layers_per_block=None, unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL ControlNet.

Args:

controlnet_model (str, optional): Path to pretrained ControlNet model.

If None, use the default ControlNet model from Unet. Defaults to None.

transformer_layers_per_block (List[int], optional):

The number of layers per block in the transformer. More details: https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0-small. Defaults to None.

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. This should be False when training ControlNet. Defaults to False.

data_preprocessor (dict, optional): The pre-process config of

SDControlNetDataPreprocessor.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, condition_image, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

condition_image (List[Union[str, Image.Image]]):

The condition image for ControlNet.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • condition_image (list[str | PIL.Image.Image]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

_forward_compile(noisy_latents, timesteps, prompt_embeds, unet_added_conditions, inputs)[source]

Forward function for torch.compile.

Parameters:
  • noisy_latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • prompt_embeds (torch.Tensor) –

  • unet_added_conditions (dict) –

  • inputs (dict) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • controlnet_model (str | None) –

  • transformer_layers_per_block (list[int] | None) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • finetune_text_encoder (bool) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.StableDiffusionXLDPO(*args, beta_dpo=5000, loss=None, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL DPO.

Args:

beta_dpo (int): DPO KL Divergence penalty. Defaults to 5000. loss (dict, optional): The loss config. Defaults to None. data_preprocessor (dict, optional): The pre-process config of

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

loss(model_pred, ref_pred, noise, latents, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • ref_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • beta_dpo (int) –

  • loss (dict | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.SDXLDPODataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

SDXLDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.StableDiffusionXLInpaint(*args, model='diffusers/stable-diffusion-xl-1.0-inpainting-0.1', data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL Inpaint.

Args:

model (str): pretrained model name of stable diffusion.

Defaults to ‘diffusers/stable-diffusion-xl-1.0-inpainting-0.1’.

data_preprocessor (dict, optional): The pre-process config of

SDXLInpaintDataPreprocessor.

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

infer(prompt, image, mask, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

image (List[Union[str, Image.Image]]):

The image for inpainting.

mask (List[Union[str, Image.Image]]):

The mask for inpainting.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • image (list[str | PIL.Image.Image]) –

  • mask (list[str | PIL.Image.Image]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • model (str) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.SDXLInpaintDataPreprocessor(non_blocking=False)[source]

Bases: mmengine.model.base_model.data_preprocessor.BaseDataPreprocessor

SDXLInpaintDataPreprocessor.

Parameters:

non_blocking (Optional[bool]) –

forward(data, training=False)[source]

Preprocesses the data into the model input format.

After the data pre-processing of cast_data(), forward will stack the input tensor list to a batch tensor at the first dimension.

Args:

data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation.

Returns:

dict or list: Data in the same format as the model input.

Parameters:
  • data (dict) –

  • training (bool) –

Return type:

dict | list

class diffengine.models.editors.StableDiffusionXLT2IAdapter(*args, adapter, unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, timesteps_generator=None, data_preprocessor=None, **kwargs)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

Stable Diffusion XL T2I Adapter.

Args:

adapter (dict): The adapter config. unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. This should be False when training ControlNet. Defaults to False.

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='CubicSamplingTimeSteps').

data_preprocessor (dict, optional): The pre-process config of

SDControlNetDataPreprocessor.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

infer(prompt, condition_image, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

condition_image (List[Union[str, Image.Image]]):

The condition image for ControlNet.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • condition_image (list[str | PIL.Image.Image]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

_forward_compile(noisy_latents, timesteps, prompt_embeds, unet_added_conditions, inputs)[source]

Forward function for torch.compile.

Parameters:
  • noisy_latents (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • prompt_embeds (torch.Tensor) –

  • unet_added_conditions (dict) –

  • inputs (dict) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • adapter (dict) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • finetune_text_encoder (bool) –

  • timesteps_generator (dict | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

class diffengine.models.editors.WuerstchenPriorModel(tokenizer, scheduler, text_encoder, image_encoder, prior, decoder_model='warp-ai/wuerstchen', prior_model='warp-ai/wuerstchen-prior', loss=None, prior_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, *, finetune_text_encoder=False, gradient_checkpointing=False)[source]

Bases: mmengine.model.BaseModel

`Wuerstchen Prior.

<https://arxiv.org/abs/2306.00637>`_

Args:

tokenizer (dict): Config of tokenizer. scheduler (dict): Config of scheduler. text_encoder (dict): Config of text encoder. image_encoder (dict): Config of image encoder. prior (dict): Config of prior. decoder_model (str): pretrained decoder model name of Wuerstchen.

Defaults to ‘warp-ai/wuerstchen’.

prior_model (str): pretrained prior model name of Wuerstchen.

Defaults to ‘warp-ai/wuerstchen-prior’.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

prior_lora_config (dict, optional): The LoRA config dict for Prior.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

data_preprocessor (dict, optional): The pre-process config of

SDDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='WuerstchenRandomTimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

property device: torch.device

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

train(*, mode=True)[source]

Convert the model into training mode.

Parameters:

mode (bool) –

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer (dict) –

  • scheduler (dict) –

  • text_encoder (dict) –

  • image_encoder (dict) –

  • prior (dict) –

  • decoder_model (str) –

  • prior_model (str) –

  • loss (dict | None) –

  • prior_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –