ratio_dataset_tvresize.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import io
  2. import math
  3. import random
  4. import cv2
  5. import lmdb
  6. import numpy as np
  7. from PIL import Image
  8. from torch.utils.data import Dataset
  9. from torchvision import transforms as T
  10. from torchvision.transforms import functional as F
  11. from openrec.preprocess import create_operators, transform
  12. class RatioDataSetTVResize(Dataset):
  13. def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'):
  14. super(RatioDataSetTVResize, self).__init__()
  15. self.ds_width = config[mode]['dataset'].get('ds_width', True)
  16. global_config = config['Global']
  17. dataset_config = config[mode]['dataset']
  18. loader_config = config[mode]['loader']
  19. max_ratio = loader_config.get('max_ratio', 10)
  20. min_ratio = loader_config.get('min_ratio', 1)
  21. data_dir_list = dataset_config['data_dir_list']
  22. self.padding = dataset_config.get('padding', True)
  23. self.padding_rand = dataset_config.get('padding_rand', False)
  24. self.padding_doub = dataset_config.get('padding_doub', False)
  25. self.do_shuffle = loader_config['shuffle']
  26. self.seed = epoch
  27. data_source_num = len(data_dir_list)
  28. ratio_list = dataset_config.get('ratio_list', 1.0)
  29. if isinstance(ratio_list, (float, int)):
  30. ratio_list = [float(ratio_list)] * int(data_source_num)
  31. assert (
  32. len(ratio_list) == data_source_num
  33. ), 'The length of ratio_list should be the same as the file_list.'
  34. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
  35. data_dir_list, ratio_list)
  36. for data_dir in data_dir_list:
  37. logger.info('Initialize indexs of datasets:%s' % data_dir)
  38. self.logger = logger
  39. self.data_idx_order_list = self.dataset_traversal()
  40. wh_ratio = np.around(np.array(self.get_wh_ratio()))
  41. self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
  42. for i in range(max_ratio + 1):
  43. logger.info((1 * (self.wh_ratio == i)).sum())
  44. self.wh_ratio_sort = np.argsort(self.wh_ratio)
  45. self.ops = create_operators(dataset_config['transforms'],
  46. global_config)
  47. self.need_reset = True in [x < 1 for x in ratio_list]
  48. self.error = 0
  49. self.base_shape = dataset_config.get(
  50. 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
  51. self.base_h = dataset_config.get('base_h', 32)
  52. self.interpolation = T.InterpolationMode.BICUBIC
  53. transforms = []
  54. transforms.extend([
  55. T.ToTensor(),
  56. T.Normalize(0.5, 0.5),
  57. ])
  58. self.transforms = T.Compose(transforms)
  59. def get_wh_ratio(self):
  60. wh_ratio = []
  61. for idx in range(self.data_idx_order_list.shape[0]):
  62. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  63. lmdb_idx = int(lmdb_idx)
  64. file_idx = int(file_idx)
  65. wh_key = 'wh-%09d'.encode() % file_idx
  66. wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
  67. if wh is None:
  68. img_key = f'image-{file_idx:09d}'.encode()
  69. img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
  70. buf = io.BytesIO(img)
  71. w, h = Image.open(buf).size
  72. else:
  73. wh = wh.decode('utf-8')
  74. w, h = wh.split('_')
  75. wh_ratio.append(float(w) / float(h))
  76. return wh_ratio
  77. def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
  78. lmdb_sets = {}
  79. dataset_idx = 0
  80. for dirpath, ratio in zip(data_dir_list, ratio_list):
  81. env = lmdb.open(dirpath,
  82. max_readers=32,
  83. readonly=True,
  84. lock=False,
  85. readahead=False,
  86. meminit=False)
  87. txn = env.begin(write=False)
  88. num_samples = int(txn.get('num-samples'.encode()))
  89. lmdb_sets[dataset_idx] = {
  90. 'dirpath': dirpath,
  91. 'env': env,
  92. 'txn': txn,
  93. 'num_samples': num_samples,
  94. 'ratio_num_samples': int(ratio * num_samples)
  95. }
  96. dataset_idx += 1
  97. return lmdb_sets
  98. def dataset_traversal(self):
  99. lmdb_num = len(self.lmdb_sets)
  100. total_sample_num = 0
  101. for lno in range(lmdb_num):
  102. total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
  103. data_idx_order_list = np.zeros((total_sample_num, 2))
  104. beg_idx = 0
  105. for lno in range(lmdb_num):
  106. tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
  107. end_idx = beg_idx + tmp_sample_num
  108. data_idx_order_list[beg_idx:end_idx, 0] = lno
  109. data_idx_order_list[beg_idx:end_idx, 1] = list(
  110. random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
  111. self.lmdb_sets[lno]['ratio_num_samples']))
  112. beg_idx = beg_idx + tmp_sample_num
  113. return data_idx_order_list
  114. def get_img_data(self, value):
  115. """get_img_data."""
  116. if not value:
  117. return None
  118. imgdata = np.frombuffer(value, dtype='uint8')
  119. if imgdata is None:
  120. return None
  121. imgori = cv2.imdecode(imgdata, 1)
  122. if imgori is None:
  123. return None
  124. return imgori
  125. def resize_norm_img(self, data, gen_ratio, padding=True):
  126. img = data['image']
  127. w, h = img.size
  128. if self.padding_rand and random.random() < 0.5:
  129. padding = not padding
  130. imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
  131. self.base_h * gen_ratio, self.base_h
  132. ]
  133. use_ratio = imgW // imgH
  134. if use_ratio >= (w // h) + 2:
  135. self.error += 1
  136. return None
  137. if not padding:
  138. resized_w = imgW
  139. else:
  140. ratio = w / float(h)
  141. if math.ceil(imgH * ratio) > imgW:
  142. resized_w = imgW
  143. else:
  144. resized_w = int(
  145. math.ceil(imgH * ratio * (random.random() + 0.5)))
  146. resized_w = min(imgW, resized_w)
  147. resized_image = F.resize(img, (imgH, resized_w),
  148. interpolation=self.interpolation)
  149. img = self.transforms(resized_image)
  150. if resized_w < imgW and padding:
  151. # img = F.pad(img, [0, 0, imgW-resized_w, 0], fill=0.)
  152. if self.padding_doub and random.random() < 0.5:
  153. img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
  154. else:
  155. img = F.pad(img, [imgW - resized_w, 0, 0, 0], fill=0.)
  156. valid_ratio = min(1.0, float(resized_w / imgW))
  157. data['image'] = img
  158. data['valid_ratio'] = valid_ratio
  159. return data
  160. def get_lmdb_sample_info(self, txn, index):
  161. label_key = 'label-%09d'.encode() % index
  162. label = txn.get(label_key)
  163. if label is None:
  164. return None
  165. label = label.decode('utf-8')
  166. img_key = 'image-%09d'.encode() % index
  167. imgbuf = txn.get(img_key)
  168. return imgbuf, label
  169. def __getitem__(self, properties):
  170. img_width = properties[0]
  171. img_height = properties[1]
  172. idx = properties[2]
  173. ratio = properties[3]
  174. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  175. lmdb_idx = int(lmdb_idx)
  176. file_idx = int(file_idx)
  177. sample_info = self.get_lmdb_sample_info(
  178. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  179. if sample_info is None:
  180. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  181. ids = random.sample(ratio_ids, 1)
  182. return self.__getitem__([img_width, img_height, ids[0], ratio])
  183. img, label = sample_info
  184. data = {'image': img, 'label': label}
  185. outs = transform(data, self.ops[:-1])
  186. if outs is not None:
  187. outs = self.resize_norm_img(outs, ratio, padding=self.padding)
  188. if outs is None:
  189. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  190. ids = random.sample(ratio_ids, 1)
  191. return self.__getitem__([img_width, img_height, ids[0], ratio])
  192. outs = transform(outs, self.ops[-1:])
  193. if outs is None:
  194. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  195. ids = random.sample(ratio_ids, 1)
  196. return self.__getitem__([img_width, img_height, ids[0], ratio])
  197. return outs
  198. def __len__(self):
  199. return self.data_idx_order_list.shape[0]