ratio_dataset.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import io
  2. import math
  3. import random
  4. import os
  5. import cv2
  6. import lmdb
  7. import numpy as np
  8. from PIL import Image
  9. from torch.utils.data import Dataset
  10. from openrec.preprocess import create_operators, transform
  11. class RatioDataSet(Dataset):
  12. def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'):
  13. super(RatioDataSet, self).__init__()
  14. self.ds_width = config[mode]['dataset'].get('ds_width', True)
  15. global_config = config['Global']
  16. dataset_config = config[mode]['dataset']
  17. loader_config = config[mode]['loader']
  18. max_ratio = loader_config.get('max_ratio', 10)
  19. min_ratio = loader_config.get('min_ratio', 1)
  20. syn = dataset_config.get('syn', False)
  21. if syn:
  22. data_dir_list = []
  23. data_dir = '../training_aug_lmdb_noerror/ep' + str(epoch)
  24. for dir_syn in os.listdir(data_dir):
  25. data_dir_list.append(data_dir + '/' + dir_syn)
  26. else:
  27. data_dir_list = dataset_config['data_dir_list']
  28. self.padding = dataset_config.get('padding', True)
  29. self.padding_rand = dataset_config.get('padding_rand', False)
  30. self.padding_doub = dataset_config.get('padding_doub', False)
  31. self.do_shuffle = loader_config['shuffle']
  32. self.seed = epoch
  33. data_source_num = len(data_dir_list)
  34. ratio_list = dataset_config.get('ratio_list', 1.0)
  35. if isinstance(ratio_list, (float, int)):
  36. ratio_list = [float(ratio_list)] * int(data_source_num)
  37. assert (
  38. len(ratio_list) == data_source_num
  39. ), 'The length of ratio_list should be the same as the file_list.'
  40. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
  41. data_dir_list, ratio_list)
  42. for data_dir in data_dir_list:
  43. logger.info('Initialize indexs of datasets:%s' % data_dir)
  44. self.logger = logger
  45. self.data_idx_order_list = self.dataset_traversal()
  46. wh_ratio = np.around(np.array(self.get_wh_ratio()))
  47. self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
  48. for i in range(max_ratio + 1):
  49. logger.info((1 * (self.wh_ratio == i)).sum())
  50. self.wh_ratio_sort = np.argsort(self.wh_ratio)
  51. self.ops = create_operators(dataset_config['transforms'],
  52. global_config)
  53. self.need_reset = True in [x < 1 for x in ratio_list]
  54. self.error = 0
  55. self.base_shape = dataset_config.get(
  56. 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
  57. self.base_h = 32
  58. def get_wh_ratio(self):
  59. wh_ratio = []
  60. for idx in range(self.data_idx_order_list.shape[0]):
  61. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  62. lmdb_idx = int(lmdb_idx)
  63. file_idx = int(file_idx)
  64. wh_key = 'wh-%09d'.encode() % file_idx
  65. wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
  66. if wh is None:
  67. img_key = f'image-{file_idx:09d}'.encode()
  68. img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
  69. buf = io.BytesIO(img)
  70. w, h = Image.open(buf).size
  71. else:
  72. wh = wh.decode('utf-8')
  73. w, h = wh.split('_')
  74. wh_ratio.append(float(w) / float(h))
  75. return wh_ratio
  76. def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
  77. lmdb_sets = {}
  78. dataset_idx = 0
  79. for dirpath, ratio in zip(data_dir_list, ratio_list):
  80. env = lmdb.open(dirpath,
  81. max_readers=32,
  82. readonly=True,
  83. lock=False,
  84. readahead=False,
  85. meminit=False)
  86. txn = env.begin(write=False)
  87. num_samples = int(txn.get('num-samples'.encode()))
  88. lmdb_sets[dataset_idx] = {
  89. 'dirpath': dirpath,
  90. 'env': env,
  91. 'txn': txn,
  92. 'num_samples': num_samples,
  93. 'ratio_num_samples': int(ratio * num_samples)
  94. }
  95. dataset_idx += 1
  96. return lmdb_sets
  97. def dataset_traversal(self):
  98. lmdb_num = len(self.lmdb_sets)
  99. total_sample_num = 0
  100. for lno in range(lmdb_num):
  101. total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
  102. data_idx_order_list = np.zeros((total_sample_num, 2))
  103. beg_idx = 0
  104. for lno in range(lmdb_num):
  105. tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
  106. end_idx = beg_idx + tmp_sample_num
  107. data_idx_order_list[beg_idx:end_idx, 0] = lno
  108. data_idx_order_list[beg_idx:end_idx, 1] = list(
  109. random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
  110. self.lmdb_sets[lno]['ratio_num_samples']))
  111. beg_idx = beg_idx + tmp_sample_num
  112. return data_idx_order_list
  113. def get_img_data(self, value):
  114. """get_img_data."""
  115. if not value:
  116. return None
  117. imgdata = np.frombuffer(value, dtype='uint8')
  118. if imgdata is None:
  119. return None
  120. imgori = cv2.imdecode(imgdata, 1)
  121. if imgori is None:
  122. return None
  123. return imgori
  124. def resize_norm_img(self, data, gen_ratio, padding=True):
  125. img = data['image']
  126. h = img.shape[0]
  127. w = img.shape[1]
  128. if self.padding_rand and random.random() < 0.5:
  129. padding = not padding
  130. imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
  131. self.base_h * gen_ratio, self.base_h
  132. ]
  133. use_ratio = imgW // imgH
  134. if use_ratio >= (w // h) + 2:
  135. self.error += 1
  136. return None
  137. if not padding:
  138. resized_image = cv2.resize(img, (imgW, imgH),
  139. interpolation=cv2.INTER_LINEAR)
  140. resized_w = imgW
  141. else:
  142. ratio = w / float(h)
  143. if math.ceil(imgH * ratio) > imgW:
  144. resized_w = imgW
  145. else:
  146. resized_w = int(
  147. math.ceil(imgH * ratio * (random.random() + 0.5)))
  148. resized_w = min(imgW, resized_w)
  149. resized_image = cv2.resize(img, (resized_w, imgH))
  150. resized_image = resized_image.astype('float32')
  151. resized_image = resized_image.transpose((2, 0, 1)) / 255
  152. resized_image -= 0.5
  153. resized_image /= 0.5
  154. padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
  155. if self.padding_doub and random.random() < 0.5:
  156. padding_im[:, :, -resized_w:] = resized_image
  157. else:
  158. padding_im[:, :, :resized_w] = resized_image
  159. valid_ratio = min(1.0, float(resized_w / imgW))
  160. data['image'] = padding_im
  161. data['valid_ratio'] = valid_ratio
  162. data['real_ratio'] = round(w / h)
  163. return data
  164. def get_lmdb_sample_info(self, txn, index):
  165. label_key = 'label-%09d'.encode() % index
  166. label = txn.get(label_key)
  167. if label is None:
  168. return None
  169. label = label.decode('utf-8')
  170. img_key = 'image-%09d'.encode() % index
  171. imgbuf = txn.get(img_key)
  172. return imgbuf, label
  173. def __getitem__(self, properties):
  174. img_width = properties[0]
  175. img_height = properties[1]
  176. idx = properties[2]
  177. ratio = properties[3]
  178. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  179. lmdb_idx = int(lmdb_idx)
  180. file_idx = int(file_idx)
  181. sample_info = self.get_lmdb_sample_info(
  182. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  183. if sample_info is None:
  184. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  185. ids = random.sample(ratio_ids, 1)
  186. return self.__getitem__([img_width, img_height, ids[0], ratio])
  187. img, label = sample_info
  188. data = {'image': img, 'label': label}
  189. outs = transform(data, self.ops[:-1])
  190. if outs is not None:
  191. outs = self.resize_norm_img(outs, ratio, padding=self.padding)
  192. if outs is None:
  193. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  194. ids = random.sample(ratio_ids, 1)
  195. return self.__getitem__([img_width, img_height, ids[0], ratio])
  196. outs = transform(outs, self.ops[-1:])
  197. if outs is None:
  198. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  199. ids = random.sample(ratio_ids, 1)
  200. return self.__getitem__([img_width, img_height, ids[0], ratio])
  201. return outs
  202. def __len__(self):
  203. return self.data_idx_order_list.shape[0]