export_rec.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. import sys
  3. __dir__ = os.path.dirname(os.path.abspath(__file__))
  4. sys.path.append(__dir__)
  5. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  6. import torch
  7. from openrec.modeling import build_model
  8. from openrec.postprocess import build_post_process
  9. from tools.engine.config import Config
  10. from tools.infer_rec import build_rec_process
  11. from tools.utility import ArgsParser
  12. from tools.utils.ckpt import load_ckpt
  13. from tools.utils.logging import get_logger
  14. def to_onnx(model, dummy_input, dynamic_axes, sava_path='model.onnx'):
  15. input_axis_name = ['batch_size', 'channel', 'in_width', 'int_height']
  16. output_axis_name = ['batch_size', 'channel', 'out_width', 'out_height']
  17. torch.onnx.export(
  18. model.to('cpu'),
  19. dummy_input,
  20. sava_path,
  21. input_names=['input'],
  22. output_names=['output'], # the model's output names
  23. dynamic_axes={
  24. 'input': {axis: input_axis_name[axis]
  25. for axis in dynamic_axes},
  26. 'output': {axis: output_axis_name[axis]
  27. for axis in dynamic_axes},
  28. },
  29. )
  30. def export_single_model(model: torch.nn.Module, _cfg, export_dir,
  31. export_config, logger, type):
  32. for layer in model.modules():
  33. if hasattr(layer, 'rep') and not getattr(layer, 'is_repped'):
  34. layer.rep()
  35. os.makedirs(export_dir, exist_ok=True)
  36. export_cfg = {'PostProcess': _cfg['PostProcess']}
  37. export_cfg['Transforms'] = build_rec_process(_cfg)
  38. cfg.save(os.path.join(export_dir, 'config.yaml'), export_cfg)
  39. dummy_input = torch.randn(*export_config['export_shape'], device='cpu')
  40. if type == 'script':
  41. save_path = os.path.join(export_dir, 'model.pt')
  42. trace_model = torch.jit.trace(model, dummy_input, strict=False)
  43. torch.jit.save(trace_model, save_path)
  44. elif type == 'onnx':
  45. save_path = os.path.join(export_dir, 'model.onnx')
  46. to_onnx(model, dummy_input, export_config.get('dynamic_axes', []),
  47. save_path)
  48. else:
  49. raise NotImplementedError
  50. logger.info(f'finish export model to {save_path}')
  51. def main(cfg, type):
  52. _cfg = cfg.cfg
  53. logger = get_logger()
  54. global_config = _cfg['Global']
  55. export_config = _cfg['Export']
  56. # build post process
  57. post_process_class = build_post_process(_cfg['PostProcess'])
  58. char_num = len(getattr(post_process_class, 'character'))
  59. cfg['Architecture']['Decoder']['out_channels'] = char_num
  60. model = build_model(_cfg['Architecture'])
  61. load_ckpt(model, _cfg)
  62. model.eval()
  63. export_dir = export_config.get('export_dir', '')
  64. if not export_dir:
  65. export_dir = os.path.join(global_config.get('output_dir', 'output'),
  66. 'export')
  67. if _cfg['Architecture']['algorithm'] in ['Distillation'
  68. ]: # distillation model
  69. _cfg['PostProcess'][
  70. 'name'] = post_process_class.__class__.__base__.__name__
  71. for model_name in model.model_list:
  72. sub_model_save_path = os.path.join(export_dir, model_name)
  73. export_single_model(
  74. model.model_list[model_name],
  75. _cfg,
  76. sub_model_save_path,
  77. export_config,
  78. logger,
  79. type,
  80. )
  81. else:
  82. export_single_model(model, _cfg, export_dir, export_config, logger,
  83. type)
  84. def parse_args():
  85. parser = ArgsParser()
  86. parser.add_argument('--type',
  87. type=str,
  88. default='onnx',
  89. help='type of export')
  90. args = parser.parse_args()
  91. return args
  92. if __name__ == '__main__':
  93. FLAGS = parse_args()
  94. cfg = Config(FLAGS.config)
  95. FLAGS = vars(FLAGS)
  96. opt = FLAGS.pop('opt')
  97. cfg.merge_dict(FLAGS)
  98. cfg.merge_dict(opt)
  99. main(cfg, FLAGS['type'])