infer_rec.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import os
  2. from pathlib import Path
  3. import sys
  4. import time
  5. __dir__ = os.path.dirname(os.path.abspath(__file__))
  6. sys.path.append(__dir__)
  7. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  8. import numpy as np
  9. from tools.engine.config import Config
  10. from tools.utility import ArgsParser
  11. from tools.utils.logging import get_logger
  12. from tools.utils.utility import get_image_file_list
  13. logger = get_logger()
  14. root_dir = Path(__file__).resolve().parent
  15. DEFAULT_CFG_PATH_REC_SERVER = str(root_dir /
  16. '../configs/rec/svtrv2/svtrv2_ch.yml')
  17. DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')
  18. DEFAULT_DICT_PATH_REC = str(root_dir / './utils/ppocr_keys_v1.txt')
  19. MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称
  20. DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL
  21. MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称
  22. DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL
  23. MODEL_NAME_REC_ONNX = './openocr_rec_model.onnx' # 模型文件名称
  24. DOWNLOAD_URL_REC_ONNX = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_rec_model.onnx' # 模型文件 URL
  25. def check_and_download_model(model_name: str, url: str):
  26. """
  27. 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
  28. Args:
  29. model_name (str): 模型文件的名称,例如 "model.pt"
  30. url (str): 模型文件的下载地址
  31. Returns:
  32. str: 模型文件的完整路径
  33. """
  34. if os.path.exists(model_name):
  35. return model_name
  36. # 固定缓存路径为用户主目录下的 ".cache/openocr"
  37. cache_dir = Path.home() / '.cache' / 'openocr'
  38. model_path = cache_dir / model_name
  39. # 如果模型文件已存在,直接返回路径
  40. if model_path.exists():
  41. logger.info(f'Model already exists at: {model_path}')
  42. return str(model_path)
  43. # 如果文件不存在,下载模型
  44. logger.info(f'Model not found. Downloading from {url}...')
  45. # 创建缓存目录(如果不存在)
  46. cache_dir.mkdir(parents=True, exist_ok=True)
  47. try:
  48. # 下载文件
  49. import urllib.request
  50. with urllib.request.urlopen(url) as response, open(model_path,
  51. 'wb') as out_file:
  52. out_file.write(response.read())
  53. logger.info(f'Model downloaded and saved at: {model_path}')
  54. return str(model_path)
  55. except Exception as e:
  56. logger.error(f'Error downloading the model: {e}')
  57. # 提示用户手动下载
  58. logger.error(
  59. f'Unable to download the model automatically. '
  60. f'Please download the model manually from the following URL:\n{url}\n'
  61. f'and save it to: {model_name} or {model_path}')
  62. raise RuntimeError(
  63. f'Failed to download the model. Please download it manually from {url} '
  64. f'and save it to {model_path}') from e
  65. class RatioRecTVReisze(object):
  66. def __init__(self, cfg):
  67. self.max_ratio = cfg['Eval']['loader'].get('max_ratio', 12)
  68. self.base_shape = cfg['Eval']['dataset'].get(
  69. 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
  70. self.base_h = cfg['Eval']['dataset'].get('base_h', 32)
  71. from torchvision import transforms as T
  72. from torchvision.transforms import functional as F
  73. self.F = F
  74. self.interpolation = T.InterpolationMode.BICUBIC
  75. transforms = []
  76. transforms.extend([
  77. T.ToTensor(),
  78. T.Normalize(0.5, 0.5),
  79. ])
  80. self.transforms = T.Compose(transforms)
  81. self.ceil = cfg['Eval']['dataset'].get('ceil', False),
  82. def __call__(self, data):
  83. img = data['image']
  84. imgH = self.base_h
  85. w, h = img.size
  86. if self.ceil:
  87. gen_ratio = int(float(w) / float(h)) + 1
  88. else:
  89. gen_ratio = max(1, round(float(w) / float(h)))
  90. ratio_resize = min(gen_ratio, self.max_ratio)
  91. imgW, imgH = self.base_shape[ratio_resize -
  92. 1] if ratio_resize <= 4 else [
  93. self.base_h *
  94. ratio_resize, self.base_h
  95. ]
  96. resized_w = imgW
  97. resized_image = self.F.resize(img, (imgH, resized_w),
  98. interpolation=self.interpolation)
  99. img = self.transforms(resized_image)
  100. data['image'] = img
  101. return data
  102. def build_rec_process(cfg):
  103. transforms = []
  104. ratio_resize_flag = True
  105. for op in cfg['Eval']['dataset']['transforms']:
  106. op_name = list(op)[0]
  107. if 'Resize' in op_name:
  108. ratio_resize_flag = False
  109. if 'Label' in op_name:
  110. continue
  111. elif op_name in ['RecResizeImg']:
  112. op[op_name]['infer_mode'] = True
  113. elif op_name == 'KeepKeys':
  114. if cfg['Architecture']['algorithm'] in ['SAR', 'RobustScanner']:
  115. if 'valid_ratio' in op[op_name]['keep_keys']:
  116. op[op_name]['keep_keys'] = ['image', 'valid_ratio']
  117. else:
  118. op[op_name]['keep_keys'] = ['image']
  119. else:
  120. op[op_name]['keep_keys'] = ['image']
  121. transforms.append(op)
  122. return transforms, ratio_resize_flag
  123. def set_device(device, numId=0):
  124. import torch
  125. if device == 'gpu' and torch.cuda.is_available():
  126. device = torch.device(f'cuda:{numId}')
  127. else:
  128. logger.info('GPU is not available, using CPU.')
  129. device = torch.device('cpu')
  130. return device
  131. class OpenRecognizer:
  132. def __init__(self,
  133. config=None,
  134. mode='mobile',
  135. backend='torch',
  136. onnx_model_path=None,
  137. numId=0):
  138. """
  139. Args:
  140. config (dict, optional): 配置信息。默认为None。
  141. mode (str, optional): 模式,'server' 或 'mobile'。默认为'mobile'。
  142. backend (str): 'torch' 或 'onnx'
  143. onnx_model_path (str): ONNX模型路径(仅当backend='onnx'时需要)
  144. numId (int, optional): 设备编号。默认为0。
  145. """
  146. if config is None:
  147. config_file = DEFAULT_CFG_PATH_REC_SERVER if mode == 'server' else DEFAULT_CFG_PATH_REC
  148. config = Config(config_file).cfg
  149. self.cfg = config
  150. # 公共初始化
  151. self._init_common()
  152. backend = backend if config['Global'].get(
  153. 'backend', None) is None else config['Global']['backend']
  154. self.backend = backend
  155. if backend == 'torch':
  156. import torch
  157. self.torch = torch
  158. self._init_torch_model(numId)
  159. elif backend == 'onnx':
  160. from tools.infer.onnx_engine import ONNXEngine
  161. onnx_model_path = onnx_model_path if config['Global'].get(
  162. 'onnx_model_path',
  163. None) is None else config['Global']['onnx_model_path']
  164. if not onnx_model_path:
  165. if self.cfg['Architecture']['algorithm'] == 'SVTRv2_mobile':
  166. onnx_model_path = check_and_download_model(
  167. MODEL_NAME_REC_ONNX, DOWNLOAD_URL_REC_ONNX)
  168. else:
  169. raise ValueError('ONNX模式需要指定onnx_model_path参数')
  170. self.onnx_rec_engine = ONNXEngine(
  171. onnx_model_path, use_gpu=config['Global']['device'] == 'gpu')
  172. else:
  173. raise ValueError("backend参数必须是'torch'或'onnx'")
  174. def _init_common(self):
  175. # 初始化公共组件
  176. from openrec.postprocess import build_post_process
  177. from openrec.preprocess import create_operators, transform
  178. self.transform = transform
  179. # 构建预处理流程
  180. algorithm_name = self.cfg['Architecture']['algorithm']
  181. if algorithm_name in ['SVTRv2_mobile', 'SVTRv2_server']:
  182. self.cfg['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
  183. self.post_process_class = build_post_process(self.cfg['PostProcess'],
  184. self.cfg['Global'])
  185. char_num = self.post_process_class.get_character_num()
  186. self.cfg['Architecture']['Decoder']['out_channels'] = char_num
  187. transforms, ratio_resize_flag = build_rec_process(self.cfg)
  188. self.ops = create_operators(transforms, self.cfg['Global'])
  189. if ratio_resize_flag:
  190. ratio_resize = RatioRecTVReisze(cfg=self.cfg)
  191. self.ops.insert(-1, ratio_resize)
  192. def _init_torch_model(self, numId):
  193. from tools.utils.ckpt import load_ckpt
  194. from tools.infer_det import replace_batchnorm
  195. # PyTorch专用初始化
  196. algorithm_name = self.cfg['Architecture']['algorithm']
  197. if algorithm_name in ['SVTRv2_mobile', 'SVTRv2_server']:
  198. if not os.path.exists(self.cfg['Global']['pretrained_model']):
  199. pretrained_model = check_and_download_model(
  200. MODEL_NAME_REC, DOWNLOAD_URL_REC
  201. ) if algorithm_name == 'SVTRv2_mobile' else check_and_download_model(
  202. MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
  203. self.cfg['Global']['pretrained_model'] = pretrained_model
  204. from openrec.modeling import build_model as build_rec_model
  205. self.model = build_rec_model(self.cfg['Architecture'])
  206. load_ckpt(self.model, self.cfg)
  207. self.device = set_device(self.cfg['Global']['device'], numId)
  208. self.model.to(self.device)
  209. self.model.eval()
  210. if algorithm_name == 'SVTRv2_mobile':
  211. replace_batchnorm(self.model.encoder)
  212. def _inference_onnx(self, images):
  213. # ONNX输入需要为numpy数组
  214. return self.onnx_rec_engine.run(images)
  215. def __call__(self,
  216. img_path=None,
  217. img_numpy_list=None,
  218. img_numpy=None,
  219. batch_num=1):
  220. """
  221. 调用函数,处理输入图像,并返回识别结果。
  222. Args:
  223. img_path (str, optional): 图像文件的路径。默认为 None。
  224. img_numpy_list (list, optional): 包含多个图像 numpy 数组的列表。默认为 None。
  225. img_numpy (numpy.ndarray, optional): 单个图像的 numpy 数组。默认为 None。
  226. batch_num (int, optional): 每次处理的图像数量。默认为 1。
  227. Returns:
  228. list: 包含识别结果的列表,每个元素为一个字典,包含文件路径(如果有的话)、文本、分数和延迟时间。
  229. Raises:
  230. Exception: 如果没有提供图像路径或 numpy 数组,则引发异常。
  231. """
  232. if img_numpy is not None:
  233. img_numpy_list = [img_numpy]
  234. num_img = 1
  235. elif img_path is not None:
  236. img_path = get_image_file_list(img_path)
  237. num_img = len(img_path)
  238. elif img_numpy_list is not None:
  239. num_img = len(img_numpy_list)
  240. else:
  241. raise Exception('No input image path or numpy array.')
  242. results = []
  243. for start_idx in range(0, num_img, batch_num):
  244. batch_data = []
  245. batch_others = []
  246. batch_file_names = []
  247. max_width, max_height = 0, 0
  248. # Prepare batch data
  249. for img_idx in range(start_idx, min(start_idx + batch_num,
  250. num_img)):
  251. if img_numpy_list is not None:
  252. img = img_numpy_list[img_idx]
  253. data = {'image': img}
  254. elif img_path is not None:
  255. file_name = img_path[img_idx]
  256. with open(file_name, 'rb') as f:
  257. img = f.read()
  258. data = {'image': img}
  259. data = self.transform(data, self.ops[:1])
  260. batch_file_names.append(file_name)
  261. batch = self.transform(data, self.ops[1:])
  262. others = None
  263. if self.cfg['Architecture']['algorithm'] in [
  264. 'SAR', 'RobustScanner'
  265. ]:
  266. valid_ratio = np.expand_dims(batch[-1], axis=0)
  267. batch_others.append(valid_ratio)
  268. resized_image = batch[0] if isinstance(
  269. batch[0], np.ndarray) else batch[0].numpy()
  270. h, w = resized_image.shape[-2:]
  271. max_width = max(max_width, w)
  272. max_height = max(max_height, h)
  273. batch_data.append(batch[0])
  274. padded_batch = np.zeros(
  275. (len(batch_data), 3, max_height, max_width), dtype=np.float32)
  276. for i, img in enumerate(batch_data):
  277. h, w = img.shape[-2:]
  278. padded_batch[i, :, :h, :w] = img
  279. if batch_others:
  280. others = np.concatenate(batch_others, axis=0)
  281. else:
  282. others = None
  283. t_start = time.time()
  284. if self.backend == 'torch':
  285. images = self.torch.from_numpy(padded_batch).to(
  286. device=self.device)
  287. with self.torch.no_grad():
  288. preds = self.model(images, others) # bs, len, num_classes
  289. torch_tensor = True
  290. elif self.backend == 'onnx':
  291. # ONNX推理
  292. preds = self._inference_onnx(padded_batch)
  293. preds = preds[0] # bs, len, num_classes
  294. torch_tensor = False
  295. t_cost = time.time() - t_start
  296. post_results = self.post_process_class(preds,
  297. torch_tensor=torch_tensor)
  298. for i, post_result in enumerate(post_results):
  299. if img_path is not None:
  300. info = {
  301. 'file': batch_file_names[i],
  302. 'text': post_result[0],
  303. 'score': post_result[1],
  304. 'elapse': t_cost
  305. }
  306. else:
  307. info = {
  308. 'text': post_result[0],
  309. 'score': post_result[1],
  310. 'elapse': t_cost
  311. }
  312. results.append(info)
  313. return results
  314. def main(cfg):
  315. model = OpenRecognizer(cfg)
  316. save_res_path = './rec_results/'
  317. if not os.path.exists(save_res_path):
  318. os.makedirs(save_res_path)
  319. t_sum = 0
  320. sample_num = 0
  321. max_len = cfg['Global']['max_text_length']
  322. text_len_time = [0 for _ in range(max_len)]
  323. text_len_num = [0 for _ in range(max_len)]
  324. sample_num = 0
  325. with open(save_res_path + '/rec_results.txt', 'wb') as fout:
  326. for file in get_image_file_list(cfg['Global']['infer_img']):
  327. preds_result = model(img_path=file, batch_num=1)[0]
  328. rec_text = preds_result['text']
  329. score = preds_result['score']
  330. t_cost = preds_result['elapse']
  331. info = rec_text + '\t' + str(score)
  332. text_len_num[min(max_len - 1, len(rec_text))] += 1
  333. text_len_time[min(max_len - 1, len(rec_text))] += t_cost
  334. logger.info(
  335. f'{sample_num} {file}\t result: {info}, time cost: {t_cost}')
  336. otstr = file + '\t' + info + '\n'
  337. t_sum += t_cost
  338. fout.write(otstr.encode())
  339. sample_num += 1
  340. logger.info(
  341. f"Results saved to {os.path.join(save_res_path, 'rec_results.txt')}.)"
  342. )
  343. print(text_len_num)
  344. w_avg_t_cost = []
  345. for l_t_cost, l_num in zip(text_len_time, text_len_num):
  346. if l_num != 0:
  347. w_avg_t_cost.append(l_t_cost / l_num)
  348. print(w_avg_t_cost)
  349. w_avg_t_cost = sum(w_avg_t_cost) / len(w_avg_t_cost)
  350. logger.info(
  351. f'Sample num: {sample_num}, Weighted Avg time cost: {t_sum/sample_num}, Avg time cost: {w_avg_t_cost}'
  352. )
  353. logger.info('success!')
  354. if __name__ == '__main__':
  355. FLAGS = ArgsParser().parse_args()
  356. cfg = Config(FLAGS.config)
  357. FLAGS = vars(FLAGS)
  358. opt = FLAGS.pop('opt')
  359. cfg.merge_dict(FLAGS)
  360. cfg.merge_dict(opt)
  361. main(cfg.cfg)