Learn about Configs

The config system has a modular and inheritance design, and more details can be found in mmengine docs: CONFIG.

Usually, we use python files as config file. All configuration files are placed under the configs folder, and the directory structure is as follows:

DiffEngine/diffengine/
    ├── configs/
    │   ├── _base_/                       # primitive configuration folder
    │   │   ├── datasets/                      # primitive datasets
    │   │   ├── models/                        # primitive models
    │   │   ├── schedules/                     # primitive schedules
    │   │   └── default_runtime.py             # primitive runtime setting
    │   ├── stable_diffusion/             # Stable Diffusion Algorithms Folder
    │   ├── stable_diffusion_xl/          # Stable Diffusion XL Algorithms Folder
    │   ├── ...
    └── ...

Config Structure

There are four kinds of basic component files in the configs/_base_ folders, namely:

We call all the config files in the _base_ folder as primitive config files. You can easily build your training config file by inheriting some primitive config files.

For easy understanding, we use stable_diffusion_v15_pokemon_blip config file as an example and comment on each line.

from mmengine.config import read_base

with read_base():  # This config file will inherit all config files in `_base_`.
    from .._base_.datasets.pokemon_blip import *           # model settings
    from .._base_.default_runtime import *                 # data settings
    from .._base_.models.stable_diffusion_v15 import *     # schedule settings
    from .._base_.schedules.stable_diffusion_50e import *  # runtime settings

We will explain the four primitive config files separately below.

Model settings

This primitive config file includes a dict variable model, which mainly includes information such as network structure and loss function:

Usually, we use the type field to specify the class of the component and use other fields to pass the initialization arguments of the class.

Following is the model primitive config of the stable_diffusion_v15 config file in configs/_base_/models/stable_diffusion_v15.py:

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

from diffengine.models.editors import StableDiffusion

base_model = "runwayml/stable-diffusion-v1-5"  # pretrained model name of stable diffusion
model = dict(type=StableDiffusion,  # The type of the main model.
             model=base_model,
             tokenizer=dict(  # tokenizer settings
                type=CLIPTokenizer.from_pretrained,
                pretrained_model_name_or_path=base_model,
                subfolder="tokenizer"),
             scheduler=dict(  # scheduler settings
                type=DDPMScheduler.from_pretrained,
                pretrained_model_name_or_path=base_model,
                subfolder="scheduler"),
             text_encoder=dict(  # text encoder settings
                type=CLIPTextModel.from_pretrained,
                pretrained_model_name_or_path=base_model,
                subfolder="text_encoder"),
             vae=dict(  # vae settings
                type=AutoencoderKL.from_pretrained,
                pretrained_model_name_or_path=base_model,
                subfolder="vae"),
             unet=dict(  # unet settings
                type=UNet2DConditionModel.from_pretrained,
                pretrained_model_name_or_path=base_model,
                subfolder="unet"))

Data settings

This primitive config file includes information to construct the dataloader:

Following is the data primitive config of the stable_diffusion_v15 config in [configs/_base_/datasets/pokemon_blip.py]https://github.com/okotaku/diffengine/blob/main/diffengine/configs/base/datasets/pokemon_blip.py):

import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDataset
from diffengine.datasets.transforms import (
    PackInputs,
    RandomCrop,
    RandomHorizontalFlip,
    TorchVisonTransformWrapper,
)
from diffengine.engine.hooks import SDCheckpointHook, VisualizationHook

train_pipeline = [  # augmentation settings
    dict(type=TorchVisonTransformWrapper,
         transform=torchvision.transforms.Resize,
         size=512, interpolation="bilinear"),
    dict(type=RandomCrop, size=512),
    dict(type=RandomHorizontalFlip, p=0.5),
    dict(type=TorchVisonTransformWrapper,
         transform=torchvision.transforms.ToTensor),
    dict(type=TorchVisonTransformWrapper,
         transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
    dict(type=PackInputs),
]
train_dataloader = dict(
    batch_size=4,  # batch size
    num_workers=4,
    dataset=dict(
        type=HFDataset,  # The type of dataset
        dataset="lambdalabs/pokemon-blip-captions",  #  Dataset name or path.
        pipeline=train_pipeline),
    sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
    dict(type=VisualizationHook, prompt=['yoda pokemon'] * 4),  # validation visualize prompt
    dict(type=SDCheckpointHook)
]

Schedule settings

This primitive config file mainly contains training strategy settings and the settings of training, val and test loops:

Following is the schedule primitive config of the stable_diffusion_v15 config in configs/_base_/schedules/stable_diffusion_50e.py

from mmengine.hooks import CheckpointHook
from mmengine.optim import AmpOptimWrapper
from torch.optim import AdamW

optim_wrapper = dict(
    type=AmpOptimWrapper, dtype="float16",  # fp16 optimization
    optimizer=dict(type=AdamW, lr=1e-5, weight_decay=1e-2),  # Use AdamW optimizer to optimize parameters.
    clip_grad=dict(max_norm=1.0))

# Training configuration, iterate 50 epochs.
# 'by_epoch=True' means to use `EpochBaseTrainLoop`, 'by_epoch=False' means to use IterBaseTrainLoop.
train_cfg = dict(by_epoch=True, max_epochs=50)
val_cfg = None
test_cfg = None

default_hooks = dict(
    # save checkpoint per epoch and keep 3 checkpoints.
    checkpoint=dict(
        type=CheckpointHook,
        interval=1,
        max_keep_ckpts=3,
    ))

Runtime settings

This part mainly includes saving the checkpoint strategy, log configuration, training parameters, breakpoint weight path, working directory, etc.

Here is the runtime primitive config file ‘configs/base/default_runtime.py’ file used by almost all configs:

default_scope = 'diffengine'

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,

    # set multi-process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4),

    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

load_from = None
resume = False
randomness = dict(seed=None, deterministic=False)

Inherit and Modify Config File

For easy understanding, we recommend contributors inherit from existing config files. But do not abuse the inheritance. Usually, for all config files, we recommend the maximum inheritance level is 3.

For example, if your config file is based on ResNet with some other modification, you can first inherit the basic stable_diffusion_v15_pokemon_blip structure, dataset and other training settings by specifying _base_ ='./stable_diffusion_v15_pokemon_blip.py' (The path relative to your config file), and then modify the necessary parameters in the config file. A more specific example, now we want to use almost all configs in configs/stable_diffusion/stable_diffusion_v15_pokemon_blip.py, but changing the number of training epochs from 50 to 300, modify pretrained model, modify the learning rate schedule, and modify the dataset path, you can create a new config file configs/stable_diffusion/stable_diffusion_v15_pokemon_blip-300e.py with content as below:

from mmengine.config import read_base

with read_base():  # This config file will inherit all config files in `_base_`.
    from diffengine.configs.stable_diffusion.stable_diffusion_v15_pokemon_blip import * 

# trains more epochs
train_cfg.update(max_epochs=300)  # Train for 300 epochs
param_scheduler = [
    dict(
        type='LinearLR',
        start_factor=1e-3,
        by_epoch=True,
        begin=0,
        end=5,
        convert_to_iter_based=True),
    dict(
        type='CosineAnnealingLR',
        T_max=295,
        eta_min=1e-5,
        by_epoch=True,
        begin=5,
        end=300)
]

# Use your own dataset directory
train_dataloader.update(
    dataset=dict(dataset='mydata/pokemon-blip-captions'),
)

Acknowledgement

This content refers to mmengine docs: CONFIG. Thank you for the great docs.