Source code for diffengine.models.editors.ip_adapter.ip_adapter_xl

from typing import Optional

import numpy as np
import torch
from diffusers.models.embeddings import MultiIPAdapterImageProjection
from diffusers.utils import load_image
from PIL import Image
from torch import nn

from diffengine.models.archs import (
    load_ip_adapter,
    process_ip_adapter_state_dict,
    set_unet_ip_adapter,
)
from diffengine.models.editors.ip_adapter.pipeline import (
    StableDiffusionXLPipelineCustomIPAdapter,
)
from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL
from diffengine.registry import MODELS, TRANSFORMS


@MODELS.register_module()
[docs]class IPAdapterXL(StableDiffusionXL): """Stable Diffusion XL IP-Adapter. Args: ---- image_encoder (dict): The image encoder config. image_projection (dict): The image projection config. feature_extractor (dict): The feature extractor config. pretrained_adapter (str, optional): Path to pretrained IP-Adapter. Defaults to None. pretrained_adapter_subfolder (str, optional): Sub folder of pretrained IP-Adapter. Defaults to ''. pretrained_adapter_weights_name (str, optional): Weights name of pretrained IP-Adapter. Defaults to ''. unet_lora_config (dict, optional): The LoRA config dict for Unet. example. dict(type="LoRA", r=4). `type` is chosen from `LoRA`, `LoHa`, `LoKr`. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None. text_encoder_lora_config (dict, optional): The LoRA config dict for Text Encoder. example. dict(type="LoRA", r=4). `type` is chosen from `LoRA`, `LoHa`, `LoKr`. Other config are same as the config of PEFT. https://github.com/huggingface/peft Defaults to None. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. This should be `False` when training ControlNet. Defaults to False. zeros_image_embeddings_prob (float): The probabilities to generate zeros image embeddings. Defaults to 0.1. data_preprocessor (dict, optional): The pre-process config of :class:`SDControlNetDataPreprocessor`. hidden_states_idx (int): Index of the hidden states to be used. Defaults to -2. """ def __init__(self, *args, image_encoder: dict, image_projection: dict, feature_extractor: dict, pretrained_adapter: str | None = None, pretrained_adapter_subfolder: str = "", pretrained_adapter_weights_name: str = "", unet_lora_config: dict | None = None, text_encoder_lora_config: dict | None = None, finetune_text_encoder: bool = False, zeros_image_embeddings_prob: float = 0.1, data_preprocessor: dict | nn.Module | None = None, hidden_states_idx: int = -2, **kwargs) -> None: if data_preprocessor is None: data_preprocessor = {"type": "IPAdapterXLDataPreprocessor"} assert unet_lora_config is None, \ "`unet_lora_config` should be None when training IPAdapter" assert text_encoder_lora_config is None, \ "`text_encoder_lora_config` should be None when training IPAdapter" assert not finetune_text_encoder, \ "`finetune_text_encoder` should be False when training IPAdapter" self.image_encoder_config = image_encoder self.image_projection_config = image_projection self.pretrained_adapter = pretrained_adapter self.pretrained_adapter_subfolder = pretrained_adapter_subfolder self.pretrained_adapter_weights_name = pretrained_adapter_weights_name self.zeros_image_embeddings_prob = zeros_image_embeddings_prob self.hidden_states_idx = hidden_states_idx self.feature_extractor = TRANSFORMS.build(feature_extractor) super().__init__( *args, unet_lora_config=unet_lora_config, text_encoder_lora_config=text_encoder_lora_config, finetune_text_encoder=finetune_text_encoder, data_preprocessor=data_preprocessor, **kwargs) # type: ignore[misc] self.set_ip_adapter()
[docs] def set_lora(self) -> None: """Set LORA for model."""
[docs] def prepare_model(self) -> None: """Prepare model for training. Disable gradient for some models. """ self.image_encoder = MODELS.build(self.image_encoder_config) self.image_projection = MODELS.build( self.image_projection_config, default_args={ "cross_attention_dim": self.unet.config.cross_attention_dim, "image_embed_dim": self.image_encoder.config.projection_dim}) self.image_encoder.requires_grad_(requires_grad=False) super().prepare_model()
[docs] def set_ip_adapter(self) -> None: """Set IP-Adapter for model.""" self.unet.requires_grad_(requires_grad=False) set_unet_ip_adapter(self.unet) if self.pretrained_adapter is not None: load_ip_adapter(self.unet, self.image_projection, self.pretrained_adapter, self.pretrained_adapter_subfolder, self.pretrained_adapter_weights_name)
@torch.no_grad()
[docs] def infer(self, prompt: list[str], example_image: list[str | Image.Image], negative_prompt: str | None = None, height: int | None = None, width: int | None = None, num_inference_steps: int = 50, output_type: str = "pil", **kwargs) -> list[np.ndarray]: """Inference function. Args: ---- prompt (`List[str]`): The prompt or prompts to guide the image generation. example_image (`List[Union[str, Image.Image]]`): The image prompt or prompts to guide the image generation. negative_prompt (`Optional[str]`): The prompt or prompts to guide the image generation. Defaults to None. height (int, optional): The height in pixels of the generated image. Defaults to None. width (int, optional): The width in pixels of the generated image. Defaults to None. num_inference_steps (int): Number of inference steps. Defaults to 50. output_type (str): The output format of the generate image. Choose between 'pil' and 'latent'. Defaults to 'pil'. **kwargs: Other arguments. """ assert len(prompt) == len(example_image) orig_encoder_hid_proj = self.unet.encoder_hid_proj orig_encoder_hid_dim_type = self.unet.config.encoder_hid_dim_type pipeline = StableDiffusionXLPipelineCustomIPAdapter.from_pretrained( self.model, vae=self.vae, text_encoder=self.text_encoder_one, text_encoder_2=self.text_encoder_two, tokenizer=self.tokenizer_one, tokenizer_2=self.tokenizer_two, unet=self.unet, image_encoder=self.image_encoder, feature_extractor=self.feature_extractor, torch_dtype=(torch.float16 if self.device != torch.device("cpu") else torch.float32), hidden_states_idx=self.hidden_states_idx, ) adapter_state_dict = process_ip_adapter_state_dict( self.unet, self.image_projection) # convert IP-Adapter Image Projection layers to diffusers image_projection_layers = [] for state_dict in [adapter_state_dict]: image_projection_layer = ( pipeline.unet._convert_ip_adapter_image_proj_to_diffusers( # noqa state_dict["image_proj"])) image_projection_layer.to( device=pipeline.unet.device, dtype=pipeline.unet.dtype) image_projection_layers.append(image_projection_layer) pipeline.unet.encoder_hid_proj = MultiIPAdapterImageProjection( image_projection_layers) pipeline.unet.config.encoder_hid_dim_type = "ip_image_proj" if self.prediction_type is not None: # set prediction_type of scheduler if defined scheduler_args = {"prediction_type": self.prediction_type} pipeline.scheduler = pipeline.scheduler.from_config( pipeline.scheduler.config, **scheduler_args) pipeline.to(self.device) pipeline.set_progress_bar_config(disable=True) images = [] for p, img in zip(prompt, example_image, strict=True): pil_img = load_image(img) if isinstance(img, str) else img pil_img = pil_img.convert("RGB") image = pipeline( p, ip_adapter_image=pil_img, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, height=height, width=width, output_type=output_type, **kwargs).images[0] if output_type == "latent": images.append(image) else: images.append(np.array(image)) del pipeline, adapter_state_dict torch.cuda.empty_cache() self.unet.encoder_hid_proj = orig_encoder_hid_proj self.unet.config.encoder_hid_dim_type = orig_encoder_hid_dim_type return images
[docs] def forward( self, inputs: dict, data_samples: Optional[list] = None, # noqa mode: str = "loss") -> dict: """Forward function. Args: ---- inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples. Defaults to None. mode (str, optional): The mode. Defaults to "loss". Returns: ------- dict: The loss dict. """ assert mode == "loss" inputs["text_one"] = self.tokenizer_one( inputs["text"], max_length=self.tokenizer_one.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(self.device) inputs["text_two"] = self.tokenizer_two( inputs["text"], max_length=self.tokenizer_two.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(self.device) num_batches = len(inputs["img"]) if "result_class_image" in inputs: # use prior_loss_weight weight = torch.cat([ torch.ones((num_batches // 2, )), torch.ones((num_batches // 2, )) * self.prior_loss_weight, ]).float().reshape(-1, 1, 1, 1) else: weight = None latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) timesteps = self.timesteps_generator(self.scheduler, num_batches, self.device) noisy_latents = self._preprocess_model_input(latents, noise, timesteps) prompt_embeds, pooled_prompt_embeds = self.encode_prompt( inputs["text_one"], inputs["text_two"]) unet_added_conditions = { "time_ids": inputs["time_ids"], "text_embeds": pooled_prompt_embeds, } # encode image image_embeds = self.image_encoder(inputs["clip_img"]).image_embeds # random zeros image embeddings mask = torch.multinomial( torch.Tensor([ self.zeros_image_embeddings_prob, 1 - self.zeros_image_embeddings_prob, ]), len(image_embeds), replacement=True).to(image_embeds) image_embeds = (image_embeds * mask.view(-1, 1)).view(num_batches, 1, 1, -1) ip_tokens = self.image_projection(image_embeds) model_pred = self.unet( noisy_latents, timesteps, (prompt_embeds, ip_tokens), added_cond_kwargs=unet_added_conditions).sample return self.loss(model_pred, noise, latents, timesteps, weight)
@MODELS.register_module()
[docs]class IPAdapterXLPlus(IPAdapterXL): """Stable Diffusion XL IP-Adapter Plus."""
[docs] def prepare_model(self) -> None: """Prepare model for training. Disable gradient for some models. """ self.image_encoder = MODELS.build(self.image_encoder_config) self.image_projection = MODELS.build( self.image_projection_config, default_args={ "embed_dims": self.image_encoder.config.hidden_size, "output_dims": self.unet.config.cross_attention_dim}) self.image_encoder.requires_grad_(requires_grad=False) super(IPAdapterXL, self).prepare_model()
[docs] def forward( self, inputs: dict, data_samples: Optional[list] = None, # noqa mode: str = "loss") -> dict: """Forward function. Args: ---- inputs (dict): The input dict. data_samples (Optional[list], optional): The data samples. Defaults to None. mode (str, optional): The mode. Defaults to "loss". Returns: ------- dict: The loss dict. """ assert mode == "loss" inputs["text_one"] = self.tokenizer_one( inputs["text"], max_length=self.tokenizer_one.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(self.device) inputs["text_two"] = self.tokenizer_two( inputs["text"], max_length=self.tokenizer_two.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(self.device) num_batches = len(inputs["img"]) if "result_class_image" in inputs: # use prior_loss_weight weight = torch.cat([ torch.ones((num_batches // 2, )), torch.ones((num_batches // 2, )) * self.prior_loss_weight, ]).float().reshape(-1, 1, 1, 1) else: weight = None latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) timesteps = self.timesteps_generator(self.scheduler, num_batches, self.device) noisy_latents = self._preprocess_model_input(latents, noise, timesteps) prompt_embeds, pooled_prompt_embeds = self.encode_prompt( inputs["text_one"], inputs["text_two"]) unet_added_conditions = { "time_ids": inputs["time_ids"], "text_embeds": pooled_prompt_embeds, } # random zeros image clip_img = inputs["clip_img"] mask = torch.multinomial( torch.Tensor([ self.zeros_image_embeddings_prob, 1 - self.zeros_image_embeddings_prob, ]), len(clip_img), replacement=True).to(clip_img) clip_img = clip_img * mask.view(-1, 1, 1, 1) # encode image if self.hidden_states_idx == -1: image_embeds = self.image_encoder( clip_img, output_hidden_states=True, ).last_hidden_state else: image_embeds = self.image_encoder( clip_img, output_hidden_states=True, ).hidden_states[self.hidden_states_idx] ip_tokens = self.image_projection(image_embeds) model_pred = self.unet( noisy_latents, timesteps, (prompt_embeds, ip_tokens), added_cond_kwargs=unet_added_conditions).sample return self.loss(model_pred, noise, latents, timesteps, weight)