utility.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import argparse
  2. import math
  3. import cv2
  4. import numpy as np
  5. import PIL
  6. from PIL import Image, ImageDraw, ImageFont
  7. import random
  8. def str2bool(v):
  9. return v.lower() in ('true', 'yes', 't', 'y', '1')
  10. def str2int_tuple(v):
  11. return tuple([int(i.strip()) for i in v.split(',')])
  12. def init_args():
  13. parser = argparse.ArgumentParser()
  14. # params for prediction engine
  15. parser.add_argument('--use_gpu', type=str2bool, default=False)
  16. # params for text detector
  17. parser.add_argument('--image_dir', type=str)
  18. parser.add_argument('--det_algorithm', type=str, default='DB')
  19. parser.add_argument('--det_model_dir', type=str)
  20. parser.add_argument('--det_limit_side_len', type=float, default=960)
  21. parser.add_argument('--det_limit_type', type=str, default='max')
  22. parser.add_argument('--det_box_type', type=str, default='quad')
  23. # DB parmas
  24. parser.add_argument('--det_db_thresh', type=float, default=0.3)
  25. parser.add_argument('--det_db_box_thresh', type=float, default=0.6)
  26. parser.add_argument('--det_db_unclip_ratio', type=float, default=1.5)
  27. parser.add_argument('--max_batch_size', type=int, default=10)
  28. parser.add_argument('--use_dilation', type=str2bool, default=False)
  29. parser.add_argument('--det_db_score_mode', type=str, default='fast')
  30. # params for text recognizer
  31. parser.add_argument('--rec_algorithm', type=str, default='SVTR_LCNet')
  32. parser.add_argument('--rec_model_dir', type=str)
  33. parser.add_argument('--rec_image_inverse', type=str2bool, default=True)
  34. parser.add_argument('--rec_image_shape', type=str, default='3, 48, 320')
  35. parser.add_argument('--rec_batch_num', type=int, default=6)
  36. parser.add_argument('--max_text_length', type=int, default=25)
  37. parser.add_argument('--vis_font_path',
  38. type=str,
  39. default='./doc/fonts/simfang.ttf')
  40. parser.add_argument('--drop_score', type=float, default=0.5)
  41. # params for text classifier
  42. parser.add_argument('--use_angle_cls', type=str2bool, default=False)
  43. parser.add_argument('--cls_model_dir', type=str)
  44. parser.add_argument('--cls_image_shape', type=str, default='3, 48, 192')
  45. parser.add_argument('--label_list', type=list, default=['0', '180'])
  46. parser.add_argument('--cls_batch_num', type=int, default=6)
  47. parser.add_argument('--cls_thresh', type=float, default=0.9)
  48. parser.add_argument('--warmup', type=str2bool, default=False)
  49. #
  50. parser.add_argument('--output', type=str, default='./inference_results')
  51. parser.add_argument('--save_crop_res', type=str2bool, default=False)
  52. parser.add_argument('--crop_res_save_dir', type=str, default='./output')
  53. # multi-process
  54. parser.add_argument('--use_mp', type=str2bool, default=False)
  55. parser.add_argument('--total_process_num', type=int, default=1)
  56. parser.add_argument('--process_id', type=int, default=0)
  57. parser.add_argument('--show_log', type=str2bool, default=True)
  58. return parser
  59. def parse_args():
  60. parser = init_args()
  61. return parser.parse_args()
  62. def create_font(txt, sz, font_path='./doc/fonts/simfang.ttf'):
  63. font_size = int(sz[1] * 0.99)
  64. font = ImageFont.truetype(font_path, font_size, encoding='utf-8')
  65. if int(PIL.__version__.split('.')[0]) < 10:
  66. length = font.getsize(txt)[0]
  67. else:
  68. length = font.getlength(txt)
  69. if length > sz[0]:
  70. font_size = int(font_size * sz[0] / length)
  71. font = ImageFont.truetype(font_path, font_size, encoding='utf-8')
  72. return font
  73. def draw_box_txt_fine(img_size, box, txt, font_path='./doc/fonts/simfang.ttf'):
  74. box_height = int(
  75. math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2))
  76. box_width = int(
  77. math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2))
  78. if box_height > 2 * box_width and box_height > 30:
  79. img_text = Image.new('RGB', (box_height, box_width), (255, 255, 255))
  80. draw_text = ImageDraw.Draw(img_text)
  81. if txt:
  82. font = create_font(txt, (box_height, box_width), font_path)
  83. draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
  84. img_text = img_text.transpose(Image.ROTATE_270)
  85. else:
  86. img_text = Image.new('RGB', (box_width, box_height), (255, 255, 255))
  87. draw_text = ImageDraw.Draw(img_text)
  88. if txt:
  89. font = create_font(txt, (box_width, box_height), font_path)
  90. draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
  91. pts1 = np.float32([[0, 0], [box_width, 0], [box_width, box_height],
  92. [0, box_height]])
  93. pts2 = np.array(box, dtype=np.float32)
  94. M = cv2.getPerspectiveTransform(pts1, pts2)
  95. img_text = np.array(img_text, dtype=np.uint8)
  96. img_right_text = cv2.warpPerspective(
  97. img_text,
  98. M,
  99. img_size,
  100. flags=cv2.INTER_NEAREST,
  101. borderMode=cv2.BORDER_CONSTANT,
  102. borderValue=(255, 255, 255),
  103. )
  104. return img_right_text
  105. def draw_ocr_box_txt(
  106. image,
  107. boxes,
  108. txts=None,
  109. scores=None,
  110. drop_score=0.5,
  111. font_path='./doc/fonts/simfang.ttf',
  112. ):
  113. h, w = image.height, image.width
  114. img_left = image.copy()
  115. img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
  116. random.seed(0)
  117. draw_left = ImageDraw.Draw(img_left)
  118. if txts is None or len(txts) != len(boxes):
  119. txts = [None] * len(boxes)
  120. for idx, (box, txt) in enumerate(zip(boxes, txts)):
  121. if scores is not None and scores[idx] < drop_score:
  122. continue
  123. color = (random.randint(0, 255), random.randint(0, 255),
  124. random.randint(0, 255))
  125. if isinstance(box[0], list):
  126. box = list(map(tuple, box))
  127. draw_left.polygon(box, fill=color)
  128. img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
  129. pts = np.array(box, np.int32).reshape((-1, 1, 2))
  130. cv2.polylines(img_right_text, [pts], True, color, 1)
  131. img_right = cv2.bitwise_and(img_right, img_right_text)
  132. img_left = Image.blend(image, img_left, 0.5)
  133. img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
  134. img_show.paste(img_left, (0, 0, w, h))
  135. img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
  136. return np.array(img_show)
  137. def get_rotate_crop_image(img, points):
  138. """
  139. img_height, img_width = img.shape[0:2]
  140. left = int(np.min(points[:, 0]))
  141. right = int(np.max(points[:, 0]))
  142. top = int(np.min(points[:, 1]))
  143. bottom = int(np.max(points[:, 1]))
  144. img_crop = img[top:bottom, left:right, :].copy()
  145. points[:, 0] = points[:, 0] - left
  146. points[:, 1] = points[:, 1] - top
  147. """
  148. assert len(points) == 4, 'shape of points must be 4*2'
  149. img_crop_width = int(
  150. max(np.linalg.norm(points[0] - points[1]),
  151. np.linalg.norm(points[2] - points[3])))
  152. img_crop_height = int(
  153. max(np.linalg.norm(points[0] - points[3]),
  154. np.linalg.norm(points[1] - points[2])))
  155. pts_std = np.float32([
  156. [0, 0],
  157. [img_crop_width, 0],
  158. [img_crop_width, img_crop_height],
  159. [0, img_crop_height],
  160. ])
  161. M = cv2.getPerspectiveTransform(points, pts_std)
  162. dst_img = cv2.warpPerspective(
  163. img,
  164. M,
  165. (img_crop_width, img_crop_height),
  166. borderMode=cv2.BORDER_REPLICATE,
  167. flags=cv2.INTER_CUBIC,
  168. )
  169. dst_img_height, dst_img_width = dst_img.shape[0:2]
  170. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  171. dst_img = np.rot90(dst_img)
  172. return dst_img
  173. def get_minarea_rect_crop(img, points):
  174. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  175. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  176. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  177. if points[1][1] > points[0][1]:
  178. index_a = 0
  179. index_d = 1
  180. else:
  181. index_a = 1
  182. index_d = 0
  183. if points[3][1] > points[2][1]:
  184. index_b = 2
  185. index_c = 3
  186. else:
  187. index_b = 3
  188. index_c = 2
  189. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  190. crop_img = get_rotate_crop_image(img, np.array(box))
  191. return crop_img
  192. def check_gpu(use_gpu):
  193. import torch
  194. if use_gpu and not torch.cuda.is_available():
  195. use_gpu = False
  196. return use_gpu
  197. if __name__ == '__main__':
  198. pass