Wuerstchen Training

You can also check configs/wuerstchen/README.md file.

Configs

All configuration files are placed under the configs/wuerstchen folder.

Following is the example config fixed from the wuerstchen_prior_pokemon_blip config file in configs/wuerstchen/wuerstchen_prior_pokemon_blip.py:

from mmengine.config import read_base

with read_base():
    from .._base_.datasets.pokemon_blip_wuerstchen import *
    from .._base_.default_runtime import *
    from .._base_.models.wuerstchen_prior import *
    from .._base_.schedules.stable_diffusion_50e import *

optim_wrapper.update(
    optimizer=dict(lr=1e-5),
    accumulative_counts=4)  # update every four times

Run training

Run train

# single gpu
$ diffengine train ${CONFIG_FILE}
# multi gpus
$ NPROC_PER_NODE=${GPU_NUM} diffengine train ${CONFIG_FILE}

# Example
$ diffengine train wuerstchen_prior_pokemon_blip

Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the diffusers.pipeline module.

import torch
from diffusers import (
    AutoPipelineForText2Image,
)
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior

checkpoint = 'work_dirs/wuerstchen_prior_pokemon_blip/step10450'
prompt = 'A robot pokemon, 4k photo"'

prior = WuerstchenPrior.from_pretrained(
    checkpoint, subfolder='prior', torch_dtype=torch.float16)

pipe = AutoPipelineForText2Image.from_pretrained(
    'warp-ai/wuerstchen', prior_prior=prior, torch_dtype=torch.float16)
pipe.to('cuda')

image = pipe(
    prompt,
    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
    height=768,
    width=768,
    num_inference_steps=50,
).images[0]
image.save('demo.png')

Results Example

wuerstchen_prior_pokemon_blip

example1

You can check configs/wuerstchen/README.md for more details.