extract_textpoint_slow.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. import cv2
  2. import math
  3. import numpy as np
  4. from itertools import groupby
  5. from skimage.morphology._skeletonize import thin
  6. def get_dict(character_dict_path):
  7. character_str = ""
  8. with open(character_dict_path, "rb") as fin:
  9. lines = fin.readlines()
  10. for line in lines:
  11. line = line.decode("utf-8").strip("\n").strip("\r\n")
  12. character_str += line
  13. dict_character = list(character_str)
  14. return dict_character
  15. def point_pair2poly(point_pair_list):
  16. """
  17. Transfer vertical point_pairs into poly point in clockwise.
  18. """
  19. pair_length_list = []
  20. for point_pair in point_pair_list:
  21. pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
  22. pair_length_list.append(pair_length)
  23. pair_length_list = np.array(pair_length_list)
  24. pair_info = (
  25. pair_length_list.max(),
  26. pair_length_list.min(),
  27. pair_length_list.mean(), )
  28. point_num = len(point_pair_list) * 2
  29. point_list = [0] * point_num
  30. for idx, point_pair in enumerate(point_pair_list):
  31. point_list[idx] = point_pair[0]
  32. point_list[point_num - 1 - idx] = point_pair[1]
  33. return np.array(point_list).reshape(-1, 2), pair_info
  34. def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
  35. """
  36. Generate shrink_quad_along_width.
  37. """
  38. ratio_pair = np.array(
  39. [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
  40. p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
  41. p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
  42. return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
  43. def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
  44. """
  45. expand poly along width.
  46. """
  47. point_num = poly.shape[0]
  48. left_quad = np.array(
  49. [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
  50. left_ratio = (-shrink_ratio_of_width *
  51. np.linalg.norm(left_quad[0] - left_quad[3]) /
  52. (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6))
  53. left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
  54. right_quad = np.array(
  55. [
  56. poly[point_num // 2 - 2],
  57. poly[point_num // 2 - 1],
  58. poly[point_num // 2],
  59. poly[point_num // 2 + 1],
  60. ],
  61. dtype=np.float32, )
  62. right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[
  63. 0] - right_quad[3]) / (np.linalg.norm(right_quad[0] - right_quad[1]) +
  64. 1e-6)
  65. right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
  66. poly[0] = left_quad_expand[0]
  67. poly[-1] = left_quad_expand[-1]
  68. poly[point_num // 2 - 1] = right_quad_expand[1]
  69. poly[point_num // 2] = right_quad_expand[2]
  70. return poly
  71. def softmax(logits):
  72. """
  73. logits: N x d
  74. """
  75. max_value = np.max(logits, axis=1, keepdims=True)
  76. exp = np.exp(logits - max_value)
  77. exp_sum = np.sum(exp, axis=1, keepdims=True)
  78. dist = exp / exp_sum
  79. return dist
  80. def get_keep_pos_idxs(labels, remove_blank=None):
  81. """
  82. Remove duplicate and get pos idxs of keep items.
  83. The value of keep_blank should be [None, 95].
  84. """
  85. duplicate_len_list = []
  86. keep_pos_idx_list = []
  87. keep_char_idx_list = []
  88. for k, v_ in groupby(labels):
  89. current_len = len(list(v_))
  90. if k != remove_blank:
  91. current_idx = int(sum(duplicate_len_list) + current_len // 2)
  92. keep_pos_idx_list.append(current_idx)
  93. keep_char_idx_list.append(k)
  94. duplicate_len_list.append(current_len)
  95. return keep_char_idx_list, keep_pos_idx_list
  96. def remove_blank(labels, blank=0):
  97. new_labels = [x for x in labels if x != blank]
  98. return new_labels
  99. def insert_blank(labels, blank=0):
  100. new_labels = [blank]
  101. for l in labels:
  102. new_labels += [l, blank]
  103. return new_labels
  104. def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
  105. """
  106. CTC greedy (best path) decoder.
  107. """
  108. raw_str = np.argmax(np.array(probs_seq), axis=1)
  109. remove_blank_in_pos = None if keep_blank_in_idxs else blank
  110. dedup_str, keep_idx_list = get_keep_pos_idxs(
  111. raw_str, remove_blank=remove_blank_in_pos)
  112. dst_str = remove_blank(dedup_str, blank=blank)
  113. return dst_str, keep_idx_list
  114. def instance_ctc_greedy_decoder(gather_info,
  115. logits_map,
  116. keep_blank_in_idxs=True):
  117. """
  118. gather_info: [[x, y], [x, y] ...]
  119. logits_map: H x W X (n_chars + 1)
  120. """
  121. _, _, C = logits_map.shape
  122. ys, xs = zip(*gather_info)
  123. logits_seq = logits_map[list(ys), list(xs)] # n x 96
  124. probs_seq = softmax(logits_seq)
  125. dst_str, keep_idx_list = ctc_greedy_decoder(
  126. probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
  127. keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
  128. return dst_str, keep_gather_list
  129. def ctc_decoder_for_image(gather_info_list, logits_map,
  130. keep_blank_in_idxs=True):
  131. """
  132. CTC decoder using multiple processes.
  133. """
  134. decoder_results = []
  135. for gather_info in gather_info_list:
  136. res = instance_ctc_greedy_decoder(
  137. gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
  138. decoder_results.append(res)
  139. return decoder_results
  140. def sort_with_direction(pos_list, f_direction):
  141. """
  142. f_direction: h x w x 2
  143. pos_list: [[y, x], [y, x], [y, x] ...]
  144. """
  145. def sort_part_with_direction(pos_list, point_direction):
  146. pos_list = np.array(pos_list).reshape(-1, 2)
  147. point_direction = np.array(point_direction).reshape(-1, 2)
  148. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  149. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  150. sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
  151. sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
  152. return sorted_list, sorted_direction
  153. pos_list = np.array(pos_list).reshape(-1, 2)
  154. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
  155. point_direction = point_direction[:, ::-1] # x, y -> y, x
  156. sorted_point, sorted_direction = sort_part_with_direction(pos_list,
  157. point_direction)
  158. point_num = len(sorted_point)
  159. if point_num >= 16:
  160. middle_num = point_num // 2
  161. first_part_point = sorted_point[:middle_num]
  162. first_point_direction = sorted_direction[:middle_num]
  163. sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
  164. first_part_point, first_point_direction)
  165. last_part_point = sorted_point[middle_num:]
  166. last_point_direction = sorted_direction[middle_num:]
  167. sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
  168. last_part_point, last_point_direction)
  169. sorted_point = sorted_fist_part_point + sorted_last_part_point
  170. sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
  171. return sorted_point, np.array(sorted_direction)
  172. def add_id(pos_list, image_id=0):
  173. """
  174. Add id for gather feature, for inference.
  175. """
  176. new_list = []
  177. for item in pos_list:
  178. new_list.append((image_id, item[0], item[1]))
  179. return new_list
  180. def sort_and_expand_with_direction(pos_list, f_direction):
  181. """
  182. f_direction: h x w x 2
  183. pos_list: [[y, x], [y, x], [y, x] ...]
  184. """
  185. h, w, _ = f_direction.shape
  186. sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
  187. # expand along
  188. point_num = len(sorted_list)
  189. sub_direction_len = max(point_num // 3, 2)
  190. left_direction = point_direction[:sub_direction_len, :]
  191. right_dirction = point_direction[point_num - sub_direction_len:, :]
  192. left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
  193. left_average_len = np.linalg.norm(left_average_direction)
  194. left_start = np.array(sorted_list[0])
  195. left_step = left_average_direction / (left_average_len + 1e-6)
  196. right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
  197. right_average_len = np.linalg.norm(right_average_direction)
  198. right_step = right_average_direction / (right_average_len + 1e-6)
  199. right_start = np.array(sorted_list[-1])
  200. append_num = max(
  201. int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
  202. left_list = []
  203. right_list = []
  204. for i in range(append_num):
  205. ly, lx = (np.round(left_start + left_step * (i + 1)).flatten()
  206. .astype("int32").tolist())
  207. if ly < h and lx < w and (ly, lx) not in left_list:
  208. left_list.append((ly, lx))
  209. ry, rx = (np.round(right_start + right_step * (i + 1)).flatten()
  210. .astype("int32").tolist())
  211. if ry < h and rx < w and (ry, rx) not in right_list:
  212. right_list.append((ry, rx))
  213. all_list = left_list[::-1] + sorted_list + right_list
  214. return all_list
  215. def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
  216. """
  217. f_direction: h x w x 2
  218. pos_list: [[y, x], [y, x], [y, x] ...]
  219. binary_tcl_map: h x w
  220. """
  221. h, w, _ = f_direction.shape
  222. sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
  223. # expand along
  224. point_num = len(sorted_list)
  225. sub_direction_len = max(point_num // 3, 2)
  226. left_direction = point_direction[:sub_direction_len, :]
  227. right_dirction = point_direction[point_num - sub_direction_len:, :]
  228. left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
  229. left_average_len = np.linalg.norm(left_average_direction)
  230. left_start = np.array(sorted_list[0])
  231. left_step = left_average_direction / (left_average_len + 1e-6)
  232. right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
  233. right_average_len = np.linalg.norm(right_average_direction)
  234. right_step = right_average_direction / (right_average_len + 1e-6)
  235. right_start = np.array(sorted_list[-1])
  236. append_num = max(
  237. int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
  238. max_append_num = 2 * append_num
  239. left_list = []
  240. right_list = []
  241. for i in range(max_append_num):
  242. ly, lx = (np.round(left_start + left_step * (i + 1)).flatten()
  243. .astype("int32").tolist())
  244. if ly < h and lx < w and (ly, lx) not in left_list:
  245. if binary_tcl_map[ly, lx] > 0.5:
  246. left_list.append((ly, lx))
  247. else:
  248. break
  249. for i in range(max_append_num):
  250. ry, rx = (np.round(right_start + right_step * (i + 1)).flatten()
  251. .astype("int32").tolist())
  252. if ry < h and rx < w and (ry, rx) not in right_list:
  253. if binary_tcl_map[ry, rx] > 0.5:
  254. right_list.append((ry, rx))
  255. else:
  256. break
  257. all_list = left_list[::-1] + sorted_list + right_list
  258. return all_list
  259. def generate_pivot_list_curved(
  260. p_score,
  261. p_char_maps,
  262. f_direction,
  263. score_thresh=0.5,
  264. is_expand=True,
  265. is_backbone=False,
  266. image_id=0, ):
  267. """
  268. return center point and end point of TCL instance; filter with the char maps;
  269. """
  270. p_score = p_score[0]
  271. f_direction = f_direction.transpose(1, 2, 0)
  272. p_tcl_map = (p_score > score_thresh) * 1.0
  273. skeleton_map = thin(p_tcl_map)
  274. instance_count, instance_label_map = cv2.connectedComponents(
  275. skeleton_map.astype(np.uint8), connectivity=8)
  276. # get TCL Instance
  277. all_pos_yxs = []
  278. center_pos_yxs = []
  279. end_points_yxs = []
  280. instance_center_pos_yxs = []
  281. pred_strs = []
  282. if instance_count > 0:
  283. for instance_id in range(1, instance_count):
  284. pos_list = []
  285. ys, xs = np.where(instance_label_map == instance_id)
  286. pos_list = list(zip(ys, xs))
  287. ### FIX-ME, eliminate outlier
  288. if len(pos_list) < 3:
  289. continue
  290. if is_expand:
  291. pos_list_sorted = sort_and_expand_with_direction_v2(
  292. pos_list, f_direction, p_tcl_map)
  293. else:
  294. pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
  295. all_pos_yxs.append(pos_list_sorted)
  296. # use decoder to filter backgroud points.
  297. p_char_maps = p_char_maps.transpose([1, 2, 0])
  298. decode_res = ctc_decoder_for_image(
  299. all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
  300. for decoded_str, keep_yxs_list in decode_res:
  301. if is_backbone:
  302. keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
  303. instance_center_pos_yxs.append(keep_yxs_list_with_id)
  304. pred_strs.append(decoded_str)
  305. else:
  306. end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
  307. center_pos_yxs.extend(keep_yxs_list)
  308. if is_backbone:
  309. return pred_strs, instance_center_pos_yxs
  310. else:
  311. return center_pos_yxs, end_points_yxs
  312. def generate_pivot_list_horizontal(p_score,
  313. p_char_maps,
  314. f_direction,
  315. score_thresh=0.5,
  316. is_backbone=False,
  317. image_id=0):
  318. """
  319. return center point and end point of TCL instance; filter with the char maps;
  320. """
  321. p_score = p_score[0]
  322. f_direction = f_direction.transpose(1, 2, 0)
  323. p_tcl_map_bi = (p_score > score_thresh) * 1.0
  324. instance_count, instance_label_map = cv2.connectedComponents(
  325. p_tcl_map_bi.astype(np.uint8), connectivity=8)
  326. # get TCL Instance
  327. all_pos_yxs = []
  328. center_pos_yxs = []
  329. end_points_yxs = []
  330. instance_center_pos_yxs = []
  331. if instance_count > 0:
  332. for instance_id in range(1, instance_count):
  333. pos_list = []
  334. ys, xs = np.where(instance_label_map == instance_id)
  335. pos_list = list(zip(ys, xs))
  336. ### FIX-ME, eliminate outlier
  337. if len(pos_list) < 5:
  338. continue
  339. # add rule here
  340. main_direction = extract_main_direction(pos_list,
  341. f_direction) # y x
  342. reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
  343. is_h_angle = abs(np.sum(
  344. main_direction * reference_directin)) < math.cos(math.pi / 180 *
  345. 70)
  346. point_yxs = np.array(pos_list)
  347. max_y, max_x = np.max(point_yxs, axis=0)
  348. min_y, min_x = np.min(point_yxs, axis=0)
  349. is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
  350. pos_list_final = []
  351. if is_h_len:
  352. xs = np.unique(xs)
  353. for x in xs:
  354. ys = instance_label_map[:, x].copy().reshape((-1, ))
  355. y = int(np.where(ys == instance_id)[0].mean())
  356. pos_list_final.append((y, x))
  357. else:
  358. ys = np.unique(ys)
  359. for y in ys:
  360. xs = instance_label_map[y, :].copy().reshape((-1, ))
  361. x = int(np.where(xs == instance_id)[0].mean())
  362. pos_list_final.append((y, x))
  363. pos_list_sorted, _ = sort_with_direction(pos_list_final,
  364. f_direction)
  365. all_pos_yxs.append(pos_list_sorted)
  366. # use decoder to filter backgroud points.
  367. p_char_maps = p_char_maps.transpose([1, 2, 0])
  368. decode_res = ctc_decoder_for_image(
  369. all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
  370. for decoded_str, keep_yxs_list in decode_res:
  371. if is_backbone:
  372. keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
  373. instance_center_pos_yxs.append(keep_yxs_list_with_id)
  374. else:
  375. end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
  376. center_pos_yxs.extend(keep_yxs_list)
  377. if is_backbone:
  378. return instance_center_pos_yxs
  379. else:
  380. return center_pos_yxs, end_points_yxs
  381. def generate_pivot_list_slow(
  382. p_score,
  383. p_char_maps,
  384. f_direction,
  385. score_thresh=0.5,
  386. is_backbone=False,
  387. is_curved=True,
  388. image_id=0, ):
  389. """
  390. Warp all the function together.
  391. """
  392. if is_curved:
  393. return generate_pivot_list_curved(
  394. p_score,
  395. p_char_maps,
  396. f_direction,
  397. score_thresh=score_thresh,
  398. is_expand=True,
  399. is_backbone=is_backbone,
  400. image_id=image_id, )
  401. else:
  402. return generate_pivot_list_horizontal(
  403. p_score,
  404. p_char_maps,
  405. f_direction,
  406. score_thresh=score_thresh,
  407. is_backbone=is_backbone,
  408. image_id=image_id, )
  409. # for refine module
  410. def extract_main_direction(pos_list, f_direction):
  411. """
  412. f_direction: h x w x 2
  413. pos_list: [[y, x], [y, x], [y, x] ...]
  414. """
  415. pos_list = np.array(pos_list)
  416. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
  417. point_direction = point_direction[:, ::-1] # x, y -> y, x
  418. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  419. average_direction = average_direction / (
  420. np.linalg.norm(average_direction) + 1e-6)
  421. return average_direction
  422. def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
  423. """
  424. f_direction: h x w x 2
  425. pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
  426. """
  427. pos_list_full = np.array(pos_list).reshape(-1, 3)
  428. pos_list = pos_list_full[:, 1:]
  429. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
  430. point_direction = point_direction[:, ::-1] # x, y -> y, x
  431. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  432. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  433. sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
  434. return sorted_list
  435. def sort_by_direction_with_image_id(pos_list, f_direction):
  436. """
  437. f_direction: h x w x 2
  438. pos_list: [[y, x], [y, x], [y, x] ...]
  439. """
  440. def sort_part_with_direction(pos_list_full, point_direction):
  441. pos_list_full = np.array(pos_list_full).reshape(-1, 3)
  442. pos_list = pos_list_full[:, 1:]
  443. point_direction = np.array(point_direction).reshape(-1, 2)
  444. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  445. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  446. sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
  447. sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
  448. return sorted_list, sorted_direction
  449. pos_list = np.array(pos_list).reshape(-1, 3)
  450. point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
  451. point_direction = point_direction[:, ::-1] # x, y -> y, x
  452. sorted_point, sorted_direction = sort_part_with_direction(pos_list,
  453. point_direction)
  454. point_num = len(sorted_point)
  455. if point_num >= 16:
  456. middle_num = point_num // 2
  457. first_part_point = sorted_point[:middle_num]
  458. first_point_direction = sorted_direction[:middle_num]
  459. sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
  460. first_part_point, first_point_direction)
  461. last_part_point = sorted_point[middle_num:]
  462. last_point_direction = sorted_direction[middle_num:]
  463. sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
  464. last_part_point, last_point_direction)
  465. sorted_point = sorted_fist_part_point + sorted_last_part_point
  466. sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
  467. return sorted_point
  468. def generate_pivot_list_tt_inference(
  469. p_score,
  470. p_char_maps,
  471. f_direction,
  472. score_thresh=0.5,
  473. is_backbone=False,
  474. is_curved=True,
  475. image_id=0, ):
  476. """
  477. return center point and end point of TCL instance; filter with the char maps;
  478. """
  479. p_score = p_score[0]
  480. f_direction = f_direction.transpose(1, 2, 0)
  481. p_tcl_map = (p_score > score_thresh) * 1.0
  482. skeleton_map = thin(p_tcl_map)
  483. instance_count, instance_label_map = cv2.connectedComponents(
  484. skeleton_map.astype(np.uint8), connectivity=8)
  485. # get TCL Instance
  486. all_pos_yxs = []
  487. if instance_count > 0:
  488. for instance_id in range(1, instance_count):
  489. pos_list = []
  490. ys, xs = np.where(instance_label_map == instance_id)
  491. pos_list = list(zip(ys, xs))
  492. ### FIX-ME, eliminate outlier
  493. if len(pos_list) < 3:
  494. continue
  495. pos_list_sorted = sort_and_expand_with_direction_v2(
  496. pos_list, f_direction, p_tcl_map)
  497. pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
  498. all_pos_yxs.append(pos_list_sorted_with_id)
  499. return all_pos_yxs