predict_rec.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import os
  2. import sys
  3. __dir__ = os.path.dirname(os.path.abspath(__file__))
  4. sys.path.append(__dir__)
  5. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
  6. import math
  7. import time
  8. import cv2
  9. import numpy as np
  10. from openrec.postprocess import build_post_process
  11. from openrec.preprocess import create_operators, transform
  12. from tools.engine.config import Config
  13. from tools.infer.onnx_engine import ONNXEngine
  14. from tools.infer.utility import check_gpu, parse_args
  15. from tools.utils.logging import get_logger
  16. from tools.utils.utility import check_and_read, get_image_file_list
  17. logger = get_logger()
  18. class TextRecognizer(ONNXEngine):
  19. def __init__(self, args):
  20. if args.rec_model_dir is None or not os.path.exists(
  21. args.rec_model_dir):
  22. raise Exception(
  23. f'args.rec_model_dir is set to {args.rec_model_dir}, but it is not exists'
  24. )
  25. onnx_path = os.path.join(args.rec_model_dir, 'model.onnx')
  26. config_path = os.path.join(args.rec_model_dir, 'config.yaml')
  27. super(TextRecognizer, self).__init__(onnx_path, args.use_gpu)
  28. self.rec_image_shape = [
  29. int(v) for v in args.rec_image_shape.split(',')
  30. ]
  31. self.rec_batch_num = args.rec_batch_num
  32. self.rec_algorithm = args.rec_algorithm
  33. cfg = Config(config_path).cfg
  34. self.ops = create_operators(cfg['Transforms'][1:])
  35. self.postprocess_op = build_post_process(cfg['PostProcess'])
  36. def resize_norm_img(self, img, max_wh_ratio):
  37. imgC, imgH, imgW = self.rec_image_shape
  38. assert imgC == img.shape[2]
  39. imgW = int((imgH * max_wh_ratio))
  40. h, w = img.shape[:2]
  41. ratio = w / float(h)
  42. if math.ceil(imgH * ratio) > imgW:
  43. resized_w = imgW
  44. else:
  45. resized_w = int(math.ceil(imgH * ratio))
  46. resized_image = cv2.resize(img, (resized_w, imgH))
  47. resized_image = resized_image.astype('float32')
  48. resized_image = resized_image.transpose((2, 0, 1)) / 255
  49. resized_image -= 0.5
  50. resized_image /= 0.5
  51. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  52. padding_im[:, :, 0:resized_w] = resized_image
  53. return padding_im
  54. def __call__(self, img_list):
  55. img_num = len(img_list)
  56. # Calculate the aspect ratio of all text bars
  57. width_list = []
  58. for img in img_list:
  59. width_list.append(img.shape[1] / float(img.shape[0]))
  60. # Sorting can speed up the recognition process
  61. indices = np.argsort(np.array(width_list))
  62. rec_res = [['', 0.0]] * img_num
  63. batch_num = self.rec_batch_num
  64. st = time.time()
  65. for beg_img_no in range(0, img_num, batch_num):
  66. end_img_no = min(img_num, beg_img_no + batch_num)
  67. norm_img_batch = []
  68. imgC, imgH, imgW = self.rec_image_shape[:3]
  69. max_wh_ratio = imgW / imgH
  70. # max_wh_ratio = 0
  71. for ino in range(beg_img_no, end_img_no):
  72. h, w = img_list[indices[ino]].shape[0:2]
  73. wh_ratio = w * 1.0 / h
  74. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  75. for ino in range(beg_img_no, end_img_no):
  76. if self.rec_algorithm == 'nrtr':
  77. norm_img = transform({'image': img_list[indices[ino]]},
  78. self.ops)[0]
  79. else:
  80. norm_img = self.resize_norm_img(img_list[indices[ino]],
  81. max_wh_ratio)
  82. norm_img = norm_img[np.newaxis, :]
  83. norm_img_batch.append(norm_img)
  84. norm_img_batch = np.concatenate(norm_img_batch)
  85. norm_img_batch = norm_img_batch.copy()
  86. preds = self.run(norm_img_batch)
  87. if len(preds) == 1:
  88. preds = preds[0]
  89. rec_result = self.postprocess_op({'res': preds})
  90. for rno in range(len(rec_result)):
  91. rec_res[indices[beg_img_no + rno]] = rec_result[rno]
  92. return rec_res, time.time() - st
  93. def main(args):
  94. args.use_gpu = check_gpu(args.use_gpu)
  95. image_file_list = get_image_file_list(args.image_dir)
  96. text_recognizer = TextRecognizer(args)
  97. valid_image_file_list = []
  98. img_list = []
  99. # warmup 2 times
  100. if args.warmup:
  101. img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)
  102. for i in range(2):
  103. text_recognizer([img] * int(args.rec_batch_num))
  104. for image_file in image_file_list:
  105. img, flag, _ = check_and_read(image_file)
  106. if not flag:
  107. img = cv2.imread(image_file)
  108. if img is None:
  109. logger.info(f'error in loading image:{image_file}')
  110. continue
  111. valid_image_file_list.append(image_file)
  112. img_list.append(img)
  113. rec_res, _ = text_recognizer(img_list)
  114. for ino in range(len(img_list)):
  115. logger.info(f'result of {valid_image_file_list[ino]}:{rec_res[ino]}')
  116. if __name__ == '__main__':
  117. main(parse_args())