crop_paste.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import cv2
  15. import random
  16. import numpy as np
  17. from PIL import Image
  18. from shapely.geometry import Polygon
  19. from .iaa_augment import IaaAugment
  20. from .crop_resize import is_poly_outside_rect
  21. from tools.infer.utility import get_rotate_crop_image
  22. class CopyPaste(object):
  23. def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
  24. self.ext_data_num = 1
  25. self.objects_paste_ratio = objects_paste_ratio
  26. self.limit_paste = limit_paste
  27. augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
  28. self.aug = IaaAugment(augmenter_args)
  29. def __call__(self, data):
  30. point_num = data['polys'].shape[1]
  31. src_img = data['image']
  32. src_polys = data['polys'].tolist()
  33. src_texts = data['texts']
  34. src_ignores = data['ignore_tags'].tolist()
  35. ext_data = data['ext_data'][0]
  36. ext_image = ext_data['image']
  37. ext_polys = ext_data['polys']
  38. ext_texts = ext_data['texts']
  39. ext_ignores = ext_data['ignore_tags']
  40. indexes = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
  41. select_num = max(
  42. 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
  43. random.shuffle(indexes)
  44. select_idxs = indexes[:select_num]
  45. select_polys = ext_polys[select_idxs]
  46. select_ignores = ext_ignores[select_idxs]
  47. src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
  48. ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
  49. src_img = Image.fromarray(src_img).convert('RGBA')
  50. for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
  51. box_img = get_rotate_crop_image(ext_image, poly)
  52. src_img, box = self.paste_img(src_img, box_img, src_polys)
  53. if box is not None:
  54. box = box.tolist()
  55. for _ in range(len(box), point_num):
  56. box.append(box[-1])
  57. src_polys.append(box)
  58. src_texts.append(ext_texts[idx])
  59. src_ignores.append(tag)
  60. src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
  61. h, w = src_img.shape[:2]
  62. src_polys = np.array(src_polys)
  63. src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
  64. src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
  65. data['image'] = src_img
  66. data['polys'] = src_polys
  67. data['texts'] = src_texts
  68. data['ignore_tags'] = np.array(src_ignores)
  69. return data
  70. def paste_img(self, src_img, box_img, src_polys):
  71. box_img_pil = Image.fromarray(box_img).convert('RGBA')
  72. src_w, src_h = src_img.size
  73. box_w, box_h = box_img_pil.size
  74. angle = np.random.randint(0, 360)
  75. box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
  76. box = rotate_bbox(box_img, box, angle)[0]
  77. box_img_pil = box_img_pil.rotate(angle, expand=1)
  78. box_w, box_h = box_img_pil.width, box_img_pil.height
  79. if src_w - box_w < 0 or src_h - box_h < 0:
  80. return src_img, None
  81. paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
  82. src_h - box_h)
  83. if paste_x is None:
  84. return src_img, None
  85. box[:, 0] += paste_x
  86. box[:, 1] += paste_y
  87. r, g, b, A = box_img_pil.split()
  88. src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
  89. return src_img, box
  90. def select_coord(self, src_polys, box, endx, endy):
  91. if self.limit_paste:
  92. xmin, ymin, xmax, ymax = (
  93. box[:, 0].min(),
  94. box[:, 1].min(),
  95. box[:, 0].max(),
  96. box[:, 1].max(),
  97. )
  98. for _ in range(50):
  99. paste_x = random.randint(0, endx)
  100. paste_y = random.randint(0, endy)
  101. xmin1 = xmin + paste_x
  102. xmax1 = xmax + paste_x
  103. ymin1 = ymin + paste_y
  104. ymax1 = ymax + paste_y
  105. num_poly_in_rect = 0
  106. for poly in src_polys:
  107. if not is_poly_outside_rect(poly, xmin1, ymin1,
  108. xmax1 - xmin1, ymax1 - ymin1):
  109. num_poly_in_rect += 1
  110. break
  111. if num_poly_in_rect == 0:
  112. return paste_x, paste_y
  113. return None, None
  114. else:
  115. paste_x = random.randint(0, endx)
  116. paste_y = random.randint(0, endy)
  117. return paste_x, paste_y
  118. def get_union(pD, pG):
  119. return Polygon(pD).union(Polygon(pG)).area
  120. def get_intersection_over_union(pD, pG):
  121. return get_intersection(pD, pG) / get_union(pD, pG)
  122. def get_intersection(pD, pG):
  123. return Polygon(pD).intersection(Polygon(pG)).area
  124. def rotate_bbox(img, text_polys, angle, scale=1):
  125. """
  126. from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
  127. Args:
  128. img: np.ndarray
  129. text_polys: np.ndarray N*4*2
  130. angle: int
  131. scale: int
  132. Returns:
  133. """
  134. w = img.shape[1]
  135. h = img.shape[0]
  136. rangle = np.deg2rad(angle)
  137. nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
  138. nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
  139. rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
  140. rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
  141. rot_mat[0, 2] += rot_move[0]
  142. rot_mat[1, 2] += rot_move[1]
  143. # ---------------------- rotate box ----------------------
  144. rot_text_polys = list()
  145. for bbox in text_polys:
  146. point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
  147. point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
  148. point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
  149. point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
  150. rot_text_polys.append([point1, point2, point3, point4])
  151. return np.array(rot_text_polys, dtype=np.float32)