db_postprocess.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import numpy as np
  2. import cv2
  3. from shapely.geometry import Polygon
  4. import pyclipper
  5. """
  6. This code is refered from:
  7. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
  8. """
  9. class DBPostProcess(object):
  10. """
  11. The post process for Differentiable Binarization (DB).
  12. """
  13. def __init__(
  14. self,
  15. thresh=0.3,
  16. box_thresh=0.7,
  17. max_candidates=1000,
  18. unclip_ratio=2.0,
  19. use_dilation=False,
  20. score_mode='fast',
  21. box_type='quad',
  22. **kwargs,
  23. ):
  24. self.thresh = thresh
  25. self.box_thresh = box_thresh
  26. self.max_candidates = max_candidates
  27. self.unclip_ratio = unclip_ratio
  28. self.min_size = 3
  29. self.score_mode = score_mode
  30. self.box_type = box_type
  31. assert score_mode in [
  32. 'slow',
  33. 'fast',
  34. ], 'Score mode must be in [slow, fast] but got: {}'.format(score_mode)
  35. self.dilation_kernel = None if not use_dilation else np.array([[1, 1],
  36. [1, 1]])
  37. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  38. """
  39. _bitmap: single map with shape (1, H, W),
  40. whose values are binarized as {0, 1}
  41. """
  42. bitmap = _bitmap
  43. height, width = bitmap.shape
  44. boxes = []
  45. scores = []
  46. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
  47. cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  48. for contour in contours[:self.max_candidates]:
  49. epsilon = 0.002 * cv2.arcLength(contour, True)
  50. approx = cv2.approxPolyDP(contour, epsilon, True)
  51. points = approx.reshape((-1, 2))
  52. if points.shape[0] < 4:
  53. continue
  54. score = self.box_score_fast(pred, points.reshape(-1, 2))
  55. if self.box_thresh > score:
  56. continue
  57. if points.shape[0] > 2:
  58. box = self.unclip(points, self.unclip_ratio)
  59. if len(box) > 1:
  60. continue
  61. else:
  62. continue
  63. box = np.array(box).reshape(-1, 2)
  64. if len(box) == 0:
  65. continue
  66. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  67. if sside < self.min_size + 2:
  68. continue
  69. box = np.array(box)
  70. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0,
  71. dest_width)
  72. box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0,
  73. dest_height)
  74. boxes.append(box.tolist())
  75. scores.append(score)
  76. return boxes, scores
  77. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  78. """
  79. _bitmap: single map with shape (1, H, W),
  80. whose values are binarized as {0, 1}
  81. """
  82. bitmap = _bitmap
  83. height, width = bitmap.shape
  84. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
  85. cv2.CHAIN_APPROX_SIMPLE)
  86. if len(outs) == 3:
  87. img, contours, _ = outs[0], outs[1], outs[2]
  88. elif len(outs) == 2:
  89. contours, _ = outs[0], outs[1]
  90. num_contours = min(len(contours), self.max_candidates)
  91. boxes = []
  92. scores = []
  93. for index in range(num_contours):
  94. contour = contours[index]
  95. points, sside = self.get_mini_boxes(contour)
  96. if sside < self.min_size:
  97. continue
  98. points = np.array(points)
  99. if self.score_mode == 'fast':
  100. score = self.box_score_fast(pred, points.reshape(-1, 2))
  101. else:
  102. score = self.box_score_slow(pred, contour)
  103. if self.box_thresh > score:
  104. continue
  105. box = self.unclip(points, self.unclip_ratio)
  106. if len(box) > 1:
  107. continue
  108. box = np.array(box).reshape(-1, 1, 2)
  109. box, sside = self.get_mini_boxes(box)
  110. if sside < self.min_size + 2:
  111. continue
  112. box = np.array(box)
  113. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0,
  114. dest_width)
  115. box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0,
  116. dest_height)
  117. boxes.append(box.astype('int32'))
  118. scores.append(score)
  119. return np.array(boxes, dtype='int32'), scores
  120. def unclip(self, box, unclip_ratio):
  121. poly = Polygon(box)
  122. distance = poly.area * unclip_ratio / poly.length
  123. offset = pyclipper.PyclipperOffset()
  124. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  125. expanded = offset.Execute(distance)
  126. return expanded
  127. def get_mini_boxes(self, contour):
  128. bounding_box = cv2.minAreaRect(contour)
  129. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  130. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  131. if points[1][1] > points[0][1]:
  132. index_1 = 0
  133. index_4 = 1
  134. else:
  135. index_1 = 1
  136. index_4 = 0
  137. if points[3][1] > points[2][1]:
  138. index_2 = 2
  139. index_3 = 3
  140. else:
  141. index_2 = 3
  142. index_3 = 2
  143. box = [
  144. points[index_1], points[index_2], points[index_3], points[index_4]
  145. ]
  146. return box, min(bounding_box[1])
  147. def box_score_fast(self, bitmap, _box):
  148. """
  149. box_score_fast: use bbox mean score as the mean score
  150. """
  151. h, w = bitmap.shape[:2]
  152. box = _box.copy()
  153. xmin = np.clip(np.floor(box[:, 0].min()).astype('int32'), 0, w - 1)
  154. xmax = np.clip(np.ceil(box[:, 0].max()).astype('int32'), 0, w - 1)
  155. ymin = np.clip(np.floor(box[:, 1].min()).astype('int32'), 0, h - 1)
  156. ymax = np.clip(np.ceil(box[:, 1].max()).astype('int32'), 0, h - 1)
  157. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  158. box[:, 0] = box[:, 0] - xmin
  159. box[:, 1] = box[:, 1] - ymin
  160. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype('int32'), 1)
  161. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  162. def box_score_slow(self, bitmap, contour):
  163. """
  164. box_score_slow: use polyon mean score as the mean score
  165. """
  166. h, w = bitmap.shape[:2]
  167. contour = contour.copy()
  168. contour = np.reshape(contour, (-1, 2))
  169. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  170. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  171. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  172. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  173. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  174. contour[:, 0] = contour[:, 0] - xmin
  175. contour[:, 1] = contour[:, 1] - ymin
  176. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype('int32'), 1)
  177. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  178. def __call__(self, outs_dict, batch, **kwargs):
  179. self.thresh = kwargs.get('mask_thresh', self.thresh)
  180. self.box_thresh = kwargs.get('box_thresh', self.box_thresh)
  181. self.unclip_ratio = kwargs.get('unclip_ratio', self.unclip_ratio)
  182. self.box_type = kwargs.get('box_type', self.box_type)
  183. self.score_mode = kwargs.get('score_mode', self.score_mode)
  184. pred = outs_dict['maps']
  185. if kwargs.get('torch_tensor', True):
  186. pred = pred.detach().cpu().numpy()
  187. pred = pred[:, 0, :, :]
  188. segmentation = pred > self.thresh
  189. boxes_batch = []
  190. for batch_index in range(pred.shape[0]):
  191. src_h, src_w, ratio_h, ratio_w = batch[1][batch_index]
  192. if self.dilation_kernel is not None:
  193. mask = cv2.dilate(
  194. np.array(segmentation[batch_index]).astype(np.uint8),
  195. self.dilation_kernel,
  196. )
  197. else:
  198. mask = segmentation[batch_index]
  199. if self.box_type == 'poly':
  200. boxes, scores = self.polygons_from_bitmap(
  201. pred[batch_index], mask, src_w, src_h)
  202. elif self.box_type == 'quad':
  203. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
  204. src_w, src_h)
  205. else:
  206. raise ValueError(
  207. "box_type can only be one of ['quad', 'poly']")
  208. boxes_batch.append({'points': boxes})
  209. return boxes_batch
  210. class DistillationDBPostProcess(object):
  211. def __init__(
  212. self,
  213. model_name=['student'],
  214. key=None,
  215. thresh=0.3,
  216. box_thresh=0.6,
  217. max_candidates=1000,
  218. unclip_ratio=1.5,
  219. use_dilation=False,
  220. score_mode='fast',
  221. box_type='quad',
  222. **kwargs,
  223. ):
  224. self.model_name = model_name
  225. self.key = key
  226. self.post_process = DBPostProcess(
  227. thresh=thresh,
  228. box_thresh=box_thresh,
  229. max_candidates=max_candidates,
  230. unclip_ratio=unclip_ratio,
  231. use_dilation=use_dilation,
  232. score_mode=score_mode,
  233. box_type=box_type,
  234. )
  235. def __call__(self, predicts, shape_list):
  236. results = {}
  237. for k in self.model_name:
  238. results[k] = self.post_process(predicts[k], shape_list=shape_list)
  239. return results