toonnx.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import os
  2. import sys
  3. __dir__ = os.path.dirname(os.path.abspath(__file__))
  4. sys.path.append(__dir__)
  5. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  6. import torch
  7. from tools.engine.config import Config
  8. from tools.utility import ArgsParser
  9. from tools.utils.logging import get_logger
  10. def to_onnx(model, dummy_input, dynamic_axes, sava_path='model.onnx'):
  11. input_axis_name = ['batch_size', 'channel', 'in_width', 'int_height']
  12. output_axis_name = ['batch_size', 'channel', 'out_width', 'out_height']
  13. torch.onnx.export(
  14. model.to('cpu'),
  15. dummy_input,
  16. sava_path,
  17. input_names=['input'],
  18. output_names=['output'], # the model's output names
  19. dynamic_axes={
  20. 'input': {axis: input_axis_name[axis]
  21. for axis in dynamic_axes},
  22. 'output': {axis: output_axis_name[axis]
  23. for axis in dynamic_axes},
  24. },
  25. )
  26. def main(cfg):
  27. _cfg = cfg.cfg
  28. logger = get_logger()
  29. global_config = _cfg['Global']
  30. export_dir = global_config.get('export_dir', '')
  31. if _cfg['Architecture']['algorithm'] == 'SVTRv2_mobile':
  32. from tools.infer_rec import OpenRecognizer
  33. model = OpenRecognizer(_cfg).model
  34. dynamic_axes = [0, 3]
  35. dummy_input = torch.randn([1, 3, 48, 320], device='cpu')
  36. if not export_dir:
  37. export_dir = os.path.join(
  38. global_config.get('output_dir', 'output'), 'export_rec')
  39. save_path = os.path.join(export_dir, 'rec_model.onnx')
  40. if _cfg['Architecture']['algorithm'] == 'DB_mobile':
  41. from tools.infer_det import OpenDetector
  42. model = OpenDetector(_cfg).model
  43. dynamic_axes = [0, 2, 3]
  44. dummy_input = torch.randn([1, 3, 960, 960], device='cpu')
  45. if not export_dir:
  46. export_dir = os.path.join(
  47. global_config.get('output_dir', 'output'), 'export_det')
  48. save_path = os.path.join(export_dir, 'det_model.onnx')
  49. os.makedirs(export_dir, exist_ok=True)
  50. to_onnx(model, dummy_input, dynamic_axes, save_path)
  51. logger.info(f'finish export model to {save_path}')
  52. def parse_args():
  53. parser = ArgsParser()
  54. args = parser.parse_args()
  55. return args
  56. if __name__ == '__main__':
  57. FLAGS = parse_args()
  58. cfg = Config(FLAGS.config)
  59. FLAGS = vars(FLAGS)
  60. opt = FLAGS.pop('opt')
  61. cfg.merge_dict(FLAGS)
  62. cfg.merge_dict(opt)
  63. main(cfg)