import inspect
import math
from copy import deepcopy
from typing import Optional, Union
import numpy as np
import torch
from diffusers import AmusedPipeline
from mmengine import print_log
from mmengine.model import BaseModel
from peft import get_peft_model
from torch import nn
from diffengine.models.archs import create_peft_config
from diffengine.registry import MODELS
@MODELS.register_module()
[docs]class AMUSEd(BaseModel):
"""aMUSEd.
Args:
----
tokenizer (dict): Config of tokenizer.
text_encoder (dict): Config of text encoder.
vae (dict): Config of vae.
transformer (dict): Config of transformer.
model (str): pretrained model name.
Defaults to "amused/amused-512".
loss (dict): Config of loss. Defaults to
``dict(type='L2Loss', loss_weight=1.0)``.
transformer_lora_config (dict, optional): The LoRA config dict for
Transformer.
example. dict(type="LoRA", r=4). `type` is chosen from `LoRA`,
`LoHa`, `LoKr`. Other config are same as the config of PEFT.
https://github.com/huggingface/peft
Defaults to None.
text_encoder_lora_config (dict, optional): The LoRA config dict for
Text Encoder. example. dict(type="LoRA", r=4). `type` is chosen
from `LoRA`, `LoHa`, `LoKr`. Other config are same as the config of
PEFT. https://github.com/huggingface/peft
Defaults to None.
prior_loss_weight (float): The weight of prior preservation loss.
It works when training dreambooth with class images.
data_preprocessor (dict, optional): The pre-process config of
:class:`SDDataPreprocessor`.
vae_batch_size (int): The batch size of vae. Defaults to 8.
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
gradient_checkpointing (bool): Whether or not to use gradient
checkpointing to save memory at the expense of slower backward
pass. Defaults to False.
enable_xformers (bool): Whether or not to enable memory efficient
attention. Defaults to False.
"""
def __init__(
self,
tokenizer: dict,
text_encoder: dict,
vae: dict,
transformer: dict,
model: str = "amused/amused-512",
loss: dict | None = None,
transformer_lora_config: dict | None = None,
text_encoder_lora_config: dict | None = None,
prior_loss_weight: float = 1.,
data_preprocessor: dict | nn.Module | None = None,
vae_batch_size: int = 8,
*,
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
enable_xformers: bool = False,
) -> None:
if data_preprocessor is None:
data_preprocessor = {"type": "AMUSEdPreprocessor"}
if loss is None:
loss = {}
super().__init__(data_preprocessor=data_preprocessor)
if (
transformer_lora_config is not None) and (
text_encoder_lora_config is not None) and (
not finetune_text_encoder):
print_log(
"You are using LoRA for Transformer and text encoder. "
"But you are not set `finetune_text_encoder=True`. "
"We will set `finetune_text_encoder=True` for you.")
finetune_text_encoder = True
if text_encoder_lora_config is not None:
assert finetune_text_encoder, (
"If you want to use LoRA for text encoder, "
"you should set finetune_text_encoder=True."
)
if finetune_text_encoder and transformer_lora_config is not None:
assert text_encoder_lora_config is not None, (
"If you want to finetune text encoder with LoRA Transformer, "
"you should set text_encoder_lora_config."
)
self.model = model
self.transformer_lora_config = deepcopy(transformer_lora_config)
self.text_encoder_lora_config = deepcopy(text_encoder_lora_config)
self.finetune_text_encoder = finetune_text_encoder
self.prior_loss_weight = prior_loss_weight
self.gradient_checkpointing = gradient_checkpointing
self.enable_xformers = enable_xformers
self.vae_batch_size = vae_batch_size
if not isinstance(loss, nn.Module):
loss = MODELS.build(
loss,
default_args={"type": "CrossEntropyLoss", "loss_weight": 1.0})
self.loss_module: nn.Module = loss
self.tokenizer = MODELS.build(
tokenizer,
default_args={"pretrained_model_name_or_path": model,
} if not inspect.isclass(tokenizer.get("type")) else None)
self.text_encoder = MODELS.build(
text_encoder,
default_args={"pretrained_model_name_or_path": model,
} if not inspect.isclass(text_encoder.get("type")) else None)
self.vae = MODELS.build(
vae,
default_args={"pretrained_model_name_or_path": model,
} if not inspect.isclass(vae.get("type")) else None)
self.transformer = MODELS.build(
transformer,
default_args={"pretrained_model_name_or_path": model,
} if not inspect.isclass(transformer.get("type")) else None)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.mask_id = self.transformer.config.vocab_size - 1
self.codebook_size = self.transformer.config.codebook_size
self.prepare_model()
self.set_lora()
self.set_xformers()
[docs] def set_lora(self) -> None:
"""Set LORA for model."""
if self.text_encoder_lora_config is not None:
text_encoder_lora_config = create_peft_config(
self.text_encoder_lora_config)
self.text_encoder = get_peft_model(
self.text_encoder, text_encoder_lora_config)
self.text_encoder.print_trainable_parameters()
if self.transformer_lora_config is not None:
transformer_lora_config = create_peft_config(self.transformer_lora_config)
self.transformer = get_peft_model(self.transformer, transformer_lora_config)
self.transformer.print_trainable_parameters()
[docs] def prepare_model(self) -> None:
"""Prepare model for training.
Disable gradient for some models.
"""
if self.gradient_checkpointing:
self.transformer.enable_gradient_checkpointing()
if self.finetune_text_encoder:
self.text_encoder.gradient_checkpointing_enable()
self.vae.requires_grad_(requires_grad=False)
print_log("Set VAE untrainable.", "current")
if not self.finetune_text_encoder:
self.text_encoder.requires_grad_(requires_grad=False)
print_log("Set Text Encoder untrainable.", "current")
@property
[docs] def device(self) -> torch.device:
"""Get device information.
Returns
-------
torch.device: device.
"""
return next(self.parameters()).device
@torch.no_grad()
[docs] def infer(self,
prompt: list[str],
negative_prompt: str | None = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 12,
output_type: str = "pil",
**kwargs) -> list[np.ndarray]:
"""Inference function.
Args:
----
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
negative_prompt (`Optional[str]`):
The prompt or prompts to guide the image generation.
Defaults to None.
height (int, optional):
The height in pixels of the generated image. Defaults to None.
width (int, optional):
The width in pixels of the generated image. Defaults to None.
num_inference_steps (int): Number of inference steps.
Defaults to 12.
output_type (str): The output format of the generate image.
Choose between 'pil' and 'latent'. Defaults to 'pil'.
**kwargs: Other arguments.
"""
pipeline = AmusedPipeline.from_pretrained(
self.model,
vqvae=self.vae,
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
transformer=self.transformer,
torch_dtype=(torch.float16 if self.device != torch.device("cpu")
else torch.float32),
)
pipeline.to(self.device)
pipeline.set_progress_bar_config(disable=True)
images = []
for p in prompt:
image = pipeline(
p,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
output_type=output_type,
**kwargs).images[0]
if output_type == "latent":
images.append(image)
else:
images.append(np.array(image))
del pipeline
torch.cuda.empty_cache()
return images
[docs] def val_step(
self,
data: Union[tuple, dict, list] # noqa
) -> list:
"""Val step."""
msg = "val_step is not implemented now, please use infer."
raise NotImplementedError(msg)
[docs] def test_step(
self,
data: Union[tuple, dict, list] # noqa
) -> list:
"""Test step."""
msg = "test_step is not implemented now, please use infer."
raise NotImplementedError(msg)
[docs] def _forward_vae(self, img: torch.Tensor, num_batches: int,
) -> torch.Tensor:
"""Forward vae."""
latents = []
for i in range(0, num_batches, self.vae_batch_size):
latents_ = self.vae.encode(img[i : i + self.vae_batch_size]).latents
latents_ = self.vae.quantize(latents_)[2][2].reshape(
num_batches, -1)
latents.append(latents_)
return torch.cat(latents, dim=0)
[docs] def forward(
self,
inputs: dict,
data_samples: Optional[list] = None, # noqa
mode: str = "loss") -> dict:
"""Forward function.
Args:
----
inputs (dict): The input dict.
data_samples (Optional[list], optional): The data samples.
Defaults to None.
mode (str, optional): The mode. Defaults to "loss".
Returns:
-------
dict: The loss dict.
"""
assert mode == "loss"
inputs["text"] = self.tokenizer(
inputs["text"],
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt").input_ids.to(self.device)
num_batches = len(inputs["img"])
if "result_class_image" in inputs:
# use prior_loss_weight
weight = torch.cat([
torch.ones((num_batches // 2, )),
torch.ones((num_batches // 2, )) * self.prior_loss_weight,
]).to(self.device).float().reshape(-1, 1, 1, 1)
else:
weight = None
latents = self._forward_vae(inputs["img"], num_batches)
timesteps = torch.rand(num_batches, device=self.device)
# mask
seq_len = latents.shape[1]
mask_prob = torch.cos(timesteps * math.pi * 0.5)
mask_prob = mask_prob.clip(0.0)
num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
batch_randperm = torch.rand(
num_batches, seq_len, device=self.device).argsort(dim=-1)
mask = batch_randperm < num_token_masked.unsqueeze(-1)
input_ids = torch.where(mask, self.mask_id, latents)
h, w = inputs["img"].shape[-2:]
input_ids = input_ids.reshape(
num_batches,
h // self.vae_scale_factor,
w // self.vae_scale_factor)
labels = torch.where(mask, latents, -100)
outputs = self.text_encoder(
inputs["text"], return_dict=True, output_hidden_states=True)
encoder_hidden_states = outputs.hidden_states[-2]
cond_embeds = outputs[0]
logits = self.transformer(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
micro_conds=inputs["micro_conds"],
pooled_text_emb=cond_embeds).reshape(
num_batches, self.codebook_size, -1).permute(
0, 2, 1).reshape(
-1, self.codebook_size)
loss_dict = dict()
loss = self.loss_module(
logits, labels.view(-1), weight=weight)
loss_dict["loss"] = loss
return loss_dict