from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from diffengine.models.editors import LatentConsistencyModelsXL
from diffengine.models.losses import HuberLoss
[docs]base_model = "stabilityai/stable-diffusion-xl-base-1.0"
[docs]model = dict(type=LatentConsistencyModelsXL,
model=base_model,
tokenizer_one=dict(type=AutoTokenizer.from_pretrained,
subfolder="tokenizer",
use_fast=False),
tokenizer_two=dict(type=AutoTokenizer.from_pretrained,
subfolder="tokenizer_2",
use_fast=False),
scheduler=dict(type=DDPMScheduler.from_pretrained,
subfolder="scheduler"),
text_encoder_one=dict(type=CLIPTextModel.from_pretrained,
subfolder="text_encoder"),
text_encoder_two=dict(type=CLIPTextModelWithProjection.from_pretrained,
subfolder="text_encoder_2"),
vae=dict(
type=AutoencoderKL.from_pretrained,
pretrained_model_name_or_path="madebyollin/sdxl-vae-fp16-fix"),
unet=dict(type=UNet2DConditionModel.from_pretrained,
subfolder="unet"),
loss=dict(type=HuberLoss),
pre_compute_text_embeddings=True,
gradient_checkpointing=True,
unet_lora_config=dict(
type="LoRA",
r=8,
lora_alpha=8,
target_modules=[
"to_q",
"to_k",
"to_v",
"to_out.0",
"proj_in",
"proj_out",
"ff.net.0.proj",
"ff.net.2",
"conv1",
"conv2",
"conv_shortcut",
"downsamplers.0.conv",
"upsamplers.0.conv",
"time_emb_proj",
]))