ckpt.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import os
  2. import torch
  3. from tools.utils.logging import get_logger
  4. def save_ckpt(
  5. model,
  6. cfg,
  7. optimizer,
  8. lr_scheduler,
  9. epoch,
  10. global_step,
  11. metrics,
  12. is_best=False,
  13. logger=None,
  14. prefix=None,
  15. ):
  16. """
  17. Saving checkpoints
  18. :param epoch: current epoch number
  19. :param log: logging information of the epoch
  20. :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
  21. """
  22. if logger is None:
  23. logger = get_logger()
  24. if prefix is None:
  25. if is_best:
  26. save_path = os.path.join(cfg["Global"]["output_dir"], "best.pth")
  27. else:
  28. save_path = os.path.join(cfg["Global"]["output_dir"], "latest.pth")
  29. else:
  30. save_path = os.path.join(cfg["Global"]["output_dir"], prefix + ".pth")
  31. state_dict = model.module.state_dict() if cfg["Global"]["distributed"] else model.state_dict()
  32. state = {
  33. "epoch": epoch,
  34. "global_step": global_step,
  35. "state_dict": state_dict,
  36. "optimizer": None if is_best else optimizer.state_dict(),
  37. "scheduler": None if is_best else lr_scheduler.state_dict(),
  38. "config": cfg,
  39. "metrics": metrics,
  40. }
  41. torch.save(state, save_path)
  42. logger.info(f"save ckpt to {save_path}")
  43. def load_ckpt(model, cfg, optimizer=None, lr_scheduler=None, logger=None):
  44. """
  45. Resume from saved checkpoints
  46. :param checkpoint_path: Checkpoint path to be resumed
  47. """
  48. if logger is None:
  49. logger = get_logger()
  50. checkpoints = cfg["Global"].get("checkpoints")
  51. pretrained_model = cfg["Global"].get("pretrained_model")
  52. status = {}
  53. if checkpoints and os.path.exists(checkpoints):
  54. checkpoint = torch.load(checkpoints, map_location=torch.device("cpu"))
  55. model.load_state_dict(checkpoint["state_dict"], strict=True)
  56. if optimizer is not None:
  57. optimizer.load_state_dict(checkpoint["optimizer"])
  58. if lr_scheduler is not None:
  59. lr_scheduler.load_state_dict(checkpoint["scheduler"])
  60. logger.info(f"resume from checkpoint {checkpoints} (epoch {checkpoint['epoch']})")
  61. status["global_step"] = checkpoint["global_step"]
  62. status["epoch"] = checkpoint["epoch"] + 1
  63. status["metrics"] = checkpoint["metrics"]
  64. elif pretrained_model and os.path.exists(pretrained_model):
  65. load_pretrained_params(model, pretrained_model, logger)
  66. logger.info(f"finetune from checkpoint {pretrained_model}")
  67. else:
  68. logger.info("train from scratch")
  69. return status
  70. def load_pretrained_params(model, pretrained_model, logger):
  71. checkpoint = torch.load(pretrained_model, map_location=torch.device("cpu"))
  72. model.load_state_dict(checkpoint["state_dict"], strict=False)
  73. for name in model.state_dict().keys():
  74. if name not in checkpoint["state_dict"]:
  75. logger.info(f"{name} is not in pretrained model")