diffengine.models.editors.ssd_1b¶
Submodules¶
Package Contents¶
Classes¶
SSD1B. |
- class diffengine.models.editors.ssd_1b.SSD1B(tokenizer_one, tokenizer_two, scheduler, text_encoder_one, text_encoder_two, vae, teacher_unet, student_unet, model='stabilityai/stable-diffusion-xl-base-1.0', loss=None, unet_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, prediction_type=None, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, vae_batch_size=8, *, finetune_text_encoder=False, gradient_checkpointing=False, pre_compute_text_embeddings=False, enable_xformers=False, student_weight_from_teacher=False)[source]¶
Bases:
diffengine.models.editors.stable_diffusion_xl.StableDiffusionXLSSD1B.
Refer to official implementation: https://github.com/segmind/SSD-1B/blob/main/distill_sdxl.py
Args:¶
tokenizer_one (dict): Config of tokenizer one. tokenizer_two (dict): Config of tokenizer two. scheduler (dict): Config of scheduler. text_encoder_one (dict): Config of text encoder one. text_encoder_two (dict): Config of text encoder two. vae (dict): Config of vae. teacher_unet (dict): Config of teacher unet. student_unet (dict): Config of student unet. model (str): pretrained model name of stable diffusion xl.
Defaults to ‘stabilityai/stable-diffusion-xl-base-1.0’.
- vae_model (str, optional): Path to pretrained VAE model with better
numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038. Defaults to None.
- loss (dict): Config of loss. Defaults to
dict(type='L2Loss', loss_weight=1.0).- unet_lora_config (dict, optional): The LoRA config dict for Unet.
example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.
- text_encoder_lora_config (dict, optional): The LoRA config dict for
Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.
- prior_loss_weight (float): The weight of prior preservation loss.
It works when training dreambooth with class images.
- prediction_type (str): The prediction_type that shall be used for
training. Choose between ‘epsilon’ or ‘v_prediction’ or leave None. If left to None the default prediction type of the scheduler: noise_scheduler.config.prediciton_type is chosen. Defaults to None.
- data_preprocessor (dict, optional): The pre-process config of
SDXLDataPreprocessor.- noise_generator (dict, optional): The noise generator config.
Defaults to
dict(type='WhiteNoise').- timesteps_generator (dict, optional): The timesteps generator config.
Defaults to
dict(type='TimeSteps').- input_perturbation_gamma (float): The gamma of input perturbation.
The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.
vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
- gradient_checkpointing (bool): Whether or not to use gradient
checkpointing to save memory at the expense of slower backward pass. Defaults to False.
- pre_compute_text_embeddings(bool): Whether or not to pre-compute text
embeddings to save memory. Defaults to False.
- enable_xformers (bool): Whether or not to enable memory efficient
attention. Defaults to False.
- student_weight_from_teacher (bool): Whether or not to initialize
student model with teacher model. Defaults to False.
- prepare_model()[source]¶
Prepare model for training.
Disable gradient for some models.
- Return type:
None
- _forward_vae(img, num_batches)[source]¶
Forward vae.
- Parameters:
img (torch.Tensor) –
num_batches (int) –
- Return type:
torch.Tensor
- forward(inputs, data_samples=None, mode='loss')[source]¶
Forward function.
Args:¶
inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.
Defaults to None.
mode (str, optional): The mode. Defaults to “loss”.
Returns:¶
dict: The loss dict.
- Parameters:
inputs (dict) –
data_samples (Optional[list]) –
mode (str) –
- Return type:
dict
- Parameters:
tokenizer_one (dict) –
tokenizer_two (dict) –
scheduler (dict) –
text_encoder_one (dict) –
text_encoder_two (dict) –
vae (dict) –
teacher_unet (dict) –
student_unet (dict) –
model (str) –
loss (dict | None) –
unet_lora_config (dict | None) –
text_encoder_lora_config (dict | None) –
prior_loss_weight (float) –
prediction_type (str | None) –
data_preprocessor (dict | torch.nn.Module | None) –
noise_generator (dict | None) –
timesteps_generator (dict | None) –
input_perturbation_gamma (float) –
vae_batch_size (int) –
finetune_text_encoder (bool) –
gradient_checkpointing (bool) –
pre_compute_text_embeddings (bool) –
enable_xformers (bool) –
student_weight_from_teacher (bool) –