IP-Adapter Training¶
You can also check configs/ip_adapter/README.md
file.
Configs¶
All configuration files are placed under the configs/ip_adapter
folder.
Following is the example config fixed from the stable_diffusion_xl_pokemon_blip_ip_adapter config file in configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py
:
from mmengine.config import read_base
with read_base():
from .._base_.datasets.pokemon_blip_xl_ip_adapter import *
from .._base_.default_runtime import *
from .._base_.models.stable_diffusion_xl_ip_adapter import *
from .._base_.schedules.stable_diffusion_xl_50e import *
train_dataloader.update(batch_size=1)
optim_wrapper.update(accumulative_counts=4) # update every four times
Run training¶
Run train
# single gpu
$ diffengine train ${CONFIG_FILE}
# multi gpus
$ NPROC_PER_NODE=${GPU_NUM} diffengine train ${CONFIG_FILE}
# Example.
$ diffengine train stable_diffusion_xl_pokemon_blip_ip_adapter
Inference with diffengine¶
Once you have trained a model, specify the path to the saved model and utilize it for inference using the diffengine
module.
import torch
from diffusers import DiffusionPipeline, AutoencoderKL
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection
prompt = ''
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="sdxl_models/image_encoder",
torch_dtype=torch.float16,
).to('cuda')
vae = AutoencoderKL.from_pretrained(
'madebyollin/sdxl-vae-fp16-fix',
torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0',
image_encoder=image_encoder,
vae=vae, torch_dtype=torch.float16)
pipe.to('cuda')
pipe.load_ip_adapter("work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/step41650", subfolder="", weight_name="ip_adapter.bin")
image = load_image("https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true")
image = pipe(
prompt,
ip_adapter_image=image,
height=1024,
width=1024,
).images[0]
image.save('demo.png')