123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import threading
- import queue
- import os
- import sys
- import time
- __dir__ = os.path.dirname(os.path.abspath(__file__))
- sys.path.append(__dir__)
- sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
- import numpy as np
- import cv2
- import json
- from PIL import Image
- from tools.utils.utility import get_image_file_list, check_and_read
- from tools.infer_rec import OpenRecognizer
- from tools.infer_det import OpenDetector
- from tools.infer_e2e import check_and_download_font, sorted_boxes
- from tools.engine.config import Config
- from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
- class OpenOCRParallel:
- def __init__(self, drop_score=0.5, det_box_type='quad', max_rec_threads=1):
- cfg_det = Config(
- './configs/det/dbnet/repvit_db.yml').cfg # mobile model
- # cfg_rec = Config('./configs/rec/svtrv2/svtrv2_ch.yml').cfg # server model
- cfg_rec = Config(
- './configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model
- self.text_detector = OpenDetector(cfg_det, numId=0)
- self.text_recognizer = OpenRecognizer(cfg_rec, numId=0)
- self.det_box_type = det_box_type
- self.drop_score = drop_score
- self.queue = queue.Queue(
- ) # Queue to hold detected boxes for recognition
- self.results = {}
- self.lock = threading.Lock() # Lock for thread-safe access to results
- self.max_rec_threads = max_rec_threads
- self.stop_signal = threading.Event() # Signal to stop threads
- def start_recognition_threads(self):
- """Start recognition threads."""
- self.rec_threads = []
- for _ in range(self.max_rec_threads):
- t = threading.Thread(target=self.recognize_text)
- t.start()
- self.rec_threads.append(t)
- def detect_text(self, image_list):
- """Single-threaded text detection for all images."""
- for image_id, (img_numpy, ori_img) in enumerate(image_list):
- dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes']
- if dt_boxes is None:
- self.results[image_id] = [] # If no boxes, set empty results
- continue
- dt_boxes = sorted_boxes(dt_boxes)
- img_crop_list = []
- for box in dt_boxes:
- tmp_box = np.array(box).astype(np.float32)
- img_crop = (get_rotate_crop_image(ori_img, tmp_box)
- if self.det_box_type == 'quad' else
- get_minarea_rect_crop(ori_img, tmp_box))
- img_crop_list.append(img_crop)
- self.queue.put(
- (image_id, dt_boxes, img_crop_list
- )) # Put image ID, detected box, and cropped image in queue
- # Signal that no more items will be added to the queue
- self.stop_signal.set()
- def recognize_text(self):
- """Recognize text in each cropped image."""
- while not self.stop_signal.is_set() or not self.queue.empty():
- try:
- image_id, boxs, img_crop_list = self.queue.get(timeout=0.5)
- rec_results = self.text_recognizer(
- img_numpy_list=img_crop_list, batch_num=6)
- for rec_result, box in zip(rec_results, boxs):
- text, score = rec_result['text'], rec_result['score']
- if score >= self.drop_score:
- with self.lock:
- # Ensure results dictionary has a list for each image ID
- if image_id not in self.results:
- self.results[image_id] = []
- self.results[image_id].append({
- 'transcription':
- text,
- 'points':
- box.tolist(),
- 'score':
- score
- })
- self.queue.task_done()
- except queue.Empty:
- continue
- def process_images(self, image_list):
- """Process a list of images."""
- # Initialize results dictionary
- self.results = {i: [] for i in range(len(image_list))}
- # Start recognition threads
- t_start_1 = time.time()
- self.start_recognition_threads()
- # Start detection in the main thread
- t_start = time.time()
- self.detect_text(image_list)
- print('det time:', time.time() - t_start)
- # Wait for recognition threads to finish
- for t in self.rec_threads:
- t.join()
- self.stop_signal.clear()
- print('all time:', time.time() - t_start_1)
- return self.results
- def main(cfg_det, cfg_rec):
- img_path = './testA/'
- image_file_list = get_image_file_list(img_path)
- drop_score = 0.5
- text_sys = OpenOCRParallel(
- drop_score=drop_score,
- det_box_type='quad') # det_box_type: 'quad' or 'poly'
- is_visualize = False
- if is_visualize:
- font_path = './simfang.ttf'
- check_and_download_font(font_path)
- draw_img_save_dir = img_path + 'e2e_results/' if img_path[
- -1] != '/' else img_path[:-1] + 'e2e_results/'
- os.makedirs(draw_img_save_dir, exist_ok=True)
- save_results = []
- # Prepare images
- images = []
- t_start = time.time()
- for image_file in image_file_list:
- img, flag_gif, flag_pdf = check_and_read(image_file)
- if not flag_gif and not flag_pdf:
- img = cv2.imread(image_file)
- if img is not None:
- images.append((img, img.copy()))
- results = text_sys.process_images(images)
- print(f'time cost: {time.time() - t_start}')
- # Save results and visualize
- for image_id, res in results.items():
- image_file = image_file_list[image_id]
- save_pred = f'{os.path.basename(image_file)}\t{json.dumps(res, ensure_ascii=False)}\n'
- # print(save_pred)
- save_results.append(save_pred)
- if is_visualize:
- dt_boxes = [result['points'] for result in res]
- rec_res = [result['transcription'] for result in res]
- rec_score = [result['score'] for result in res]
- image = Image.fromarray(
- cv2.cvtColor(images[image_id][0], cv2.COLOR_BGR2RGB))
- draw_img = draw_ocr_box_txt(image,
- dt_boxes,
- rec_res,
- rec_score,
- drop_score=drop_score,
- font_path=font_path)
- save_file = os.path.join(draw_img_save_dir,
- os.path.basename(image_file))
- cv2.imwrite(save_file, draw_img[:, :, ::-1])
- with open(os.path.join(draw_img_save_dir, 'system_results.txt'),
- 'w',
- encoding='utf-8') as f:
- f.writelines(save_results)
- if __name__ == '__main__':
- main()
|