Source code for diffengine.models.archs.ip_adapter

import copy
from collections import OrderedDict

import torch
import torch.nn.functional as F  # noqa
from diffusers.models.attention_processor import (
    AttnProcessor,
    AttnProcessor2_0,
    IPAdapterAttnProcessor,
    IPAdapterAttnProcessor2_0,
)
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import _get_model_file
from safetensors import safe_open
from torch import nn


[docs]def set_unet_ip_adapter(unet: nn.Module) -> None: """Set IP-Adapter for Unet. Args: ---- unet (nn.Module): The unet to set IP-Adapter. """ attn_procs = {} key_id = 1 for name in unet.attn_processors: cross_attention_dim = None if name.endswith( "attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = ( AttnProcessor2_0 if hasattr( F, "scaled_dot_product_attention") else AttnProcessor ) attn_procs[name] = attn_processor_class() else: attn_processor_class = ( IPAdapterAttnProcessor2_0 if hasattr( F, "scaled_dot_product_attention", ) else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, ).to(dtype=unet.dtype, device=unet.device) key_id += 2 unet.set_attn_processor(attn_procs)
[docs]def load_ip_adapter( # noqa: C901, PLR0912 unet: nn.Module, image_projection: nn.Module, pretrained_adapter: str, subfolder: str, weights_name: str) -> None: """Load IP-Adapter pretrained weights. Reference to diffusers/loaders/ip_adapter.py. and diffusers/loaders/unet.py. """ model_file = _get_model_file( pretrained_adapter, subfolder=subfolder, weights_name=weights_name, cache_dir=None, force_download=False, resume_download=False, proxies=None, local_files_only=None, token=None, revision=None, user_agent={ "file_type": "attn_procs_weights", "framework": "pytorch", }) if weights_name.endswith(".safetensors"): state_dict: dict = {"image_proj": {}, "ip_adapter": {}} with safe_open(model_file, framework="pt", device="cpu") as f: for key in f: if key.startswith("image_proj."): state_dict["image_proj"][ key.replace("image_proj.", "")] = f.get_tensor(key) elif key.startswith("ip_adapter."): state_dict["ip_adapter"][ key.replace("ip_adapter.", "")] = f.get_tensor(key) else: state_dict = torch.load(model_file, map_location="cpu") key_id = 1 for name, attn_proc in unet.attn_processors.items(): cross_attention_dim = None if name.endswith( "attn1.processor") else unet.config.cross_attention_dim if cross_attention_dim is None or "motion_modules" in name: continue value_dict = {} value_dict.update( {"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) value_dict.update( {"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) attn_proc.load_state_dict(value_dict) key_id += 2 image_proj_state_dict = {} if "proj.weight" in state_dict["image_proj"]: # IP-Adapter for key, value in state_dict["image_proj"].items(): diffusers_name = key.replace("proj", "image_embeds") image_proj_state_dict[diffusers_name] = value elif "proj.3.weight" in state_dict["image_proj"]: # IP-Adapter Full for key, value in state_dict["image_proj"].items(): diffusers_name = key.replace("proj.0", "ff.net.0.proj") diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") diffusers_name = diffusers_name.replace("proj.3", "norm") image_proj_state_dict[diffusers_name] = value else: # IP-Adapter Plus for key, value in state_dict["image_proj"].items(): diffusers_name = key.replace("0.to", "2.to") diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias") diffusers_name = diffusers_name.replace( "1.1.weight", "3.1.net.0.proj.weight") diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight") if "norm1" in diffusers_name: image_proj_state_dict[ diffusers_name.replace("0.norm1", "0")] = value elif "norm2" in diffusers_name: image_proj_state_dict[ diffusers_name.replace("0.norm2", "1")] = value elif "to_kv" in diffusers_name: v_chunk = value.chunk(2, dim=0) image_proj_state_dict[ diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] image_proj_state_dict[ diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] elif "to_out" in diffusers_name: image_proj_state_dict[ diffusers_name.replace("to_out", "to_out.0")] = value else: image_proj_state_dict[diffusers_name] = value image_projection.load_state_dict(image_proj_state_dict) del image_proj_state_dict, state_dict torch.cuda.empty_cache()
[docs]def process_ip_adapter_state_dict( # noqa: PLR0915, C901, PLR0912 unet: nn.Module, image_projection: nn.Module) -> dict: """Process IP-Adapter state dict.""" adapter_modules = torch.nn.ModuleList([ v if isinstance(v, nn.Module) else nn.Identity( ) for v in copy.deepcopy(unet.attn_processors).values()]) adapter_state_dict = OrderedDict() for k, v in adapter_modules.state_dict().items(): new_k = k.replace(".0.weight", ".weight") adapter_state_dict[new_k] = v # not save no grad key ip_image_projection_state_dict = OrderedDict() if isinstance(image_projection, ImageProjection): for k, v in image_projection.state_dict().items(): new_k = k.replace("image_embeds.", "proj.") ip_image_projection_state_dict[new_k] = v elif isinstance(image_projection, IPAdapterPlusImageProjection): for k, v in image_projection.state_dict().items(): if "2.to" in k: new_k = k.replace("2.to", "0.to") elif "layers.3.0.weight" in k: new_k = k.replace("layers.3.0.weight", "layers.3.0.norm1.weight") elif "layers.3.0.bias" in k: new_k = k.replace("layers.3.0.bias", "layers.3.0.norm1.bias") elif "layers.3.1.weight" in k: new_k = k.replace("layers.3.1.weight", "layers.3.0.norm2.weight") elif "layers.3.1.bias" in k: new_k = k.replace("layers.3.1.bias", "layers.3.0.norm2.bias") elif "3.0.weight" in k: new_k = k.replace("3.0.weight", "1.0.weight") elif "3.0.bias" in k: new_k = k.replace("3.0.bias", "1.0.bias") elif "3.0.weight" in k: new_k = k.replace("3.0.weight", "1.0.weight") elif "3.1.net.0.proj.weight" in k: new_k = k.replace("3.1.net.0.proj.weight", "1.1.weight") elif "3.1.net.2.weight" in k: new_k = k.replace("3.1.net.2.weight", "1.3.weight") elif "layers.0.0" in k: new_k = k.replace("layers.0.0", "layers.0.0.norm1") elif "layers.0.1" in k: new_k = k.replace("layers.0.1", "layers.0.0.norm2") elif "layers.1.0" in k: new_k = k.replace("layers.1.0", "layers.1.0.norm1") elif "layers.1.1" in k: new_k = k.replace("layers.1.1", "layers.1.0.norm2") elif "layers.2.0" in k: new_k = k.replace("layers.2.0", "layers.2.0.norm1") elif "layers.2.1" in k: new_k = k.replace("layers.2.1", "layers.2.0.norm2") else: new_k = k if "norm_cross" in new_k: ip_image_projection_state_dict[new_k.replace("norm_cross", "norm1")] = v elif "layer_norm" in new_k: ip_image_projection_state_dict[new_k.replace("layer_norm", "norm2")] = v elif "to_k" in new_k: ip_image_projection_state_dict[ new_k.replace("to_k", "to_kv")] = torch.cat([ v, image_projection.state_dict()[k.replace("to_k", "to_v")]], dim=0) elif "to_v" in new_k: continue elif "to_out.0" in new_k: ip_image_projection_state_dict[new_k.replace("to_out.0", "to_out")] = v else: ip_image_projection_state_dict[new_k] = v return {"image_proj": ip_image_projection_state_dict, "ip_adapter": adapter_state_dict}