diffengine.models.editors.ssd_1b

Submodules

Package Contents

Classes

SSD1B

SSD1B.

class diffengine.models.editors.ssd_1b.SSD1B(tokenizer_one, tokenizer_two, scheduler, text_encoder_one, text_encoder_two, vae, teacher_unet, student_unet, model='stabilityai/stable-diffusion-xl-base-1.0', loss=None, unet_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, vae_batch_size=8, *, finetune_text_encoder=False, gradient_checkpointing=False, pre_compute_text_embeddings=False, enable_xformers=False, student_weight_from_teacher=False)[source]

Bases: diffengine.models.editors.stable_diffusion_xl.StableDiffusionXL

SSD1B.

Refer to official implementation: https://github.com/segmind/SSD-1B/blob/main/distill_sdxl.py

Args:

tokenizer_one (dict): Config of tokenizer one. tokenizer_two (dict): Config of tokenizer two. scheduler (dict): Config of scheduler. text_encoder_one (dict): Config of text encoder one. text_encoder_two (dict): Config of text encoder two. vae (dict): Config of vae. teacher_unet (dict): Config of teacher unet. student_unet (dict): Config of student unet. model (str): pretrained model name of stable diffusion xl.

Defaults to ‘stabilityai/stable-diffusion-xl-base-1.0’.

vae_model (str, optional): Path to pretrained VAE model with better

numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038. Defaults to None.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

unet_lora_config (dict, optional): The LoRA config dict for Unet.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

prediction_type (str): The prediction_type that shall be used for

training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler: noise_scheduler.config.prediciton_type is chosen. Defaults to None.

data_preprocessor (dict, optional): The pre-process config of

SDXLDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='TimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

pre_compute_text_embeddings(bool): Whether or not to pre-compute text

embeddings to save memory. Defaults to False.

enable_xformers (bool): Whether or not to enable memory efficient

attention. Defaults to False.

student_weight_from_teacher (bool): Whether or not to initialize

student model with teacher model. Defaults to False.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

_cast_hook()[source]
Return type:

None

set_xformers()[source]

Set xformers for model.

Return type:

None

_forward_vae(img, num_batches)[source]

Forward vae.

Parameters:
  • img (torch.Tensor) –

  • num_batches (int) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer_one (dict) –

  • tokenizer_two (dict) –

  • scheduler (dict) –

  • text_encoder_one (dict) –

  • text_encoder_two (dict) –

  • vae (dict) –

  • teacher_unet (dict) –

  • student_unet (dict) –

  • model (str) –

  • loss (dict | None) –

  • unet_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • prediction_type (str | None) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • vae_batch_size (int) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –

  • pre_compute_text_embeddings (bool) –

  • enable_xformers (bool) –

  • student_weight_from_teacher (bool) –