123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- import random
- import numpy as np
- import torch.distributed as dist
- from torch.utils.data import Sampler
- class MultiScaleSampler(Sampler):
- def __init__(
- self,
- data_source,
- scales,
- first_bs=128,
- fix_bs=True,
- divided_factor=[8, 16],
- is_training=True,
- ratio_wh=0.8,
- max_w=480.0,
- seed=None,
- ):
- """
- multi scale samper
- Args:
- data_source(dataset)
- scales(list): several scales for image resolution
- first_bs(int): batch size for the first scale in scales
- divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
- is_training(boolean): mode
- """
- # min. and max. spatial dimensions
- self.data_source = data_source
- self.data_idx_order_list = np.array(data_source.data_idx_order_list)
- self.ds_width = data_source.ds_width
- self.seed = data_source.seed
- if self.ds_width:
- self.wh_ratio = data_source.wh_ratio
- self.wh_ratio_sort = data_source.wh_ratio_sort
- self.n_data_samples = len(self.data_source)
- self.ratio_wh = ratio_wh
- self.max_w = max_w
- if isinstance(scales[0], list):
- width_dims = [i[0] for i in scales]
- height_dims = [i[1] for i in scales]
- elif isinstance(scales[0], int):
- width_dims = scales
- height_dims = scales
- base_im_w = width_dims[0]
- base_im_h = height_dims[0]
- base_batch_size = first_bs
- # Get the GPU and node related information
- if dist.is_initialized():
- num_replicas = dist.get_world_size()
- rank = dist.get_rank()
- else:
- num_replicas = 1
- rank = 0
- # adjust the total samples to avoid batch dropping
- num_samples_per_replica = int(self.n_data_samples * 1.0 / num_replicas)
- img_indices = [idx for idx in range(self.n_data_samples)]
- self.shuffle = False
- if is_training:
- # compute the spatial dimensions and corresponding batch size
- # ImageNet models down-sample images by a factor of 32.
- # Ensure that width and height dimensions are multiples are multiple of 32.
- width_dims = [
- int((w // divided_factor[0]) * divided_factor[0])
- for w in width_dims
- ]
- height_dims = [
- int((h // divided_factor[1]) * divided_factor[1])
- for h in height_dims
- ]
- img_batch_pairs = list()
- base_elements = base_im_w * base_im_h * base_batch_size
- for h, w in zip(height_dims, width_dims):
- if fix_bs:
- batch_size = base_batch_size
- else:
- batch_size = int(max(1, (base_elements / (h * w))))
- img_batch_pairs.append((w, h, batch_size))
- self.img_batch_pairs = img_batch_pairs
- self.shuffle = True
- else:
- self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
- self.img_indices = img_indices
- self.n_samples_per_replica = num_samples_per_replica
- self.epoch = 0
- self.rank = rank
- self.num_replicas = num_replicas
- self.batch_list = []
- self.current = 0
- last_index = num_samples_per_replica * num_replicas
- indices_rank_i = self.img_indices[self.rank:last_index:self.
- num_replicas]
- while self.current < self.n_samples_per_replica:
- for curr_w, curr_h, curr_bsz in self.img_batch_pairs:
- end_index = min(self.current + curr_bsz,
- self.n_samples_per_replica)
- batch_ids = indices_rank_i[self.current:end_index]
- n_batch_samples = len(batch_ids)
- if n_batch_samples != curr_bsz:
- batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
- self.current += curr_bsz
- if len(batch_ids) > 0:
- batch = [curr_w, curr_h, len(batch_ids)]
- self.batch_list.append(batch)
- random.shuffle(self.batch_list)
- self.length = len(self.batch_list)
- self.batchs_in_one_epoch = self.iter()
- self.batchs_in_one_epoch_id = [
- i for i in range(len(self.batchs_in_one_epoch))
- ]
- def __iter__(self):
- if self.seed is None:
- random.seed(self.epoch)
- self.epoch += 1
- else:
- random.seed(self.seed)
- random.shuffle(self.batchs_in_one_epoch_id)
- for batch_tuple_id in self.batchs_in_one_epoch_id:
- yield self.batchs_in_one_epoch[batch_tuple_id]
- def iter(self):
- if self.shuffle:
- if self.seed is not None:
- random.seed(self.seed)
- else:
- random.seed(self.epoch)
- if not self.ds_width:
- random.shuffle(self.img_indices)
- random.shuffle(self.img_batch_pairs)
- indices_rank_i = self.img_indices[
- self.rank:len(self.img_indices):self.num_replicas]
- else:
- indices_rank_i = self.img_indices[
- self.rank:len(self.img_indices):self.num_replicas]
- start_index = 0
- batchs_in_one_epoch = []
- for batch_tuple in self.batch_list:
- curr_w, curr_h, curr_bsz = batch_tuple
- end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
- batch_ids = indices_rank_i[start_index:end_index]
- n_batch_samples = len(batch_ids)
- if n_batch_samples != curr_bsz:
- batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
- start_index += curr_bsz
- if len(batch_ids) > 0:
- if self.ds_width:
- wh_ratio_current = self.wh_ratio[
- self.wh_ratio_sort[batch_ids]]
- ratio_current = wh_ratio_current.mean()
- ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h
- else:
- ratio_current = None
- batch = [(curr_w, curr_h, b_id, ratio_current)
- for b_id in batch_ids]
- # yield batch
- batchs_in_one_epoch.append(batch)
- return batchs_in_one_epoch
- def set_epoch(self, epoch: int):
- self.epoch = epoch
- def __len__(self):
- return self.length
|