Source code for publish_model2diffusers

import argparse
import os.path as osp
from pathlib import Path

import torch
from mmengine.config import Config
from mmengine.logging import print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

from diffengine.configs import cfgs_name_path


[docs]def parse_args(): parser = argparse.ArgumentParser( description="Process a checkpoint to be published") parser.add_argument("config", help="train config file path") parser.add_argument("in_file", help="input checkpoint filename") parser.add_argument("out_dir", help="output dir") parser.add_argument( "--save-keys", nargs="+", type=str, default=["unet", "text_encoder", "transformer"], help="keys to save in the published checkpoint") parser.add_argument( "--colossalai", action="store_true", help="whether the checkpoint is trained with colossalai") return parser.parse_args()
[docs]def process_checkpoint(runner, out_dir, save_keys=None) -> None: if save_keys is None: save_keys = ["unet", "text_encoder"] for k in save_keys: model = getattr(runner.model, k) model.save_pretrained(osp.join(out_dir, k)) print_log(f"The published model is saved at {out_dir}.", logger="current")
[docs]def main() -> None: args = parse_args() allowed_save_keys = ["unet", "text_encoder", "transformer"] if not set(args.save_keys).issubset(set(allowed_save_keys)): msg = (f"These metrics are supported: {allowed_save_keys}, " f"but got {args.save_keys}") raise KeyError(msg) # parse config if not osp.isfile(args.config): try: args.config = cfgs_name_path[args.config] except KeyError as exc: msg = f"Cannot find {args.config}" raise FileNotFoundError(msg) from exc cfg = Config.fromfile(args.config) cfg.work_dir = osp.join("./work_dirs", osp.splitext(osp.basename(args.config))[0]) if args.colossalai: cfg.strategy = None cfg.pop("runner_type") # build the runner from config runner = ( Runner.from_cfg(cfg) if "runner_type" not in cfg else RUNNERS.build(cfg)) if args.colossalai: from colossalai.checkpoint_io import GeneralCheckpointIO GeneralCheckpointIO().load_sharded_model(runner.model, Path(args.in_file)) else: state_dict = torch.load(args.in_file) if "state_dict" in state_dict: state_dict = state_dict["state_dict"] runner.model.load_state_dict(state_dict, strict=False) process_checkpoint(runner, args.out_dir, args.save_keys)
if __name__ == "__main__": main()