diffengine.models.editors.distill_sd¶
Submodules¶
Package Contents¶
Classes¶
Distill Stable Diffusion XL. |
- class diffengine.models.editors.distill_sd.DistillSDXL(*args, model_type, unet_lora_config=None, text_encoder_lora_config=None, finetune_text_encoder=False, **kwargs)[source]¶
Bases:
diffengine.models.editors.stable_diffusion_xl.StableDiffusionXLDistill Stable Diffusion XL.
Args:¶
- model_type (str): The type of model to use. Choice from sd_tiny,
sd_small.
- unet_lora_config (dict, optional): The LoRA config dict for Unet.
example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.
- text_encoder_lora_config (dict, optional): The LoRA config dict for
Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.
- finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. This should be False when training ControlNet. Defaults to False.
- prepare_model()[source]¶
Prepare model for training.
Disable gradient for some models.
- Return type:
None
- forward(inputs, data_samples=None, mode='loss')[source]¶
Forward function.
Args:¶
inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.
Defaults to None.
mode (str, optional): The mode. Defaults to “loss”.
Returns:¶
dict: The loss dict.
- Parameters:
inputs (dict) –
data_samples (Optional[list]) –
mode (str) –
- Return type:
dict
- Parameters:
model_type (str) –
unet_lora_config (dict | None) –
text_encoder_lora_config (dict | None) –
finetune_text_encoder (bool) –