1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- import os
- import sys
- __dir__ = os.path.dirname(os.path.abspath(__file__))
- sys.path.append(__dir__)
- sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
- import torch
- from tools.engine.config import Config
- from tools.utility import ArgsParser
- from tools.utils.logging import get_logger
- def to_onnx(model, dummy_input, dynamic_axes, sava_path='model.onnx'):
- input_axis_name = ['batch_size', 'channel', 'in_width', 'int_height']
- output_axis_name = ['batch_size', 'channel', 'out_width', 'out_height']
- torch.onnx.export(
- model.to('cpu'),
- dummy_input,
- sava_path,
- input_names=['input'],
- output_names=['output'], # the model's output names
- dynamic_axes={
- 'input': {axis: input_axis_name[axis]
- for axis in dynamic_axes},
- 'output': {axis: output_axis_name[axis]
- for axis in dynamic_axes},
- },
- )
- def main(cfg):
- _cfg = cfg.cfg
- logger = get_logger()
- global_config = _cfg['Global']
- export_dir = global_config.get('export_dir', '')
- if _cfg['Architecture']['algorithm'] == 'SVTRv2_mobile':
- from tools.infer_rec import OpenRecognizer
- model = OpenRecognizer(_cfg).model
- dynamic_axes = [0, 3]
- dummy_input = torch.randn([1, 3, 48, 320], device='cpu')
- if not export_dir:
- export_dir = os.path.join(
- global_config.get('output_dir', 'output'), 'export_rec')
- save_path = os.path.join(export_dir, 'rec_model.onnx')
- if _cfg['Architecture']['algorithm'] == 'DB_mobile':
- from tools.infer_det import OpenDetector
- model = OpenDetector(_cfg).model
- dynamic_axes = [0, 2, 3]
- dummy_input = torch.randn([1, 3, 960, 960], device='cpu')
- if not export_dir:
- export_dir = os.path.join(
- global_config.get('output_dir', 'output'), 'export_det')
- save_path = os.path.join(export_dir, 'det_model.onnx')
- os.makedirs(export_dir, exist_ok=True)
- to_onnx(model, dummy_input, dynamic_axes, save_path)
- logger.info(f'finish export model to {save_path}')
- def parse_args():
- parser = ArgsParser()
- args = parser.parse_args()
- return args
- if __name__ == '__main__':
- FLAGS = parse_args()
- cfg = Config(FLAGS.config)
- FLAGS = vars(FLAGS)
- opt = FLAGS.pop('opt')
- cfg.merge_dict(FLAGS)
- cfg.merge_dict(opt)
- main(cfg)
|