rec_aug.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import random
  2. import cv2
  3. import numpy as np
  4. from PIL import Image
  5. from .parseq_aug import rand_augment_transform
  6. class PARSeqAugPIL(object):
  7. def __init__(self, **kwargs):
  8. self.transforms = rand_augment_transform()
  9. def __call__(self, data):
  10. img = data['image']
  11. img_aug = self.transforms(img)
  12. data['image'] = img_aug
  13. return data
  14. class PARSeqAug(object):
  15. def __init__(self, **kwargs):
  16. self.transforms = rand_augment_transform()
  17. def __call__(self, data):
  18. img = data['image']
  19. img = np.array(self.transforms(Image.fromarray(img)))
  20. data['image'] = img
  21. return data
  22. class ABINetAug(object):
  23. def __init__(self,
  24. geometry_p=0.5,
  25. deterioration_p=0.25,
  26. colorjitter_p=0.25,
  27. **kwargs):
  28. from torchvision.transforms import Compose
  29. from .abinet_aug import CVColorJitter, CVDeterioration, CVGeometry
  30. self.transforms = Compose([
  31. CVGeometry(
  32. degrees=45,
  33. translate=(0.0, 0.0),
  34. scale=(0.5, 2.0),
  35. shear=(45, 15),
  36. distortion=0.5,
  37. p=geometry_p,
  38. ),
  39. CVDeterioration(var=20, degrees=6, factor=4, p=deterioration_p),
  40. CVColorJitter(
  41. brightness=0.5,
  42. contrast=0.5,
  43. saturation=0.5,
  44. hue=0.1,
  45. p=colorjitter_p,
  46. ),
  47. ])
  48. def __call__(self, data):
  49. img = data['image']
  50. img = self.transforms(img)
  51. data['image'] = img
  52. return data
  53. class SVTRAug(object):
  54. def __init__(self,
  55. aug_type=0,
  56. geometry_p=0.5,
  57. deterioration_p=0.25,
  58. colorjitter_p=0.25,
  59. **kwargs):
  60. from torchvision.transforms import Compose
  61. from .abinet_aug import CVColorJitter, SVTRDeterioration, SVTRGeometry
  62. self.transforms = Compose([
  63. SVTRGeometry(
  64. aug_type=aug_type,
  65. degrees=45,
  66. translate=(0.0, 0.0),
  67. scale=(0.5, 2.0),
  68. shear=(45, 15),
  69. distortion=0.5,
  70. p=geometry_p,
  71. ),
  72. SVTRDeterioration(var=20, degrees=6, factor=4, p=deterioration_p),
  73. CVColorJitter(
  74. brightness=0.5,
  75. contrast=0.5,
  76. saturation=0.5,
  77. hue=0.1,
  78. p=colorjitter_p,
  79. ),
  80. ])
  81. def __call__(self, data):
  82. img = data['image']
  83. img = self.transforms(img)
  84. data['image'] = img
  85. return data
  86. class BaseDataAugmentation(object):
  87. def __init__(self,
  88. crop_prob=0.4,
  89. reverse_prob=0.4,
  90. noise_prob=0.4,
  91. jitter_prob=0.4,
  92. blur_prob=0.4,
  93. hsv_aug_prob=0.4,
  94. **kwargs):
  95. self.crop_prob = crop_prob
  96. self.reverse_prob = reverse_prob
  97. self.noise_prob = noise_prob
  98. self.jitter_prob = jitter_prob
  99. self.blur_prob = blur_prob
  100. self.hsv_aug_prob = hsv_aug_prob
  101. # for GaussianBlur
  102. self.fil = cv2.getGaussianKernel(ksize=5, sigma=1, ktype=cv2.CV_32F)
  103. def __call__(self, data):
  104. img = data['image']
  105. h, w, _ = img.shape
  106. if random.random() <= self.crop_prob and h >= 20 and w >= 20:
  107. img = get_crop(img)
  108. if random.random() <= self.blur_prob:
  109. # GaussianBlur
  110. img = cv2.sepFilter2D(img, -1, self.fil, self.fil)
  111. if random.random() <= self.hsv_aug_prob:
  112. img = hsv_aug(img)
  113. if random.random() <= self.jitter_prob:
  114. img = jitter(img)
  115. if random.random() <= self.noise_prob:
  116. img = add_gasuss_noise(img)
  117. if random.random() <= self.reverse_prob:
  118. img = 255 - img
  119. data['image'] = img
  120. return data
  121. def hsv_aug(img):
  122. """cvtColor."""
  123. hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
  124. delta = 0.001 * random.random() * flag()
  125. hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
  126. new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
  127. return new_img
  128. def blur(img):
  129. """blur."""
  130. h, w, _ = img.shape
  131. if h > 10 and w > 10:
  132. return cv2.GaussianBlur(img, (5, 5), 1)
  133. else:
  134. return img
  135. def jitter(img):
  136. """jitter."""
  137. w, h, _ = img.shape
  138. if h > 10 and w > 10:
  139. thres = min(w, h)
  140. s = int(random.random() * thres * 0.01)
  141. src_img = img.copy()
  142. for i in range(s):
  143. img[i:, i:, :] = src_img[:w - i, :h - i, :]
  144. return img
  145. else:
  146. return img
  147. def add_gasuss_noise(image, mean=0, var=0.1):
  148. """Gasuss noise."""
  149. noise = np.random.normal(mean, var**0.5, image.shape)
  150. out = image + 0.5 * noise
  151. out = np.clip(out, 0, 255)
  152. out = np.uint8(out)
  153. return out
  154. def get_crop(image):
  155. """random crop."""
  156. h, w, _ = image.shape
  157. top_min = 1
  158. top_max = 8
  159. top_crop = int(random.randint(top_min, top_max))
  160. top_crop = min(top_crop, h - 1)
  161. crop_img = image.copy()
  162. ratio = random.randint(0, 1)
  163. if ratio:
  164. crop_img = crop_img[top_crop:h, :, :]
  165. else:
  166. crop_img = crop_img[0:h - top_crop, :, :]
  167. return crop_img
  168. def flag():
  169. """flag."""
  170. return 1 if random.random() > 0.5000001 else -1