infer_det.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from pathlib import Path
  5. import time
  6. import numpy as np
  7. import os
  8. import sys
  9. __dir__ = os.path.dirname(os.path.abspath(__file__))
  10. sys.path.append(__dir__)
  11. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  12. os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
  13. import cv2
  14. import json
  15. from tools.engine.config import Config
  16. from tools.utility import ArgsParser
  17. from tools.utils.logging import get_logger
  18. from tools.utils.utility import get_image_file_list
  19. logger = get_logger()
  20. root_dir = Path(__file__).resolve().parent
  21. DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
  22. MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称
  23. DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL
  24. MODEL_NAME_DET_ONNX = './openocr_det_model.onnx' # 模型文件名称
  25. DOWNLOAD_URL_DET_ONNX = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_model.onnx' # 模型文件 URL
  26. def check_and_download_model(model_name: str, url: str):
  27. """
  28. 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
  29. Args:
  30. model_name (str): 模型文件的名称,例如 "model.pt"
  31. url (str): 模型文件的下载地址
  32. Returns:
  33. str: 模型文件的完整路径
  34. """
  35. if os.path.exists(model_name):
  36. return model_name
  37. # 固定缓存路径为用户主目录下的 ".cache/openocr"
  38. cache_dir = Path.home() / '.cache' / 'openocr'
  39. model_path = cache_dir / model_name
  40. # 如果模型文件已存在,直接返回路径
  41. if model_path.exists():
  42. logger.info(f'Model already exists at: {model_path}')
  43. return str(model_path)
  44. # 如果文件不存在,下载模型
  45. logger.info(f'Model not found. Downloading from {url}...')
  46. # 创建缓存目录(如果不存在)
  47. cache_dir.mkdir(parents=True, exist_ok=True)
  48. try:
  49. # 下载文件
  50. import urllib.request
  51. with urllib.request.urlopen(url) as response, open(model_path,
  52. 'wb') as out_file:
  53. out_file.write(response.read())
  54. logger.info(f'Model downloaded and saved at: {model_path}')
  55. return str(model_path)
  56. except Exception as e:
  57. logger.error(f'Error downloading the model: {e}')
  58. # 提示用户手动下载
  59. logger.error(
  60. f'Unable to download the model automatically. '
  61. f'Please download the model manually from the following URL:\n{url}\n'
  62. f'and save it to: {model_name} or {model_path}')
  63. raise RuntimeError(
  64. f'Failed to download the model. Please download it manually from {url} '
  65. f'and save it to {model_path}') from e
  66. def replace_batchnorm(net):
  67. import torch
  68. for child_name, child in net.named_children():
  69. if hasattr(child, 'fuse'):
  70. fused = child.fuse()
  71. setattr(net, child_name, fused)
  72. replace_batchnorm(fused)
  73. elif isinstance(child, torch.nn.BatchNorm2d):
  74. setattr(net, child_name, torch.nn.Identity())
  75. else:
  76. replace_batchnorm(child)
  77. def draw_det_res(dt_boxes, img, img_name, save_path):
  78. src_im = img
  79. for box in dt_boxes:
  80. box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
  81. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  82. if not os.path.exists(save_path):
  83. os.makedirs(save_path)
  84. save_path = os.path.join(save_path, os.path.basename(img_name))
  85. cv2.imwrite(save_path, src_im)
  86. def set_device(device, numId=0):
  87. import torch
  88. if device == 'gpu' and torch.cuda.is_available():
  89. device = torch.device(f'cuda:{numId}')
  90. else:
  91. logger.info('GPU is not available, using CPU.')
  92. device = torch.device('cpu')
  93. return device
  94. class OpenDetector(object):
  95. def __init__(self,
  96. config=None,
  97. backend='torch',
  98. onnx_model_path=None,
  99. numId=0):
  100. """
  101. Args:
  102. config (dict, optional): 配置信息。默认为None。
  103. backend (str): 'torch' 或 'onnx'
  104. onnx_model_path (str): ONNX模型路径(仅当backend='onnx'时需要)
  105. numId (int, optional): 设备编号。默认为0。
  106. """
  107. if config is None:
  108. config = Config(DEFAULT_CFG_PATH_DET).cfg
  109. self._init_common(config)
  110. backend = backend if config['Global'].get(
  111. 'backend', None) is None else config['Global']['backend']
  112. self.backend = backend
  113. if backend == 'torch':
  114. import torch
  115. self.torch = torch
  116. if config['Architecture']['algorithm'] == 'DB_mobile':
  117. if not os.path.exists(config['Global']['pretrained_model']):
  118. config['Global'][
  119. 'pretrained_model'] = check_and_download_model(
  120. MODEL_NAME_DET, DOWNLOAD_URL_DET)
  121. self._init_torch_model(config, numId)
  122. elif backend == 'onnx':
  123. from tools.infer.onnx_engine import ONNXEngine
  124. onnx_model_path = onnx_model_path if config['Global'].get(
  125. 'onnx_model_path',
  126. None) is None else config['Global']['onnx_model_path']
  127. if onnx_model_path is None:
  128. if config['Architecture']['algorithm'] == 'DB_mobile':
  129. onnx_model_path = check_and_download_model(
  130. MODEL_NAME_DET_ONNX, DOWNLOAD_URL_DET_ONNX)
  131. else:
  132. raise ValueError('ONNX模式需要指定onnx_model_path参数')
  133. self.onnx_det_engine = ONNXEngine(
  134. onnx_model_path, use_gpu=config['Global']['device'] == 'gpu')
  135. else:
  136. raise ValueError("backend参数必须是'torch'或'onnx'")
  137. def _init_common(self, config):
  138. from opendet.postprocess import build_post_process
  139. from opendet.preprocess import create_operators, transform
  140. global_config = config['Global']
  141. # create data ops
  142. self.transform = transform
  143. transforms = []
  144. for op in config['Eval']['dataset']['transforms']:
  145. op_name = list(op)[0]
  146. if 'Label' in op_name:
  147. continue
  148. elif op_name == 'KeepKeys':
  149. op[op_name]['keep_keys'] = ['image', 'shape']
  150. transforms.append(op)
  151. self.ops = create_operators(transforms, global_config)
  152. # build post process
  153. self.post_process_class = build_post_process(config['PostProcess'],
  154. global_config)
  155. def _init_torch_model(self, config, numId=0):
  156. from opendet.modeling import build_model as build_det_model
  157. from tools.utils.ckpt import load_ckpt
  158. # build model
  159. self.model = build_det_model(config['Architecture'])
  160. self.model.eval()
  161. load_ckpt(self.model, config)
  162. if config['Architecture']['algorithm'] == 'DB_mobile':
  163. replace_batchnorm(self.model.backbone)
  164. self.device = set_device(config['Global']['device'], numId=numId)
  165. self.model.to(device=self.device)
  166. def _inference_onnx(self, images):
  167. # ONNX输入需要为numpy数组
  168. return self.onnx_det_engine.run(images)
  169. def __call__(self,
  170. img_path=None,
  171. img_numpy_list=None,
  172. img_numpy=None,
  173. return_mask=False,
  174. **kwargs):
  175. """
  176. 对输入图像进行处理,并返回处理结果。
  177. Args:
  178. img_path (str, optional): 图像文件路径。默认为 None。
  179. img_numpy_list (list, optional): 图像数据列表,每个元素为 numpy 数组。默认为 None。
  180. img_numpy (numpy.ndarray, optional): 图像数据,numpy 数组格式。默认为 None。
  181. Returns:
  182. list: 包含处理结果的列表。每个元素为一个字典,包含 'boxes' 和 'elapse' 两个键。
  183. 'boxes' 的值为检测到的目标框点集,'elapse' 的值为处理时间。
  184. Raises:
  185. Exception: 若没有提供图像路径或 numpy 数组,则抛出异常。
  186. """
  187. if img_numpy is not None:
  188. img_numpy_list = [img_numpy]
  189. num_img = 1
  190. elif img_path is not None:
  191. img_path = get_image_file_list(img_path)
  192. num_img = len(img_path)
  193. elif img_numpy_list is not None:
  194. num_img = len(img_numpy_list)
  195. else:
  196. raise Exception('No input image path or numpy array.')
  197. results = []
  198. for img_idx in range(num_img):
  199. if img_numpy_list is not None:
  200. img = img_numpy_list[img_idx]
  201. data = {'image': img}
  202. elif img_path is not None:
  203. with open(img_path[img_idx], 'rb') as f:
  204. img = f.read()
  205. data = {'image': img}
  206. data = self.transform(data, self.ops[:1])
  207. if kwargs.get('det_input_size', None) is not None:
  208. data['max_sile_len'] = kwargs['det_input_size']
  209. batch = self.transform(data, self.ops[1:])
  210. images = np.expand_dims(batch[0], axis=0)
  211. shape_list = np.expand_dims(batch[1], axis=0)
  212. t_start = time.time()
  213. if self.backend == 'torch':
  214. images = self.torch.from_numpy(images).to(device=self.device)
  215. with self.torch.no_grad():
  216. preds = self.model(images)
  217. kwargs['torch_tensor'] = True
  218. elif self.backend == 'onnx':
  219. preds_det = self._inference_onnx(images)
  220. preds = {'maps': preds_det[0]}
  221. kwargs['torch_tensor'] = False
  222. t_cost = time.time() - t_start
  223. post_result = self.post_process_class(preds, [None, shape_list],
  224. **kwargs)
  225. info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
  226. if return_mask:
  227. if isinstance(preds['maps'], self.torch.Tensor):
  228. mask = preds['maps'].detach().cpu().numpy()
  229. else:
  230. mask = preds['maps']
  231. info['mask'] = mask
  232. results.append(info)
  233. return results
  234. def main(cfg):
  235. is_visualize = cfg['Global'].get('is_visualize', False)
  236. model = OpenDetector(cfg)
  237. save_res_path = './det_results/'
  238. if not os.path.exists(save_res_path):
  239. os.makedirs(save_res_path)
  240. sample_num = 0
  241. with open(save_res_path + '/det_results.txt', 'wb') as fout:
  242. for file in get_image_file_list(cfg['Global']['infer_img']):
  243. preds_result = model(img_path=file)[0]
  244. logger.info('{} infer_img: {}, time cost: {}'.format(
  245. sample_num, file, preds_result['elapse']))
  246. boxes = preds_result['boxes']
  247. dt_boxes_json = []
  248. for box in boxes:
  249. tmp_json = {}
  250. tmp_json['points'] = np.array(box).tolist()
  251. dt_boxes_json.append(tmp_json)
  252. if is_visualize:
  253. src_img = cv2.imread(file)
  254. draw_det_res(boxes, src_img, file, save_res_path)
  255. logger.info('The detected Image saved in {}'.format(
  256. os.path.join(save_res_path, os.path.basename(file))))
  257. otstr = file + '\t' + json.dumps(dt_boxes_json) + '\n'
  258. logger.info('results: {}'.format(json.dumps(dt_boxes_json)))
  259. fout.write(otstr.encode())
  260. sample_num += 1
  261. logger.info(
  262. f"Results saved to {os.path.join(save_res_path, 'det_results.txt')}.)"
  263. )
  264. logger.info('success!')
  265. if __name__ == '__main__':
  266. FLAGS = ArgsParser().parse_args()
  267. cfg = Config(FLAGS.config)
  268. FLAGS = vars(FLAGS)
  269. opt = FLAGS.pop('opt')
  270. cfg.merge_dict(FLAGS)
  271. cfg.merge_dict(opt)
  272. main(cfg.cfg)