infer_e2e_parallel.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import threading
  5. import queue
  6. import os
  7. import sys
  8. import time
  9. __dir__ = os.path.dirname(os.path.abspath(__file__))
  10. sys.path.append(__dir__)
  11. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  12. import numpy as np
  13. import cv2
  14. import json
  15. from PIL import Image
  16. from tools.utils.utility import get_image_file_list, check_and_read
  17. from tools.infer_rec import OpenRecognizer
  18. from tools.infer_det import OpenDetector
  19. from tools.infer_e2e import check_and_download_font, sorted_boxes
  20. from tools.engine.config import Config
  21. from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
  22. class OpenOCRParallel:
  23. def __init__(self, drop_score=0.5, det_box_type='quad', max_rec_threads=1):
  24. cfg_det = Config(
  25. './configs/det/dbnet/repvit_db.yml').cfg # mobile model
  26. # cfg_rec = Config('./configs/rec/svtrv2/svtrv2_ch.yml').cfg # server model
  27. cfg_rec = Config(
  28. './configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model
  29. self.text_detector = OpenDetector(cfg_det, numId=0)
  30. self.text_recognizer = OpenRecognizer(cfg_rec, numId=0)
  31. self.det_box_type = det_box_type
  32. self.drop_score = drop_score
  33. self.queue = queue.Queue(
  34. ) # Queue to hold detected boxes for recognition
  35. self.results = {}
  36. self.lock = threading.Lock() # Lock for thread-safe access to results
  37. self.max_rec_threads = max_rec_threads
  38. self.stop_signal = threading.Event() # Signal to stop threads
  39. def start_recognition_threads(self):
  40. """Start recognition threads."""
  41. self.rec_threads = []
  42. for _ in range(self.max_rec_threads):
  43. t = threading.Thread(target=self.recognize_text)
  44. t.start()
  45. self.rec_threads.append(t)
  46. def detect_text(self, image_list):
  47. """Single-threaded text detection for all images."""
  48. for image_id, (img_numpy, ori_img) in enumerate(image_list):
  49. dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes']
  50. if dt_boxes is None:
  51. self.results[image_id] = [] # If no boxes, set empty results
  52. continue
  53. dt_boxes = sorted_boxes(dt_boxes)
  54. img_crop_list = []
  55. for box in dt_boxes:
  56. tmp_box = np.array(box).astype(np.float32)
  57. img_crop = (get_rotate_crop_image(ori_img, tmp_box)
  58. if self.det_box_type == 'quad' else
  59. get_minarea_rect_crop(ori_img, tmp_box))
  60. img_crop_list.append(img_crop)
  61. self.queue.put(
  62. (image_id, dt_boxes, img_crop_list
  63. )) # Put image ID, detected box, and cropped image in queue
  64. # Signal that no more items will be added to the queue
  65. self.stop_signal.set()
  66. def recognize_text(self):
  67. """Recognize text in each cropped image."""
  68. while not self.stop_signal.is_set() or not self.queue.empty():
  69. try:
  70. image_id, boxs, img_crop_list = self.queue.get(timeout=0.5)
  71. rec_results = self.text_recognizer(
  72. img_numpy_list=img_crop_list, batch_num=6)
  73. for rec_result, box in zip(rec_results, boxs):
  74. text, score = rec_result['text'], rec_result['score']
  75. if score >= self.drop_score:
  76. with self.lock:
  77. # Ensure results dictionary has a list for each image ID
  78. if image_id not in self.results:
  79. self.results[image_id] = []
  80. self.results[image_id].append({
  81. 'transcription':
  82. text,
  83. 'points':
  84. box.tolist(),
  85. 'score':
  86. score
  87. })
  88. self.queue.task_done()
  89. except queue.Empty:
  90. continue
  91. def process_images(self, image_list):
  92. """Process a list of images."""
  93. # Initialize results dictionary
  94. self.results = {i: [] for i in range(len(image_list))}
  95. # Start recognition threads
  96. t_start_1 = time.time()
  97. self.start_recognition_threads()
  98. # Start detection in the main thread
  99. t_start = time.time()
  100. self.detect_text(image_list)
  101. print('det time:', time.time() - t_start)
  102. # Wait for recognition threads to finish
  103. for t in self.rec_threads:
  104. t.join()
  105. self.stop_signal.clear()
  106. print('all time:', time.time() - t_start_1)
  107. return self.results
  108. def main(cfg_det, cfg_rec):
  109. img_path = './testA/'
  110. image_file_list = get_image_file_list(img_path)
  111. drop_score = 0.5
  112. text_sys = OpenOCRParallel(
  113. drop_score=drop_score,
  114. det_box_type='quad') # det_box_type: 'quad' or 'poly'
  115. is_visualize = False
  116. if is_visualize:
  117. font_path = './simfang.ttf'
  118. check_and_download_font(font_path)
  119. draw_img_save_dir = img_path + 'e2e_results/' if img_path[
  120. -1] != '/' else img_path[:-1] + 'e2e_results/'
  121. os.makedirs(draw_img_save_dir, exist_ok=True)
  122. save_results = []
  123. # Prepare images
  124. images = []
  125. t_start = time.time()
  126. for image_file in image_file_list:
  127. img, flag_gif, flag_pdf = check_and_read(image_file)
  128. if not flag_gif and not flag_pdf:
  129. img = cv2.imread(image_file)
  130. if img is not None:
  131. images.append((img, img.copy()))
  132. results = text_sys.process_images(images)
  133. print(f'time cost: {time.time() - t_start}')
  134. # Save results and visualize
  135. for image_id, res in results.items():
  136. image_file = image_file_list[image_id]
  137. save_pred = f'{os.path.basename(image_file)}\t{json.dumps(res, ensure_ascii=False)}\n'
  138. # print(save_pred)
  139. save_results.append(save_pred)
  140. if is_visualize:
  141. dt_boxes = [result['points'] for result in res]
  142. rec_res = [result['transcription'] for result in res]
  143. rec_score = [result['score'] for result in res]
  144. image = Image.fromarray(
  145. cv2.cvtColor(images[image_id][0], cv2.COLOR_BGR2RGB))
  146. draw_img = draw_ocr_box_txt(image,
  147. dt_boxes,
  148. rec_res,
  149. rec_score,
  150. drop_score=drop_score,
  151. font_path=font_path)
  152. save_file = os.path.join(draw_img_save_dir,
  153. os.path.basename(image_file))
  154. cv2.imwrite(save_file, draw_img[:, :, ::-1])
  155. with open(os.path.join(draw_img_save_dir, 'system_results.txt'),
  156. 'w',
  157. encoding='utf-8') as f:
  158. f.writelines(save_results)
  159. if __name__ == '__main__':
  160. main()