123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from pathlib import Path
- import time
- import numpy as np
- import os
- import sys
- __dir__ = os.path.dirname(os.path.abspath(__file__))
- sys.path.append(__dir__)
- sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
- os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
- import cv2
- import json
- from tools.engine.config import Config
- from tools.utility import ArgsParser
- from tools.utils.logging import get_logger
- from tools.utils.utility import get_image_file_list
- logger = get_logger()
- root_dir = Path(__file__).resolve().parent
- DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
- MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称
- DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL
- MODEL_NAME_DET_ONNX = './openocr_det_model.onnx' # 模型文件名称
- DOWNLOAD_URL_DET_ONNX = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_model.onnx' # 模型文件 URL
- def check_and_download_model(model_name: str, url: str):
- """
- 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
- Args:
- model_name (str): 模型文件的名称,例如 "model.pt"
- url (str): 模型文件的下载地址
- Returns:
- str: 模型文件的完整路径
- """
- if os.path.exists(model_name):
- return model_name
- # 固定缓存路径为用户主目录下的 ".cache/openocr"
- cache_dir = Path.home() / '.cache' / 'openocr'
- model_path = cache_dir / model_name
- # 如果模型文件已存在,直接返回路径
- if model_path.exists():
- logger.info(f'Model already exists at: {model_path}')
- return str(model_path)
- # 如果文件不存在,下载模型
- logger.info(f'Model not found. Downloading from {url}...')
- # 创建缓存目录(如果不存在)
- cache_dir.mkdir(parents=True, exist_ok=True)
- try:
- # 下载文件
- import urllib.request
- with urllib.request.urlopen(url) as response, open(model_path,
- 'wb') as out_file:
- out_file.write(response.read())
- logger.info(f'Model downloaded and saved at: {model_path}')
- return str(model_path)
- except Exception as e:
- logger.error(f'Error downloading the model: {e}')
- # 提示用户手动下载
- logger.error(
- f'Unable to download the model automatically. '
- f'Please download the model manually from the following URL:\n{url}\n'
- f'and save it to: {model_name} or {model_path}')
- raise RuntimeError(
- f'Failed to download the model. Please download it manually from {url} '
- f'and save it to {model_path}') from e
- def replace_batchnorm(net):
- import torch
- for child_name, child in net.named_children():
- if hasattr(child, 'fuse'):
- fused = child.fuse()
- setattr(net, child_name, fused)
- replace_batchnorm(fused)
- elif isinstance(child, torch.nn.BatchNorm2d):
- setattr(net, child_name, torch.nn.Identity())
- else:
- replace_batchnorm(child)
- def draw_det_res(dt_boxes, img, img_name, save_path):
- src_im = img
- for box in dt_boxes:
- box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
- cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
- if not os.path.exists(save_path):
- os.makedirs(save_path)
- save_path = os.path.join(save_path, os.path.basename(img_name))
- cv2.imwrite(save_path, src_im)
- def set_device(device, numId=0):
- import torch
- if device == 'gpu' and torch.cuda.is_available():
- device = torch.device(f'cuda:{numId}')
- else:
- logger.info('GPU is not available, using CPU.')
- device = torch.device('cpu')
- return device
- class OpenDetector(object):
- def __init__(self,
- config=None,
- backend='torch',
- onnx_model_path=None,
- numId=0):
- """
- Args:
- config (dict, optional): 配置信息。默认为None。
- backend (str): 'torch' 或 'onnx'
- onnx_model_path (str): ONNX模型路径(仅当backend='onnx'时需要)
- numId (int, optional): 设备编号。默认为0。
- """
- if config is None:
- config = Config(DEFAULT_CFG_PATH_DET).cfg
- self._init_common(config)
- backend = backend if config['Global'].get(
- 'backend', None) is None else config['Global']['backend']
- self.backend = backend
- if backend == 'torch':
- import torch
- self.torch = torch
- if config['Architecture']['algorithm'] == 'DB_mobile':
- if not os.path.exists(config['Global']['pretrained_model']):
- config['Global'][
- 'pretrained_model'] = check_and_download_model(
- MODEL_NAME_DET, DOWNLOAD_URL_DET)
- self._init_torch_model(config, numId)
- elif backend == 'onnx':
- from tools.infer.onnx_engine import ONNXEngine
- onnx_model_path = onnx_model_path if config['Global'].get(
- 'onnx_model_path',
- None) is None else config['Global']['onnx_model_path']
- if onnx_model_path is None:
- if config['Architecture']['algorithm'] == 'DB_mobile':
- onnx_model_path = check_and_download_model(
- MODEL_NAME_DET_ONNX, DOWNLOAD_URL_DET_ONNX)
- else:
- raise ValueError('ONNX模式需要指定onnx_model_path参数')
- self.onnx_det_engine = ONNXEngine(
- onnx_model_path, use_gpu=config['Global']['device'] == 'gpu')
- else:
- raise ValueError("backend参数必须是'torch'或'onnx'")
- def _init_common(self, config):
- from opendet.postprocess import build_post_process
- from opendet.preprocess import create_operators, transform
- global_config = config['Global']
- # create data ops
- self.transform = transform
- transforms = []
- for op in config['Eval']['dataset']['transforms']:
- op_name = list(op)[0]
- if 'Label' in op_name:
- continue
- elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image', 'shape']
- transforms.append(op)
- self.ops = create_operators(transforms, global_config)
- # build post process
- self.post_process_class = build_post_process(config['PostProcess'],
- global_config)
- def _init_torch_model(self, config, numId=0):
- from opendet.modeling import build_model as build_det_model
- from tools.utils.ckpt import load_ckpt
- # build model
- self.model = build_det_model(config['Architecture'])
- self.model.eval()
- load_ckpt(self.model, config)
- if config['Architecture']['algorithm'] == 'DB_mobile':
- replace_batchnorm(self.model.backbone)
- self.device = set_device(config['Global']['device'], numId=numId)
- self.model.to(device=self.device)
- def _inference_onnx(self, images):
- # ONNX输入需要为numpy数组
- return self.onnx_det_engine.run(images)
- def __call__(self,
- img_path=None,
- img_numpy_list=None,
- img_numpy=None,
- return_mask=False,
- **kwargs):
- """
- 对输入图像进行处理,并返回处理结果。
- Args:
- img_path (str, optional): 图像文件路径。默认为 None。
- img_numpy_list (list, optional): 图像数据列表,每个元素为 numpy 数组。默认为 None。
- img_numpy (numpy.ndarray, optional): 图像数据,numpy 数组格式。默认为 None。
- Returns:
- list: 包含处理结果的列表。每个元素为一个字典,包含 'boxes' 和 'elapse' 两个键。
- 'boxes' 的值为检测到的目标框点集,'elapse' 的值为处理时间。
- Raises:
- Exception: 若没有提供图像路径或 numpy 数组,则抛出异常。
- """
- if img_numpy is not None:
- img_numpy_list = [img_numpy]
- num_img = 1
- elif img_path is not None:
- img_path = get_image_file_list(img_path)
- num_img = len(img_path)
- elif img_numpy_list is not None:
- num_img = len(img_numpy_list)
- else:
- raise Exception('No input image path or numpy array.')
- results = []
- for img_idx in range(num_img):
- if img_numpy_list is not None:
- img = img_numpy_list[img_idx]
- data = {'image': img}
- elif img_path is not None:
- with open(img_path[img_idx], 'rb') as f:
- img = f.read()
- data = {'image': img}
- data = self.transform(data, self.ops[:1])
- if kwargs.get('det_input_size', None) is not None:
- data['max_sile_len'] = kwargs['det_input_size']
- batch = self.transform(data, self.ops[1:])
- images = np.expand_dims(batch[0], axis=0)
- shape_list = np.expand_dims(batch[1], axis=0)
- t_start = time.time()
- if self.backend == 'torch':
- images = self.torch.from_numpy(images).to(device=self.device)
- with self.torch.no_grad():
- preds = self.model(images)
- kwargs['torch_tensor'] = True
- elif self.backend == 'onnx':
- preds_det = self._inference_onnx(images)
- preds = {'maps': preds_det[0]}
- kwargs['torch_tensor'] = False
- t_cost = time.time() - t_start
- post_result = self.post_process_class(preds, [None, shape_list],
- **kwargs)
- info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
- if return_mask:
- if isinstance(preds['maps'], self.torch.Tensor):
- mask = preds['maps'].detach().cpu().numpy()
- else:
- mask = preds['maps']
- info['mask'] = mask
- results.append(info)
- return results
- def main(cfg):
- is_visualize = cfg['Global'].get('is_visualize', False)
- model = OpenDetector(cfg)
- save_res_path = './det_results/'
- if not os.path.exists(save_res_path):
- os.makedirs(save_res_path)
- sample_num = 0
- with open(save_res_path + '/det_results.txt', 'wb') as fout:
- for file in get_image_file_list(cfg['Global']['infer_img']):
- preds_result = model(img_path=file)[0]
- logger.info('{} infer_img: {}, time cost: {}'.format(
- sample_num, file, preds_result['elapse']))
- boxes = preds_result['boxes']
- dt_boxes_json = []
- for box in boxes:
- tmp_json = {}
- tmp_json['points'] = np.array(box).tolist()
- dt_boxes_json.append(tmp_json)
- if is_visualize:
- src_img = cv2.imread(file)
- draw_det_res(boxes, src_img, file, save_res_path)
- logger.info('The detected Image saved in {}'.format(
- os.path.join(save_res_path, os.path.basename(file))))
- otstr = file + '\t' + json.dumps(dt_boxes_json) + '\n'
- logger.info('results: {}'.format(json.dumps(dt_boxes_json)))
- fout.write(otstr.encode())
- sample_num += 1
- logger.info(
- f"Results saved to {os.path.join(save_res_path, 'det_results.txt')}.)"
- )
- logger.info('success!')
- if __name__ == '__main__':
- FLAGS = ArgsParser().parse_args()
- cfg = Config(FLAGS.config)
- FLAGS = vars(FLAGS)
- opt = FLAGS.pop('opt')
- cfg.merge_dict(FLAGS)
- cfg.merge_dict(opt)
- main(cfg.cfg)
|