extract_textpoint_fast.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. import cv2
  2. import math
  3. import numpy as np
  4. from itertools import groupby
  5. from skimage.morphology._skeletonize import thin
  6. def get_dict(character_dict_path):
  7. character_str = ""
  8. with open(character_dict_path, "rb") as fin:
  9. lines = fin.readlines()
  10. for line in lines:
  11. line = line.decode("utf-8").strip("\n").strip("\r\n")
  12. character_str += line
  13. dict_character = list(character_str)
  14. return dict_character
  15. def softmax(logits):
  16. """
  17. logits: N x d
  18. """
  19. max_value = np.max(logits, axis=1, keepdims=True)
  20. exp = np.exp(logits - max_value)
  21. exp_sum = np.sum(exp, axis=1, keepdims=True)
  22. dist = exp / exp_sum
  23. return dist
  24. def get_keep_pos_idxs(labels, remove_blank=None):
  25. """
  26. Remove duplicate and get pos idxs of keep items.
  27. The value of keep_blank should be [None, 95].
  28. """
  29. duplicate_len_list = []
  30. keep_pos_idx_list = []
  31. keep_char_idx_list = []
  32. for k, v_ in groupby(labels):
  33. current_len = len(list(v_))
  34. if k != remove_blank:
  35. current_idx = int(sum(duplicate_len_list) + current_len // 2)
  36. keep_pos_idx_list.append(current_idx)
  37. keep_char_idx_list.append(k)
  38. duplicate_len_list.append(current_len)
  39. return keep_char_idx_list, keep_pos_idx_list
  40. def remove_blank(labels, blank=0):
  41. new_labels = [x for x in labels if x != blank]
  42. return new_labels
  43. def insert_blank(labels, blank=0):
  44. new_labels = [blank]
  45. for l in labels:
  46. new_labels += [l, blank]
  47. return new_labels
  48. def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
  49. """
  50. CTC greedy (best path) decoder.
  51. """
  52. raw_str = np.argmax(np.array(probs_seq), axis=1)
  53. remove_blank_in_pos = None if keep_blank_in_idxs else blank
  54. dedup_str, keep_idx_list = get_keep_pos_idxs(
  55. raw_str, remove_blank=remove_blank_in_pos)
  56. dst_str = remove_blank(dedup_str, blank=blank)
  57. return dst_str, keep_idx_list
  58. def instance_ctc_greedy_decoder(gather_info,
  59. logits_map,
  60. pts_num=4,
  61. point_gather_mode=None):
  62. _, _, C = logits_map.shape
  63. if point_gather_mode == "align":
  64. insert_num = 0
  65. gather_info = np.array(gather_info)
  66. length = len(gather_info) - 1
  67. for index in range(length):
  68. stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[
  69. index + 1 + insert_num][0])
  70. stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[
  71. index + 1 + insert_num][1])
  72. max_points = int(max(stride_x, stride_y))
  73. stride = (gather_info[index + insert_num] -
  74. gather_info[index + 1 + insert_num]) / (max_points)
  75. insert_num_temp = max_points - 1
  76. for i in range(int(insert_num_temp)):
  77. insert_value = gather_info[index + insert_num] - (i + 1
  78. ) * stride
  79. insert_index = index + i + 1 + insert_num
  80. gather_info = np.insert(
  81. gather_info, insert_index, insert_value, axis=0)
  82. insert_num += insert_num_temp
  83. gather_info = gather_info.tolist()
  84. else:
  85. pass
  86. ys, xs = zip(*gather_info)
  87. logits_seq = logits_map[list(ys), list(xs)]
  88. probs_seq = logits_seq
  89. labels = np.argmax(probs_seq, axis=1)
  90. dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
  91. detal = len(gather_info) // (pts_num - 1)
  92. keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
  93. keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
  94. return dst_str, keep_gather_list
  95. def ctc_decoder_for_image(gather_info_list,
  96. logits_map,
  97. Lexicon_Table,
  98. pts_num=6,
  99. point_gather_mode=None):
  100. """
  101. CTC decoder using multiple processes.
  102. """
  103. decoder_str = []
  104. decoder_xys = []
  105. for gather_info in gather_info_list:
  106. if len(gather_info) < pts_num:
  107. continue
  108. dst_str, xys_list = instance_ctc_greedy_decoder(
  109. gather_info,
  110. logits_map,
  111. pts_num=pts_num,
  112. point_gather_mode=point_gather_mode, )
  113. dst_str_readable = "".join([Lexicon_Table[idx] for idx in dst_str])
  114. if len(dst_str_readable) < 2:
  115. continue
  116. decoder_str.append(dst_str_readable)
  117. decoder_xys.append(xys_list)
  118. return decoder_str, decoder_xys
  119. def sort_with_direction(pos_list, f_direction):
  120. """
  121. f_direction: h x w x 2
  122. pos_list: [[y, x], [y, x], [y, x] ...]
  123. """
  124. def sort_part_with_direction(pos_list, point_direction):
  125. pos_list = np.array(pos_list).reshape(-1, 2)
  126. point_direction = np.array(point_direction).reshape(-1, 2)
  127. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  128. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  129. sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
  130. sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
  131. return sorted_list, sorted_direction
  132. pos_list = np.array(pos_list).reshape(-1, 2)
  133. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
  134. point_direction = point_direction[:, ::-1] # x, y -> y, x
  135. sorted_point, sorted_direction = sort_part_with_direction(pos_list,
  136. point_direction)
  137. point_num = len(sorted_point)
  138. if point_num >= 16:
  139. middle_num = point_num // 2
  140. first_part_point = sorted_point[:middle_num]
  141. first_point_direction = sorted_direction[:middle_num]
  142. sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
  143. first_part_point, first_point_direction)
  144. last_part_point = sorted_point[middle_num:]
  145. last_point_direction = sorted_direction[middle_num:]
  146. sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
  147. last_part_point, last_point_direction)
  148. sorted_point = sorted_fist_part_point + sorted_last_part_point
  149. sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
  150. return sorted_point, np.array(sorted_direction)
  151. def add_id(pos_list, image_id=0):
  152. """
  153. Add id for gather feature, for inference.
  154. """
  155. new_list = []
  156. for item in pos_list:
  157. new_list.append((image_id, item[0], item[1]))
  158. return new_list
  159. def sort_and_expand_with_direction(pos_list, f_direction):
  160. """
  161. f_direction: h x w x 2
  162. pos_list: [[y, x], [y, x], [y, x] ...]
  163. """
  164. h, w, _ = f_direction.shape
  165. sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
  166. point_num = len(sorted_list)
  167. sub_direction_len = max(point_num // 3, 2)
  168. left_direction = point_direction[:sub_direction_len, :]
  169. right_dirction = point_direction[point_num - sub_direction_len:, :]
  170. left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
  171. left_average_len = np.linalg.norm(left_average_direction)
  172. left_start = np.array(sorted_list[0])
  173. left_step = left_average_direction / (left_average_len + 1e-6)
  174. right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
  175. right_average_len = np.linalg.norm(right_average_direction)
  176. right_step = right_average_direction / (right_average_len + 1e-6)
  177. right_start = np.array(sorted_list[-1])
  178. append_num = max(
  179. int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
  180. left_list = []
  181. right_list = []
  182. for i in range(append_num):
  183. ly, lx = (np.round(left_start + left_step * (i + 1)).flatten()
  184. .astype("int32").tolist())
  185. if ly < h and lx < w and (ly, lx) not in left_list:
  186. left_list.append((ly, lx))
  187. ry, rx = (np.round(right_start + right_step * (i + 1)).flatten()
  188. .astype("int32").tolist())
  189. if ry < h and rx < w and (ry, rx) not in right_list:
  190. right_list.append((ry, rx))
  191. all_list = left_list[::-1] + sorted_list + right_list
  192. return all_list
  193. def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
  194. """
  195. f_direction: h x w x 2
  196. pos_list: [[y, x], [y, x], [y, x] ...]
  197. binary_tcl_map: h x w
  198. """
  199. h, w, _ = f_direction.shape
  200. sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
  201. point_num = len(sorted_list)
  202. sub_direction_len = max(point_num // 3, 2)
  203. left_direction = point_direction[:sub_direction_len, :]
  204. right_dirction = point_direction[point_num - sub_direction_len:, :]
  205. left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
  206. left_average_len = np.linalg.norm(left_average_direction)
  207. left_start = np.array(sorted_list[0])
  208. left_step = left_average_direction / (left_average_len + 1e-6)
  209. right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
  210. right_average_len = np.linalg.norm(right_average_direction)
  211. right_step = right_average_direction / (right_average_len + 1e-6)
  212. right_start = np.array(sorted_list[-1])
  213. append_num = max(
  214. int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
  215. max_append_num = 2 * append_num
  216. left_list = []
  217. right_list = []
  218. for i in range(max_append_num):
  219. ly, lx = (np.round(left_start + left_step * (i + 1)).flatten()
  220. .astype("int32").tolist())
  221. if ly < h and lx < w and (ly, lx) not in left_list:
  222. if binary_tcl_map[ly, lx] > 0.5:
  223. left_list.append((ly, lx))
  224. else:
  225. break
  226. for i in range(max_append_num):
  227. ry, rx = (np.round(right_start + right_step * (i + 1)).flatten()
  228. .astype("int32").tolist())
  229. if ry < h and rx < w and (ry, rx) not in right_list:
  230. if binary_tcl_map[ry, rx] > 0.5:
  231. right_list.append((ry, rx))
  232. else:
  233. break
  234. all_list = left_list[::-1] + sorted_list + right_list
  235. return all_list
  236. def point_pair2poly(point_pair_list):
  237. """
  238. Transfer vertical point_pairs into poly point in clockwise.
  239. """
  240. point_num = len(point_pair_list) * 2
  241. point_list = [0] * point_num
  242. for idx, point_pair in enumerate(point_pair_list):
  243. point_list[idx] = point_pair[0]
  244. point_list[point_num - 1 - idx] = point_pair[1]
  245. return np.array(point_list).reshape(-1, 2)
  246. def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
  247. ratio_pair = np.array(
  248. [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
  249. p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
  250. p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
  251. return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
  252. def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
  253. """
  254. expand poly along width.
  255. """
  256. point_num = poly.shape[0]
  257. left_quad = np.array(
  258. [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
  259. left_ratio = (-shrink_ratio_of_width *
  260. np.linalg.norm(left_quad[0] - left_quad[3]) /
  261. (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6))
  262. left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
  263. right_quad = np.array(
  264. [
  265. poly[point_num // 2 - 2],
  266. poly[point_num // 2 - 1],
  267. poly[point_num // 2],
  268. poly[point_num // 2 + 1],
  269. ],
  270. dtype=np.float32, )
  271. right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[
  272. 0] - right_quad[3]) / (np.linalg.norm(right_quad[0] - right_quad[1]) +
  273. 1e-6)
  274. right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
  275. poly[0] = left_quad_expand[0]
  276. poly[-1] = left_quad_expand[-1]
  277. poly[point_num // 2 - 1] = right_quad_expand[1]
  278. poly[point_num // 2] = right_quad_expand[2]
  279. return poly
  280. def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
  281. src_h, valid_set):
  282. poly_list = []
  283. keep_str_list = []
  284. for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
  285. if len(keep_str) < 2:
  286. print("--> too short, {}".format(keep_str))
  287. continue
  288. offset_expand = 1.0
  289. if valid_set == "totaltext":
  290. offset_expand = 1.2
  291. point_pair_list = []
  292. for y, x in yx_center_line:
  293. offset = p_border[:, y, x].reshape(2, 2) * offset_expand
  294. ori_yx = np.array([y, x], dtype=np.float32)
  295. point_pair = ((ori_yx + offset)[:, ::-1] * 4.0 /
  296. np.array([ratio_w, ratio_h]).reshape(-1, 2))
  297. point_pair_list.append(point_pair)
  298. detected_poly = point_pair2poly(point_pair_list)
  299. detected_poly = expand_poly_along_width(
  300. detected_poly, shrink_ratio_of_width=0.2)
  301. detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
  302. detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
  303. keep_str_list.append(keep_str)
  304. if valid_set == "partvgg":
  305. middle_point = len(detected_poly) // 2
  306. detected_poly = detected_poly[
  307. [0, middle_point - 1, middle_point, -1], :]
  308. poly_list.append(detected_poly)
  309. elif valid_set == "totaltext":
  310. poly_list.append(detected_poly)
  311. else:
  312. print("--> Not supported format.")
  313. exit(-1)
  314. return poly_list, keep_str_list
  315. def generate_pivot_list_fast(
  316. p_score,
  317. p_char_maps,
  318. f_direction,
  319. Lexicon_Table,
  320. score_thresh=0.5,
  321. point_gather_mode=None, ):
  322. """
  323. return center point and end point of TCL instance; filter with the char maps;
  324. """
  325. p_score = p_score[0]
  326. f_direction = f_direction.transpose(1, 2, 0)
  327. p_tcl_map = (p_score > score_thresh) * 1.0
  328. skeleton_map = thin(p_tcl_map.astype(np.uint8))
  329. instance_count, instance_label_map = cv2.connectedComponents(
  330. skeleton_map.astype(np.uint8), connectivity=8)
  331. # get TCL Instance
  332. all_pos_yxs = []
  333. if instance_count > 0:
  334. for instance_id in range(1, instance_count):
  335. pos_list = []
  336. ys, xs = np.where(instance_label_map == instance_id)
  337. pos_list = list(zip(ys, xs))
  338. if len(pos_list) < 3:
  339. continue
  340. pos_list_sorted = sort_and_expand_with_direction_v2(
  341. pos_list, f_direction, p_tcl_map)
  342. all_pos_yxs.append(pos_list_sorted)
  343. p_char_maps = p_char_maps.transpose([1, 2, 0])
  344. decoded_str, keep_yxs_list = ctc_decoder_for_image(
  345. all_pos_yxs,
  346. logits_map=p_char_maps,
  347. Lexicon_Table=Lexicon_Table,
  348. point_gather_mode=point_gather_mode, )
  349. return keep_yxs_list, decoded_str
  350. def extract_main_direction(pos_list, f_direction):
  351. """
  352. f_direction: h x w x 2
  353. pos_list: [[y, x], [y, x], [y, x] ...]
  354. """
  355. pos_list = np.array(pos_list)
  356. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
  357. point_direction = point_direction[:, ::-1] # x, y -> y, x
  358. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  359. average_direction = average_direction / (
  360. np.linalg.norm(average_direction) + 1e-6)
  361. return average_direction
  362. def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
  363. """
  364. f_direction: h x w x 2
  365. pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
  366. """
  367. pos_list_full = np.array(pos_list).reshape(-1, 3)
  368. pos_list = pos_list_full[:, 1:]
  369. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
  370. point_direction = point_direction[:, ::-1] # x, y -> y, x
  371. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  372. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  373. sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
  374. return sorted_list
  375. def sort_by_direction_with_image_id(pos_list, f_direction):
  376. """
  377. f_direction: h x w x 2
  378. pos_list: [[y, x], [y, x], [y, x] ...]
  379. """
  380. def sort_part_with_direction(pos_list_full, point_direction):
  381. pos_list_full = np.array(pos_list_full).reshape(-1, 3)
  382. pos_list = pos_list_full[:, 1:]
  383. point_direction = np.array(point_direction).reshape(-1, 2)
  384. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  385. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  386. sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
  387. sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
  388. return sorted_list, sorted_direction
  389. pos_list = np.array(pos_list).reshape(-1, 3)
  390. point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
  391. point_direction = point_direction[:, ::-1] # x, y -> y, x
  392. sorted_point, sorted_direction = sort_part_with_direction(pos_list,
  393. point_direction)
  394. point_num = len(sorted_point)
  395. if point_num >= 16:
  396. middle_num = point_num // 2
  397. first_part_point = sorted_point[:middle_num]
  398. first_point_direction = sorted_direction[:middle_num]
  399. sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
  400. first_part_point, first_point_direction)
  401. last_part_point = sorted_point[middle_num:]
  402. last_point_direction = sorted_direction[middle_num:]
  403. sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
  404. last_part_point, last_point_direction)
  405. sorted_point = sorted_fist_part_point + sorted_last_part_point
  406. sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
  407. return sorted_point