123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import os
- import torch
- from tools.utils.logging import get_logger
- def save_ckpt(
- model,
- cfg,
- optimizer,
- lr_scheduler,
- epoch,
- global_step,
- metrics,
- is_best=False,
- logger=None,
- prefix=None,
- ):
- """
- Saving checkpoints
- :param epoch: current epoch number
- :param log: logging information of the epoch
- :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
- """
- if logger is None:
- logger = get_logger()
- if prefix is None:
- if is_best:
- save_path = os.path.join(cfg["Global"]["output_dir"], "best.pth")
- else:
- save_path = os.path.join(cfg["Global"]["output_dir"], "latest.pth")
- else:
- save_path = os.path.join(cfg["Global"]["output_dir"], prefix + ".pth")
- state_dict = model.module.state_dict() if cfg["Global"]["distributed"] else model.state_dict()
- state = {
- "epoch": epoch,
- "global_step": global_step,
- "state_dict": state_dict,
- "optimizer": None if is_best else optimizer.state_dict(),
- "scheduler": None if is_best else lr_scheduler.state_dict(),
- "config": cfg,
- "metrics": metrics,
- }
- torch.save(state, save_path)
- logger.info(f"save ckpt to {save_path}")
- def load_ckpt(model, cfg, optimizer=None, lr_scheduler=None, logger=None):
- """
- Resume from saved checkpoints
- :param checkpoint_path: Checkpoint path to be resumed
- """
- if logger is None:
- logger = get_logger()
- checkpoints = cfg["Global"].get("checkpoints")
- pretrained_model = cfg["Global"].get("pretrained_model")
- status = {}
- if checkpoints and os.path.exists(checkpoints):
- checkpoint = torch.load(checkpoints, map_location=torch.device("cpu"))
- model.load_state_dict(checkpoint["state_dict"], strict=True)
- if optimizer is not None:
- optimizer.load_state_dict(checkpoint["optimizer"])
- if lr_scheduler is not None:
- lr_scheduler.load_state_dict(checkpoint["scheduler"])
- logger.info(f"resume from checkpoint {checkpoints} (epoch {checkpoint['epoch']})")
- status["global_step"] = checkpoint["global_step"]
- status["epoch"] = checkpoint["epoch"] + 1
- status["metrics"] = checkpoint["metrics"]
- elif pretrained_model and os.path.exists(pretrained_model):
- load_pretrained_params(model, pretrained_model, logger)
- logger.info(f"finetune from checkpoint {pretrained_model}")
- else:
- logger.info("train from scratch")
- return status
- def load_pretrained_params(model, pretrained_model, logger):
- checkpoint = torch.load(pretrained_model, map_location=torch.device("cpu"))
- model.load_state_dict(checkpoint["state_dict"], strict=False)
- for name in model.state_dict().keys():
- if name not in checkpoint["state_dict"]:
- logger.info(f"{name} is not in pretrained model")
|