pgnet_pp_utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import torch
  2. import os
  3. import sys
  4. __dir__ = os.path.dirname(__file__)
  5. sys.path.append(__dir__)
  6. sys.path.append(os.path.join(__dir__, ".."))
  7. from extract_textpoint_slow import *
  8. from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
  9. class PGNet_PostProcess(object):
  10. # two different post-process
  11. def __init__(
  12. self,
  13. character_dict_path,
  14. valid_set,
  15. score_thresh,
  16. outs_dict,
  17. shape_list,
  18. point_gather_mode=None, ):
  19. self.Lexicon_Table = get_dict(character_dict_path)
  20. self.valid_set = valid_set
  21. self.score_thresh = score_thresh
  22. self.outs_dict = outs_dict
  23. self.shape_list = shape_list
  24. self.point_gather_mode = point_gather_mode
  25. def pg_postprocess_fast(self):
  26. p_score = self.outs_dict["f_score"]
  27. p_border = self.outs_dict["f_border"]
  28. p_char = self.outs_dict["f_char"]
  29. p_direction = self.outs_dict["f_direction"]
  30. if isinstance(p_score, torch.Tensor):
  31. p_score = p_score[0].numpy()
  32. p_border = p_border[0].numpy()
  33. p_direction = p_direction[0].numpy()
  34. p_char = p_char[0].numpy()
  35. else:
  36. p_score = p_score[0]
  37. p_border = p_border[0]
  38. p_direction = p_direction[0]
  39. p_char = p_char[0]
  40. src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
  41. instance_yxs_list, seq_strs = generate_pivot_list_fast(
  42. p_score,
  43. p_char,
  44. p_direction,
  45. self.Lexicon_Table,
  46. score_thresh=self.score_thresh,
  47. point_gather_mode=self.point_gather_mode, )
  48. poly_list, keep_str_list = restore_poly(
  49. instance_yxs_list,
  50. seq_strs,
  51. p_border,
  52. ratio_w,
  53. ratio_h,
  54. src_w,
  55. src_h,
  56. self.valid_set, )
  57. data = {
  58. "points": poly_list,
  59. "texts": keep_str_list,
  60. }
  61. return data
  62. def pg_postprocess_slow(self):
  63. p_score = self.outs_dict["f_score"]
  64. p_border = self.outs_dict["f_border"]
  65. p_char = self.outs_dict["f_char"]
  66. p_direction = self.outs_dict["f_direction"]
  67. if isinstance(p_score, torch.Tensor):
  68. p_score = p_score[0].numpy()
  69. p_border = p_border[0].numpy()
  70. p_direction = p_direction[0].numpy()
  71. p_char = p_char[0].numpy()
  72. else:
  73. p_score = p_score[0]
  74. p_border = p_border[0]
  75. p_direction = p_direction[0]
  76. p_char = p_char[0]
  77. src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
  78. is_curved = self.valid_set == "totaltext"
  79. char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
  80. p_score,
  81. p_char,
  82. p_direction,
  83. score_thresh=self.score_thresh,
  84. is_backbone=True,
  85. is_curved=is_curved, )
  86. seq_strs = []
  87. for char_idx_set in char_seq_idx_set:
  88. pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set])
  89. seq_strs.append(pr_str)
  90. poly_list = []
  91. keep_str_list = []
  92. all_point_list = []
  93. all_point_pair_list = []
  94. for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
  95. if len(yx_center_line) == 1:
  96. yx_center_line.append(yx_center_line[-1])
  97. offset_expand = 1.0
  98. if self.valid_set == "totaltext":
  99. offset_expand = 1.2
  100. point_pair_list = []
  101. for batch_id, y, x in yx_center_line:
  102. offset = p_border[:, y, x].reshape(2, 2)
  103. if offset_expand != 1.0:
  104. offset_length = np.linalg.norm(
  105. offset, axis=1, keepdims=True)
  106. expand_length = np.clip(
  107. offset_length * (offset_expand - 1),
  108. a_min=0.5,
  109. a_max=3.0)
  110. offset_detal = offset / offset_length * expand_length
  111. offset = offset + offset_detal
  112. ori_yx = np.array([y, x], dtype=np.float32)
  113. point_pair = ((ori_yx + offset)[:, ::-1] * 4.0 /
  114. np.array([ratio_w, ratio_h]).reshape(-1, 2))
  115. point_pair_list.append(point_pair)
  116. all_point_list.append([
  117. int(round(x * 4.0 / ratio_w)),
  118. int(round(y * 4.0 / ratio_h))
  119. ])
  120. all_point_pair_list.append(point_pair.round().astype(np.int32)
  121. .tolist())
  122. detected_poly, pair_length_info = point_pair2poly(point_pair_list)
  123. detected_poly = expand_poly_along_width(
  124. detected_poly, shrink_ratio_of_width=0.2)
  125. detected_poly[:, 0] = np.clip(
  126. detected_poly[:, 0], a_min=0, a_max=src_w)
  127. detected_poly[:, 1] = np.clip(
  128. detected_poly[:, 1], a_min=0, a_max=src_h)
  129. if len(keep_str) < 2:
  130. continue
  131. keep_str_list.append(keep_str)
  132. detected_poly = np.round(detected_poly).astype("int32")
  133. if self.valid_set == "partvgg":
  134. middle_point = len(detected_poly) // 2
  135. detected_poly = detected_poly[
  136. [0, middle_point - 1, middle_point, -1], :]
  137. poly_list.append(detected_poly)
  138. elif self.valid_set == "totaltext":
  139. poly_list.append(detected_poly)
  140. else:
  141. print("--> Not supported format.")
  142. exit(-1)
  143. data = {
  144. "points": poly_list,
  145. "texts": keep_str_list,
  146. }
  147. return data