ratio_sampler.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import math
  2. import os
  3. import random
  4. import numpy as np
  5. import torch
  6. from torch.utils.data import Sampler
  7. class RatioSampler(Sampler):
  8. def __init__(self,
  9. data_source,
  10. scales,
  11. first_bs=512,
  12. fix_bs=True,
  13. divided_factor=[8, 16],
  14. is_training=True,
  15. max_ratio=10,
  16. max_bs=1024,
  17. seed=None):
  18. """
  19. multi scale samper
  20. Args:
  21. data_source(dataset)
  22. scales(list): several scales for image resolution
  23. first_bs(int): batch size for the first scale in scales
  24. divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
  25. is_training(boolean): mode
  26. """
  27. # min. and max. spatial dimensions
  28. self.data_source = data_source
  29. # self.data_idx_order_list = np.array(data_source.data_idx_order_list)
  30. self.ds_width = data_source.ds_width
  31. self.seed = data_source.seed
  32. if self.ds_width:
  33. self.wh_ratio = data_source.wh_ratio
  34. self.wh_ratio_sort = data_source.wh_ratio_sort
  35. self.n_data_samples = len(self.data_source)
  36. self.max_ratio = max_ratio
  37. self.max_bs = max_bs
  38. if isinstance(scales[0], list):
  39. width_dims = [i[0] for i in scales]
  40. height_dims = [i[1] for i in scales]
  41. elif isinstance(scales[0], int):
  42. width_dims = scales
  43. height_dims = scales
  44. base_im_w = width_dims[0]
  45. base_im_h = height_dims[0]
  46. base_batch_size = first_bs
  47. base_elements = base_im_w * base_im_h * base_batch_size
  48. self.base_elements = base_elements
  49. self.base_batch_size = base_batch_size
  50. self.base_im_h = base_im_h
  51. self.base_im_w = base_im_w
  52. # Get the GPU and node related information
  53. num_replicas = torch.cuda.device_count() if torch.cuda.is_available() else 1
  54. # rank = dist.get_rank()
  55. rank = (int(os.environ['LOCAL_RANK'])
  56. if 'LOCAL_RANK' in os.environ else 0)
  57. # self.rank = rank
  58. # adjust the total samples to avoid batch dropping
  59. num_samples_per_replica = int(
  60. math.ceil(self.n_data_samples * 1.0 / num_replicas))
  61. img_indices = [idx for idx in range(self.n_data_samples)]
  62. self.shuffle = False
  63. if is_training:
  64. # compute the spatial dimensions and corresponding batch size
  65. # ImageNet models down-sample images by a factor of 32.
  66. # Ensure that width and height dimensions are multiples are multiple of 32.
  67. width_dims = [
  68. int((w // divided_factor[0]) * divided_factor[0])
  69. for w in width_dims
  70. ]
  71. height_dims = [
  72. int((h // divided_factor[1]) * divided_factor[1])
  73. for h in height_dims
  74. ]
  75. img_batch_pairs = list()
  76. for (h, w) in zip(height_dims, width_dims):
  77. if fix_bs:
  78. batch_size = base_batch_size
  79. else:
  80. batch_size = int(max(1, (base_elements / (h * w))))
  81. img_batch_pairs.append((w, h, batch_size))
  82. self.img_batch_pairs = img_batch_pairs
  83. self.shuffle = True
  84. np.random.seed(seed)
  85. random.seed(seed)
  86. else:
  87. self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
  88. self.img_indices = img_indices
  89. self.n_samples_per_replica = num_samples_per_replica
  90. self.epoch = 0
  91. self.rank = rank
  92. self.num_replicas = num_replicas
  93. # self.batch_list = []
  94. self.current = 0
  95. self.is_training = is_training
  96. if is_training:
  97. indices_rank_i = self.img_indices[
  98. self.rank:len(self.img_indices):self.num_replicas]
  99. else:
  100. indices_rank_i = self.img_indices
  101. self.indices_rank_i_ori = np.array(self.wh_ratio_sort[indices_rank_i])
  102. self.indices_rank_i_ratio = self.wh_ratio[self.indices_rank_i_ori]
  103. indices_rank_i_ratio_unique = np.unique(self.indices_rank_i_ratio)
  104. self.indices_rank_i_ratio_unique = indices_rank_i_ratio_unique.tolist()
  105. self.batch_list = self.create_batch()
  106. self.length = len(self.batch_list)
  107. self.batchs_in_one_epoch_id = [i for i in range(self.length)]
  108. def create_batch(self):
  109. batch_list = []
  110. for ratio in self.indices_rank_i_ratio_unique:
  111. ratio_ids = np.where(self.indices_rank_i_ratio == ratio)[0]
  112. ratio_ids = self.indices_rank_i_ori[ratio_ids]
  113. if self.shuffle:
  114. random.shuffle(ratio_ids)
  115. num_ratio = ratio_ids.shape[0]
  116. if ratio < 5:
  117. batch_size_ratio = self.base_batch_size
  118. else:
  119. batch_size_ratio = min(
  120. self.max_bs,
  121. int(
  122. max(1, (self.base_elements /
  123. (self.base_im_h * ratio * self.base_im_h)))))
  124. if num_ratio > batch_size_ratio:
  125. batch_num_ratio = num_ratio // batch_size_ratio
  126. print(self.rank, num_ratio, ratio * self.base_im_h,
  127. batch_num_ratio, batch_size_ratio)
  128. ratio_ids_full = ratio_ids[:batch_num_ratio *
  129. batch_size_ratio].reshape(
  130. batch_num_ratio,
  131. batch_size_ratio, 1)
  132. w = np.full_like(ratio_ids_full, ratio * self.base_im_h)
  133. h = np.full_like(ratio_ids_full, self.base_im_h)
  134. ra_wh = np.full_like(ratio_ids_full, ratio)
  135. ratio_ids_full = np.concatenate([w, h, ratio_ids_full, ra_wh],
  136. axis=-1)
  137. batch_ratio = ratio_ids_full.tolist()
  138. if batch_num_ratio * batch_size_ratio < num_ratio:
  139. drop = ratio_ids[batch_num_ratio * batch_size_ratio:]
  140. if self.is_training:
  141. drop_full = ratio_ids[:batch_size_ratio - (
  142. num_ratio - batch_num_ratio * batch_size_ratio)]
  143. drop = np.append(drop_full, drop)
  144. drop = drop.reshape(-1, 1)
  145. w = np.full_like(drop, ratio * self.base_im_h)
  146. h = np.full_like(drop, self.base_im_h)
  147. ra_wh = np.full_like(drop, ratio)
  148. drop = np.concatenate([w, h, drop, ra_wh], axis=-1)
  149. batch_ratio.append(drop.tolist())
  150. batch_list += batch_ratio
  151. else:
  152. print(self.rank, num_ratio, ratio * self.base_im_h,
  153. batch_size_ratio)
  154. ratio_ids = ratio_ids.reshape(-1, 1)
  155. w = np.full_like(ratio_ids, ratio * self.base_im_h)
  156. h = np.full_like(ratio_ids, self.base_im_h)
  157. ra_wh = np.full_like(ratio_ids, ratio)
  158. ratio_ids = np.concatenate([w, h, ratio_ids, ra_wh], axis=-1)
  159. batch_list.append(ratio_ids.tolist())
  160. return batch_list
  161. def __iter__(self):
  162. if self.shuffle or self.is_training:
  163. random.seed(self.epoch)
  164. self.epoch += 1
  165. self.batch_list = self.create_batch()
  166. random.shuffle(self.batchs_in_one_epoch_id)
  167. for batch_tuple_id in self.batchs_in_one_epoch_id:
  168. yield self.batch_list[batch_tuple_id]
  169. def set_epoch(self, epoch: int):
  170. self.epoch = epoch
  171. def __len__(self):
  172. return self.length