infer_e2e.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import os
  5. from pathlib import Path
  6. import sys
  7. __dir__ = os.path.dirname(os.path.abspath(__file__))
  8. sys.path.append(__dir__)
  9. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  10. os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
  11. import argparse
  12. import numpy as np
  13. import copy
  14. import time
  15. import cv2
  16. import json
  17. from PIL import Image
  18. from tools.utils.utility import get_image_file_list, check_and_read
  19. from tools.infer_rec import OpenRecognizer
  20. from tools.infer_det import OpenDetector
  21. from tools.engine.config import Config
  22. from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
  23. from tools.utils.logging import get_logger
  24. root_dir = Path(__file__).resolve().parent
  25. DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
  26. DEFAULT_CFG_PATH_REC_SERVER = str(root_dir /
  27. '../configs/rec/svtrv2/svtrv2_ch.yml')
  28. DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')
  29. logger = get_logger()
  30. def check_and_download_font(font_path):
  31. if not os.path.exists(font_path):
  32. cache_dir = Path.home() / '.cache' / 'openocr'
  33. font_path = str(cache_dir / font_path)
  34. if os.path.exists(font_path):
  35. return font_path
  36. logger.info(f"Downloading '{font_path}' ...")
  37. try:
  38. import urllib.request
  39. font_url = 'https://shuiche-shop.oss-cn-chengdu.aliyuncs.com/fonts/simfang.ttf'
  40. urllib.request.urlretrieve(font_url, font_path)
  41. logger.info(f'Downloading font success: {font_path}')
  42. except Exception as e:
  43. logger.info(f'Downloading font error: {e}')
  44. return font_path
  45. def sorted_boxes(dt_boxes):
  46. """
  47. Sort text boxes in order from top to bottom, left to right
  48. args:
  49. dt_boxes(array):detected text boxes with shape [4, 2]
  50. return:
  51. sorted boxes(array) with shape [4, 2]
  52. """
  53. num_boxes = dt_boxes.shape[0]
  54. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  55. _boxes = list(sorted_boxes)
  56. for i in range(num_boxes - 1):
  57. for j in range(i, -1, -1):
  58. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  59. _boxes[j + 1][0][0] < _boxes[j][0][0]):
  60. tmp = _boxes[j]
  61. _boxes[j] = _boxes[j + 1]
  62. _boxes[j + 1] = tmp
  63. else:
  64. break
  65. return _boxes
  66. class OpenOCR(object):
  67. def __init__(self,
  68. mode='mobile',
  69. backend='torch',
  70. onnx_det_model_path=None,
  71. onnx_rec_model_path=None,
  72. drop_score=0.5,
  73. det_box_type='quad',
  74. device='gpu'):
  75. """
  76. 初始化函数,用于初始化OCR引擎的相关配置和组件。
  77. Args:
  78. mode (str, optional): 运行模式,可选值为'mobile'或'server'。默认为'mobile'。
  79. drop_score (float, optional): 检测框的置信度阈值,低于该阈值的检测框将被丢弃。默认为0.5。
  80. det_box_type (str, optional): 检测框的类型,可选值为'quad' and 'poly'。默认为'quad'。
  81. Returns:
  82. 无返回值。
  83. """
  84. cfg_det = Config(DEFAULT_CFG_PATH_DET).cfg # mobile model
  85. cfg_det['Global']['device'] = device
  86. if mode == 'server':
  87. cfg_rec = Config(DEFAULT_CFG_PATH_REC_SERVER).cfg # server model
  88. else:
  89. cfg_rec = Config(DEFAULT_CFG_PATH_REC).cfg # mobile model
  90. cfg_rec['Global']['device'] = device
  91. self.text_detector = OpenDetector(cfg_det,
  92. backend=backend,
  93. onnx_model_path=onnx_det_model_path)
  94. self.text_recognizer = OpenRecognizer(
  95. cfg_rec, backend=backend, onnx_model_path=onnx_rec_model_path)
  96. self.det_box_type = det_box_type
  97. self.drop_score = drop_score
  98. self.crop_image_res_index = 0
  99. def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
  100. os.makedirs(output_dir, exist_ok=True)
  101. bbox_num = len(img_crop_list)
  102. for bno in range(bbox_num):
  103. cv2.imwrite(
  104. os.path.join(output_dir,
  105. f'mg_crop_{bno+self.crop_image_res_index}.jpg'),
  106. img_crop_list[bno],
  107. )
  108. self.crop_image_res_index += bbox_num
  109. def infer_single_image(self,
  110. img_numpy,
  111. ori_img,
  112. crop_infer=False,
  113. rec_batch_num=6,
  114. return_mask=False,
  115. **kwargs):
  116. start = time.time()
  117. if crop_infer:
  118. dt_boxes = self.text_detector.crop_infer(
  119. img_numpy=img_numpy)[0]['boxes']
  120. else:
  121. det_res = self.text_detector(img_numpy=img_numpy,
  122. return_mask=return_mask,
  123. **kwargs)[0]
  124. dt_boxes = det_res['boxes']
  125. # logger.info(dt_boxes)
  126. det_time_cost = time.time() - start
  127. if dt_boxes is None:
  128. return None, None, None
  129. img_crop_list = []
  130. dt_boxes = sorted_boxes(dt_boxes)
  131. for bno in range(len(dt_boxes)):
  132. tmp_box = np.array(copy.deepcopy(dt_boxes[bno])).astype(np.float32)
  133. if self.det_box_type == 'quad':
  134. img_crop = get_rotate_crop_image(ori_img, tmp_box)
  135. else:
  136. img_crop = get_minarea_rect_crop(ori_img, tmp_box)
  137. img_crop_list.append(img_crop)
  138. start = time.time()
  139. rec_res = self.text_recognizer(img_numpy_list=img_crop_list,
  140. batch_num=rec_batch_num)
  141. rec_time_cost = time.time() - start
  142. filter_boxes, filter_rec_res = [], []
  143. rec_time_cost_sig = 0.0
  144. for box, rec_result in zip(dt_boxes, rec_res):
  145. text, score = rec_result['text'], rec_result['score']
  146. rec_time_cost_sig += rec_result['elapse']
  147. if score >= self.drop_score:
  148. filter_boxes.append(box)
  149. filter_rec_res.append([text, score])
  150. avg_rec_time_cost = rec_time_cost_sig / len(dt_boxes) if len(
  151. dt_boxes) > 0 else 0.0
  152. if return_mask:
  153. return filter_boxes, filter_rec_res, {
  154. 'time_cost': det_time_cost + rec_time_cost,
  155. 'detection_time': det_time_cost,
  156. 'recognition_time': rec_time_cost,
  157. 'avg_rec_time_cost': avg_rec_time_cost
  158. }, det_res['mask']
  159. return filter_boxes, filter_rec_res, {
  160. 'time_cost': det_time_cost + rec_time_cost,
  161. 'detection_time': det_time_cost,
  162. 'recognition_time': rec_time_cost,
  163. 'avg_rec_time_cost': avg_rec_time_cost
  164. }
  165. def __call__(self,
  166. img_path=None,
  167. save_dir='e2e_results/',
  168. is_visualize=False,
  169. img_numpy=None,
  170. rec_batch_num=6,
  171. crop_infer=False,
  172. return_mask=False,
  173. **kwargs):
  174. """
  175. img_path: str, optional, default=None
  176. Path to the directory containing images or the image filename.
  177. save_dir: str, optional, default='e2e_results/'
  178. Directory to save prediction and visualization results. Defaults to a subfolder in img_path.
  179. is_visualize: bool, optional, default=False
  180. Visualize the results.
  181. img_numpy: numpy or list[numpy], optional, default=None
  182. numpy of an image or List of numpy arrays representing images.
  183. rec_batch_num: int, optional, default=6
  184. Batch size for text recognition.
  185. crop_infer: bool, optional, default=False
  186. Whether to use crop inference.
  187. """
  188. if img_numpy is None and img_path is None:
  189. raise ValueError('img_path and img_numpy cannot be both None.')
  190. if img_numpy is not None:
  191. if not isinstance(img_numpy, list):
  192. img_numpy = [img_numpy]
  193. results = []
  194. time_dicts = []
  195. for index, img in enumerate(img_numpy):
  196. ori_img = img.copy()
  197. if return_mask:
  198. dt_boxes, rec_res, time_dict, mask = self.infer_single_image(
  199. img_numpy=img,
  200. ori_img=ori_img,
  201. crop_infer=crop_infer,
  202. rec_batch_num=rec_batch_num,
  203. return_mask=return_mask,
  204. **kwargs)
  205. else:
  206. dt_boxes, rec_res, time_dict = self.infer_single_image(
  207. img_numpy=img,
  208. ori_img=ori_img,
  209. crop_infer=crop_infer,
  210. rec_batch_num=rec_batch_num,
  211. **kwargs)
  212. if dt_boxes is None:
  213. results.append([])
  214. time_dicts.append({})
  215. continue
  216. res = [{
  217. 'transcription': rec_res[i][0],
  218. 'points': np.array(dt_boxes[i]).tolist(),
  219. 'score': rec_res[i][1],
  220. } for i in range(len(dt_boxes))]
  221. results.append(res)
  222. time_dicts.append(time_dict)
  223. if return_mask:
  224. return results, time_dicts, mask
  225. return results, time_dicts
  226. image_file_list = get_image_file_list(img_path)
  227. save_results = []
  228. time_dicts_return = []
  229. for idx, image_file in enumerate(image_file_list):
  230. img, flag_gif, flag_pdf = check_and_read(image_file)
  231. if not flag_gif and not flag_pdf:
  232. img = cv2.imread(image_file)
  233. if not flag_pdf:
  234. if img is None:
  235. return None
  236. imgs = [img]
  237. else:
  238. imgs = img
  239. logger.info(
  240. f'Processing {idx+1}/{len(image_file_list)}: {image_file}')
  241. res_list = []
  242. time_dicts = []
  243. for index, img_numpy in enumerate(imgs):
  244. ori_img = img_numpy.copy()
  245. dt_boxes, rec_res, time_dict = self.infer_single_image(
  246. img_numpy=img_numpy,
  247. ori_img=ori_img,
  248. crop_infer=crop_infer,
  249. rec_batch_num=rec_batch_num,
  250. **kwargs)
  251. if dt_boxes is None:
  252. res_list.append([])
  253. time_dicts.append({})
  254. continue
  255. res = [{
  256. 'transcription': rec_res[i][0],
  257. 'points': np.array(dt_boxes[i]).tolist(),
  258. 'score': rec_res[i][1],
  259. } for i in range(len(dt_boxes))]
  260. res_list.append(res)
  261. time_dicts.append(time_dict)
  262. for index, (res, time_dict) in enumerate(zip(res_list,
  263. time_dicts)):
  264. if len(res) > 0:
  265. logger.info(f'Results: {res}.')
  266. logger.info(f'Time cost: {time_dict}.')
  267. else:
  268. logger.info('No text detected.')
  269. if len(res_list) > 1:
  270. save_pred = (os.path.basename(image_file) + '_' +
  271. str(index) + '\t' +
  272. json.dumps(res, ensure_ascii=False) + '\n')
  273. else:
  274. if len(res) > 0:
  275. save_pred = (os.path.basename(image_file) + '\t' +
  276. json.dumps(res, ensure_ascii=False) +
  277. '\n')
  278. else:
  279. continue
  280. save_results.append(save_pred)
  281. time_dicts_return.append(time_dict)
  282. if is_visualize and len(res) > 0:
  283. if idx == 0:
  284. font_path = './simfang.ttf'
  285. font_path = check_and_download_font(font_path)
  286. os.makedirs(save_dir, exist_ok=True)
  287. draw_img_save_dir = os.path.join(
  288. save_dir, 'vis_results/')
  289. os.makedirs(draw_img_save_dir, exist_ok=True)
  290. logger.info(
  291. f'Visualized results will be saved to {draw_img_save_dir}.'
  292. )
  293. dt_boxes = [res[i]['points'] for i in range(len(res))]
  294. rec_res = [
  295. res[i]['transcription'] for i in range(len(res))
  296. ]
  297. rec_score = [res[i]['score'] for i in range(len(res))]
  298. image = Image.fromarray(
  299. cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
  300. boxes = dt_boxes
  301. txts = [rec_res[i] for i in range(len(rec_res))]
  302. scores = [rec_score[i] for i in range(len(rec_res))]
  303. draw_img = draw_ocr_box_txt(
  304. image,
  305. boxes,
  306. txts,
  307. scores,
  308. drop_score=self.drop_score,
  309. font_path=font_path,
  310. )
  311. if flag_gif:
  312. save_file = image_file[:-3] + 'png'
  313. elif flag_pdf:
  314. save_file = image_file.replace(
  315. '.pdf', '_' + str(index) + '.png')
  316. else:
  317. save_file = image_file
  318. cv2.imwrite(
  319. os.path.join(draw_img_save_dir,
  320. os.path.basename(save_file)),
  321. draw_img[:, :, ::-1],
  322. )
  323. if save_results:
  324. os.makedirs(save_dir, exist_ok=True)
  325. with open(os.path.join(save_dir, 'system_results.txt'),
  326. 'w',
  327. encoding='utf-8') as f:
  328. f.writelines(save_results)
  329. logger.info(
  330. f"Results saved to {os.path.join(save_dir, 'system_results.txt')}."
  331. )
  332. if is_visualize:
  333. logger.info(
  334. f'Visualized results saved to {draw_img_save_dir}.')
  335. return save_results, time_dicts_return
  336. else:
  337. logger.info('No text detected.')
  338. return None, None
  339. def main():
  340. parser = argparse.ArgumentParser(description='OpenOCR system')
  341. parser.add_argument(
  342. '--img_path',
  343. type=str,
  344. help='Path to the directory containing images or the image filename.')
  345. parser.add_argument(
  346. '--mode',
  347. type=str,
  348. default='mobile',
  349. help="Mode of the OCR system, e.g., 'mobile' or 'server'.")
  350. parser.add_argument(
  351. '--backend',
  352. type=str,
  353. default='torch',
  354. help="Backend of the OCR system, e.g., 'torch' or 'onnx'.")
  355. parser.add_argument('--onnx_det_model_path',
  356. type=str,
  357. default=None,
  358. help='Path to the ONNX model for text detection.')
  359. parser.add_argument('--onnx_rec_model_path',
  360. type=str,
  361. default=None,
  362. help='Path to the ONNX model for text recognition.')
  363. parser.add_argument(
  364. '--save_dir',
  365. type=str,
  366. default='e2e_results/',
  367. help='Directory to save prediction and visualization results. \
  368. Defaults to ./e2e_results/.')
  369. parser.add_argument('--is_vis',
  370. action='store_true',
  371. default=False,
  372. help='Visualize the results.')
  373. parser.add_argument('--drop_score',
  374. type=float,
  375. default=0.5,
  376. help='Score threshold for text recognition.')
  377. parser.add_argument('--device',
  378. type=str,
  379. default='gpu',
  380. help='Device to use for inference.')
  381. args = parser.parse_args()
  382. img_path = args.img_path
  383. mode = args.mode
  384. backend = args.backend
  385. onnx_det_model_path = args.onnx_det_model_path
  386. onnx_rec_model_path = args.onnx_rec_model_path
  387. save_dir = args.save_dir
  388. is_visualize = args.is_vis
  389. drop_score = args.drop_score
  390. device = args.device
  391. text_sys = OpenOCR(mode=mode,
  392. backend=backend,
  393. onnx_det_model_path=onnx_det_model_path,
  394. onnx_rec_model_path=onnx_rec_model_path,
  395. drop_score=drop_score,
  396. det_box_type='quad',
  397. device=device) # det_box_type: 'quad' or 'poly'
  398. text_sys(img_path=img_path, save_dir=save_dir, is_visualize=is_visualize)
  399. if __name__ == '__main__':
  400. main()