PixArt-α Training

You can also check configs/pixart_alpha/README.md file.

Configs

All configuration files are placed under the configs/pixart_alpha folder.

Following is the example config fixed from the stable_diffusion_xl_pokemon_blip config file in configs/pixart_alpha/pixart_alpha_1024_pokemon_blip.py:

from mmengine.config import read_base

with read_base():
    from .._base_.datasets.pokemon_blip_pixart import *
    from .._base_.default_runtime import *
    from .._base_.models.pixart_alpha_1024 import *
    from .._base_.schedules.stable_diffusion_50e import *

optim_wrapper.update(
    dtype="bfloat16",
    optimizer=dict(lr=2e-6, weight_decay=3e-2),
    clip_grad=dict(max_norm=0.01))

Run training

Run train

# single gpu
$ diffengine train ${CONFIG_FILE}
# multi gpus
$ NPROC_PER_NODE=${GPU_NUM} diffengine train ${CONFIG_FILE}

# Example.
$ diffengine train pixart_alpha_1024_pokemon_blip

Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the diffusers.pipeline module.

Before inferencing, we should convert weights for diffusers format,

$ diffengine convert ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
# Example
$ diffengine convert pixart_alpha_1024_pokemon_blip work_dirs/pixart_alpha_1024_pokemon_blip/epoch_50.pth work_dirs/pixart_alpha_1024_pokemon_blip --save-keys transformer

Then we can run inference.

from pathlib import Path

import torch
from diffusers import PixArtAlphaPipeline, AutoencoderKL, Transformer2DModel
from peft import PeftModel

checkpoint = Path('work_dirs/pixart_alpha_1024_pokemon_blip')
prompt = 'yoda pokemon'

vae = AutoencoderKL.from_pretrained(
    'stabilityai/sd-vae-ft-ema',
)
transformer = Transformer2DModel.from_pretrained(checkpoint, subfolder='transformer')
pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-1024-MS",
    vae=vae,
    transformer=transformer,
    torch_dtype=torch.float32,
).to("cuda")

img = pipe(
    prompt,
    width=1024,
    height=1024,
    num_inference_steps=50,
).images[0]
img.save("demo.png")

Results Example

pixart_alpha_1024_pokemon_blip

example1

You can check configs/pixart_alpha/README.md for more details.