diffengine.models.editors.wuerstchen.wuerstchen_prior

Module Contents

Classes

WuerstchenPriorModel

`Wuerstchen Prior.

class diffengine.models.editors.wuerstchen.wuerstchen_prior.WuerstchenPriorModel(tokenizer, scheduler, text_encoder, image_encoder, prior, decoder_model='warp-ai/wuerstchen', prior_model='warp-ai/wuerstchen-prior', loss=None, prior_lora_config=None, text_encoder_lora_config=None, prior_loss_weight=1.0, data_preprocessor=None, noise_generator=None, timesteps_generator=None, input_perturbation_gamma=0.0, *, finetune_text_encoder=False, gradient_checkpointing=False)[source]

Bases: mmengine.model.BaseModel

`Wuerstchen Prior.

<https://arxiv.org/abs/2306.00637>`_

Args:

tokenizer (dict): Config of tokenizer. scheduler (dict): Config of scheduler. text_encoder (dict): Config of text encoder. image_encoder (dict): Config of image encoder. prior (dict): Config of prior. decoder_model (str): pretrained decoder model name of Wuerstchen.

Defaults to ‘warp-ai/wuerstchen’.

prior_model (str): pretrained prior model name of Wuerstchen.

Defaults to ‘warp-ai/wuerstchen-prior’.

loss (dict): Config of loss. Defaults to

dict(type='L2Loss', loss_weight=1.0).

prior_lora_config (dict, optional): The LoRA config dict for Prior.

example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

text_encoder_lora_config (dict, optional): The LoRA config dict for

Text Encoder. example. dict(type=”LoRA”, r=4). type is chosen from LoRA, LoHa, LoKr. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None.

prior_loss_weight (float): The weight of prior preservation loss.

It works when training dreambooth with class images.

data_preprocessor (dict, optional): The pre-process config of

SDDataPreprocessor.

noise_generator (dict, optional): The noise generator config.

Defaults to dict(type='WhiteNoise').

timesteps_generator (dict, optional): The timesteps generator config.

Defaults to dict(type='WuerstchenRandomTimeSteps').

input_perturbation_gamma (float): The gamma of input perturbation.

The recommended value is 0.1 for Input Perturbation. Defaults to 0.0.

finetune_text_encoder (bool, optional): Whether to fine-tune text

encoder. Defaults to False.

gradient_checkpointing (bool): Whether or not to use gradient

checkpointing to save memory at the expense of slower backward pass. Defaults to False.

property device: torch.device[source]

Get device information.

Returns:

torch.device

Return type:

device.

set_lora()[source]

Set LORA for model.

Return type:

None

prepare_model()[source]

Prepare model for training.

Disable gradient for some models.

Return type:

None

train(*, mode=True)[source]

Convert the model into training mode.

Parameters:

mode (bool) –

Return type:

None

infer(prompt, negative_prompt=None, height=None, width=None, num_inference_steps=50, output_type='pil', **kwargs)[source]

Inference function.

Args:
prompt (List[str]):

The prompt or prompts to guide the image generation.

negative_prompt (Optional[str]):

The prompt or prompts to guide the image generation. Defaults to None.

height (int, optional):

The height in pixels of the generated image. Defaults to None.

width (int, optional):

The width in pixels of the generated image. Defaults to None.

num_inference_steps (int): Number of inference steps.

Defaults to 50.

output_type (str): The output format of the generate image.

Choose between ‘pil’ and ‘latent’. Defaults to ‘pil’.

**kwargs: Other arguments.

Parameters:
  • prompt (list[str]) –

  • negative_prompt (str | None) –

  • height (int | None) –

  • width (int | None) –

  • num_inference_steps (int) –

  • output_type (str) –

Return type:

list[numpy.ndarray]

val_step(data)[source]

Val step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

test_step(data)[source]

Test step.

Parameters:

data (Union[tuple, dict, list]) –

Return type:

list

loss(model_pred, noise, timesteps, weight=None)[source]

Calculate loss.

Parameters:
  • model_pred (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

  • weight (torch.Tensor | None) –

Return type:

dict[str, torch.Tensor]

_preprocess_model_input(latents, noise, timesteps)[source]

Preprocess model input.

Parameters:
  • latents (torch.Tensor) –

  • noise (torch.Tensor) –

  • timesteps (torch.Tensor) –

Return type:

torch.Tensor

forward(inputs, data_samples=None, mode='loss')[source]

Forward function.

Args:

inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples.

Defaults to None.

mode (str, optional): The mode. Defaults to “loss”.

Returns:

dict: The loss dict.

Parameters:
  • inputs (dict) –

  • data_samples (Optional[list]) –

  • mode (str) –

Return type:

dict

Parameters:
  • tokenizer (dict) –

  • scheduler (dict) –

  • text_encoder (dict) –

  • image_encoder (dict) –

  • prior (dict) –

  • decoder_model (str) –

  • prior_model (str) –

  • loss (dict | None) –

  • prior_lora_config (dict | None) –

  • text_encoder_lora_config (dict | None) –

  • prior_loss_weight (float) –

  • data_preprocessor (dict | torch.nn.Module | None) –

  • noise_generator (dict | None) –

  • timesteps_generator (dict | None) –

  • input_perturbation_gamma (float) –

  • finetune_text_encoder (bool) –

  • gradient_checkpointing (bool) –