multi_scale_sampler.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import random
  2. import numpy as np
  3. import torch.distributed as dist
  4. from torch.utils.data import Sampler
  5. class MultiScaleSampler(Sampler):
  6. def __init__(
  7. self,
  8. data_source,
  9. scales,
  10. first_bs=128,
  11. fix_bs=True,
  12. divided_factor=[8, 16],
  13. is_training=True,
  14. ratio_wh=0.8,
  15. max_w=480.0,
  16. seed=None,
  17. ):
  18. """
  19. multi scale samper
  20. Args:
  21. data_source(dataset)
  22. scales(list): several scales for image resolution
  23. first_bs(int): batch size for the first scale in scales
  24. divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
  25. is_training(boolean): mode
  26. """
  27. # min. and max. spatial dimensions
  28. self.data_source = data_source
  29. self.data_idx_order_list = np.array(data_source.data_idx_order_list)
  30. self.ds_width = data_source.ds_width
  31. self.seed = data_source.seed
  32. if self.ds_width:
  33. self.wh_ratio = data_source.wh_ratio
  34. self.wh_ratio_sort = data_source.wh_ratio_sort
  35. self.n_data_samples = len(self.data_source)
  36. self.ratio_wh = ratio_wh
  37. self.max_w = max_w
  38. if isinstance(scales[0], list):
  39. width_dims = [i[0] for i in scales]
  40. height_dims = [i[1] for i in scales]
  41. elif isinstance(scales[0], int):
  42. width_dims = scales
  43. height_dims = scales
  44. base_im_w = width_dims[0]
  45. base_im_h = height_dims[0]
  46. base_batch_size = first_bs
  47. # Get the GPU and node related information
  48. if dist.is_initialized():
  49. num_replicas = dist.get_world_size()
  50. rank = dist.get_rank()
  51. else:
  52. num_replicas = 1
  53. rank = 0
  54. # adjust the total samples to avoid batch dropping
  55. num_samples_per_replica = int(self.n_data_samples * 1.0 / num_replicas)
  56. img_indices = [idx for idx in range(self.n_data_samples)]
  57. self.shuffle = False
  58. if is_training:
  59. # compute the spatial dimensions and corresponding batch size
  60. # ImageNet models down-sample images by a factor of 32.
  61. # Ensure that width and height dimensions are multiples are multiple of 32.
  62. width_dims = [
  63. int((w // divided_factor[0]) * divided_factor[0])
  64. for w in width_dims
  65. ]
  66. height_dims = [
  67. int((h // divided_factor[1]) * divided_factor[1])
  68. for h in height_dims
  69. ]
  70. img_batch_pairs = list()
  71. base_elements = base_im_w * base_im_h * base_batch_size
  72. for h, w in zip(height_dims, width_dims):
  73. if fix_bs:
  74. batch_size = base_batch_size
  75. else:
  76. batch_size = int(max(1, (base_elements / (h * w))))
  77. img_batch_pairs.append((w, h, batch_size))
  78. self.img_batch_pairs = img_batch_pairs
  79. self.shuffle = True
  80. else:
  81. self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
  82. self.img_indices = img_indices
  83. self.n_samples_per_replica = num_samples_per_replica
  84. self.epoch = 0
  85. self.rank = rank
  86. self.num_replicas = num_replicas
  87. self.batch_list = []
  88. self.current = 0
  89. last_index = num_samples_per_replica * num_replicas
  90. indices_rank_i = self.img_indices[self.rank:last_index:self.
  91. num_replicas]
  92. while self.current < self.n_samples_per_replica:
  93. for curr_w, curr_h, curr_bsz in self.img_batch_pairs:
  94. end_index = min(self.current + curr_bsz,
  95. self.n_samples_per_replica)
  96. batch_ids = indices_rank_i[self.current:end_index]
  97. n_batch_samples = len(batch_ids)
  98. if n_batch_samples != curr_bsz:
  99. batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
  100. self.current += curr_bsz
  101. if len(batch_ids) > 0:
  102. batch = [curr_w, curr_h, len(batch_ids)]
  103. self.batch_list.append(batch)
  104. random.shuffle(self.batch_list)
  105. self.length = len(self.batch_list)
  106. self.batchs_in_one_epoch = self.iter()
  107. self.batchs_in_one_epoch_id = [
  108. i for i in range(len(self.batchs_in_one_epoch))
  109. ]
  110. def __iter__(self):
  111. if self.seed is None:
  112. random.seed(self.epoch)
  113. self.epoch += 1
  114. else:
  115. random.seed(self.seed)
  116. random.shuffle(self.batchs_in_one_epoch_id)
  117. for batch_tuple_id in self.batchs_in_one_epoch_id:
  118. yield self.batchs_in_one_epoch[batch_tuple_id]
  119. def iter(self):
  120. if self.shuffle:
  121. if self.seed is not None:
  122. random.seed(self.seed)
  123. else:
  124. random.seed(self.epoch)
  125. if not self.ds_width:
  126. random.shuffle(self.img_indices)
  127. random.shuffle(self.img_batch_pairs)
  128. indices_rank_i = self.img_indices[
  129. self.rank:len(self.img_indices):self.num_replicas]
  130. else:
  131. indices_rank_i = self.img_indices[
  132. self.rank:len(self.img_indices):self.num_replicas]
  133. start_index = 0
  134. batchs_in_one_epoch = []
  135. for batch_tuple in self.batch_list:
  136. curr_w, curr_h, curr_bsz = batch_tuple
  137. end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
  138. batch_ids = indices_rank_i[start_index:end_index]
  139. n_batch_samples = len(batch_ids)
  140. if n_batch_samples != curr_bsz:
  141. batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
  142. start_index += curr_bsz
  143. if len(batch_ids) > 0:
  144. if self.ds_width:
  145. wh_ratio_current = self.wh_ratio[
  146. self.wh_ratio_sort[batch_ids]]
  147. ratio_current = wh_ratio_current.mean()
  148. ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h
  149. else:
  150. ratio_current = None
  151. batch = [(curr_w, curr_h, b_id, ratio_current)
  152. for b_id in batch_ids]
  153. # yield batch
  154. batchs_in_one_epoch.append(batch)
  155. return batchs_in_one_epoch
  156. def set_epoch(self, epoch: int):
  157. self.epoch = epoch
  158. def __len__(self):
  159. return self.length