Source code for train

import argparse
import os
import os.path as osp

from mmengine.config import Config, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

from diffengine.configs import cfgs_name_path


[docs]def parse_args(): parser = argparse.ArgumentParser(description="Train a model") parser.add_argument("config", help="train config file path") parser.add_argument("--work-dir", help="the dir to save logs and models") parser.add_argument( "--resume", action="store_true", help="Whether to resume checkpoint.") parser.add_argument( "--amp", action="store_true", default=False, help="enable automatic-mixed-precision training") parser.add_argument( "--cfg-options", nargs="+", action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') parser.add_argument( "--launcher", choices=["none", "pytorch", "slurm", "mpi"], default="none", help="job launcher") # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` # will pass the `--local-rank` parameter to `tools/train.py` instead # of `--local_rank`. parser.add_argument("--local_rank", "--local-rank", type=int, default=0) args = parser.parse_args() if "LOCAL_RANK" not in os.environ: os.environ["LOCAL_RANK"] = str(args.local_rank) return args
[docs]def merge_args(cfg, args): """Merge CLI arguments to config.""" cfg.launcher = args.launcher # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get("work_dir", None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join("./work_dirs", osp.splitext(osp.basename(args.config))[0]) # enable automatic-mixed-precision training if args.amp is True: optim_wrapper = cfg.optim_wrapper.get("type", "OptimWrapper") assert optim_wrapper in ["OptimWrapper", "AmpOptimWrapper"], \ "`--amp` is not supported custom optimizer wrapper type " \ f"`{optim_wrapper}." cfg.optim_wrapper.type = "AmpOptimWrapper" cfg.optim_wrapper.setdefault("loss_scale", "dynamic") # resume training if args.resume: cfg.resume = True cfg.load_from = None if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) return cfg
[docs]def main() -> None: args = parse_args() # parse config if not osp.isfile(args.config): try: args.config = cfgs_name_path[args.config] except KeyError as exc: msg = f"Cannot find {args.config}" raise FileNotFoundError(msg) from exc # load config cfg = Config.fromfile(args.config) # merge cli arguments to config cfg = merge_args(cfg, args) # build the runner from config runner = ( Runner.from_cfg(cfg) if "runner_type" not in cfg else RUNNERS.build(cfg)) # start training runner.train()
if __name__ == "__main__": main()