onnx_engine.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import onnxruntime
  3. class ONNXEngine:
  4. def __init__(self, onnx_path, use_gpu):
  5. """
  6. :param onnx_path:
  7. """
  8. if not os.path.exists(onnx_path):
  9. raise Exception(f'{onnx_path} is not exists')
  10. providers = ['CPUExecutionProvider']
  11. if use_gpu:
  12. providers = ([
  13. 'TensorrtExecutionProvider',
  14. 'CUDAExecutionProvider',
  15. 'CPUExecutionProvider',
  16. ], )
  17. self.onnx_session = onnxruntime.InferenceSession(onnx_path,
  18. providers=providers)
  19. self.input_name = self.get_input_name(self.onnx_session)
  20. self.output_name = self.get_output_name(self.onnx_session)
  21. def get_output_name(self, onnx_session):
  22. """
  23. output_name = onnx_session.get_outputs()[0].name
  24. :param onnx_session:
  25. :return:
  26. """
  27. output_name = []
  28. for node in onnx_session.get_outputs():
  29. output_name.append(node.name)
  30. return output_name
  31. def get_input_name(self, onnx_session):
  32. """
  33. input_name = onnx_session.get_inputs()[0].name
  34. :param onnx_session:
  35. :return:
  36. """
  37. input_name = []
  38. for node in onnx_session.get_inputs():
  39. input_name.append(node.name)
  40. return input_name
  41. def get_input_feed(self, input_name, image_numpy):
  42. """
  43. input_feed={self.input_name: image_numpy}
  44. :param input_name:
  45. :param image_numpy:
  46. :return:
  47. """
  48. input_feed = {}
  49. for name in input_name:
  50. input_feed[name] = image_numpy
  51. return input_feed
  52. def run(self, image_numpy):
  53. # 输入数据的类型必须与模型一致,以下三种写法都是可以的
  54. input_feed = self.get_input_feed(self.input_name, image_numpy)
  55. result = self.onnx_session.run(self.output_name, input_feed=input_feed)
  56. return result