strlmdb_dataset.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import os
  2. import cv2
  3. import lmdb
  4. import numpy as np
  5. from torch.utils.data import Dataset
  6. from openrec.preprocess import create_operators, transform
  7. class STRLMDBDataSet(Dataset):
  8. def __init__(self, config, mode, logger, seed=None, epoch=1, gpu_i=0):
  9. super(STRLMDBDataSet, self).__init__()
  10. global_config = config['Global']
  11. dataset_config = config[mode]['dataset']
  12. loader_config = config[mode]['loader']
  13. loader_config['batch_size_per_card']
  14. # data_dir = dataset_config['data_dir']
  15. data_dir = '../training_aug_lmdb_noerror/ep' + str(
  16. epoch % 20 if epoch % 20 != 0 else 20)
  17. self.do_shuffle = loader_config['shuffle']
  18. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
  19. logger.info('Initialize indexs of datasets:%s' % data_dir)
  20. self.data_idx_order_list = self.dataset_traversal()
  21. if self.do_shuffle:
  22. np.random.shuffle(self.data_idx_order_list)
  23. self.ops = create_operators(dataset_config['transforms'],
  24. global_config)
  25. self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx',
  26. 1)
  27. dataset_config.get('ratio_list', [1.0])
  28. self.need_reset = True # in [x < 1 for x in ratio_list]
  29. def load_hierarchical_lmdb_dataset(self, data_dir):
  30. lmdb_sets = {}
  31. dataset_idx = 0
  32. for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
  33. if not dirnames:
  34. env = lmdb.open(
  35. dirpath,
  36. max_readers=32,
  37. readonly=True,
  38. lock=False,
  39. readahead=False,
  40. meminit=False,
  41. )
  42. txn = env.begin(write=False)
  43. num_samples = int(txn.get('num-samples'.encode()))
  44. lmdb_sets[dataset_idx] = {
  45. 'dirpath': dirpath,
  46. 'env': env,
  47. 'txn': txn,
  48. 'num_samples': num_samples,
  49. }
  50. dataset_idx += 1
  51. return lmdb_sets
  52. def dataset_traversal(self):
  53. lmdb_num = len(self.lmdb_sets)
  54. total_sample_num = 0
  55. for lno in range(lmdb_num):
  56. total_sample_num += self.lmdb_sets[lno]['num_samples']
  57. data_idx_order_list = np.zeros((total_sample_num, 2))
  58. beg_idx = 0
  59. for lno in range(lmdb_num):
  60. tmp_sample_num = self.lmdb_sets[lno]['num_samples']
  61. end_idx = beg_idx + tmp_sample_num
  62. data_idx_order_list[beg_idx:end_idx, 0] = lno
  63. data_idx_order_list[beg_idx:end_idx,
  64. 1] = list(range(tmp_sample_num))
  65. data_idx_order_list[beg_idx:end_idx, 1] += 1
  66. beg_idx = beg_idx + tmp_sample_num
  67. return data_idx_order_list
  68. def get_img_data(self, value):
  69. """get_img_data."""
  70. if not value:
  71. return None
  72. imgdata = np.frombuffer(value, dtype='uint8')
  73. if imgdata is None:
  74. return None
  75. imgori = cv2.imdecode(imgdata, 1)
  76. if imgori is None:
  77. return None
  78. return imgori
  79. def get_ext_data(self):
  80. ext_data_num = 0
  81. for op in self.ops:
  82. if hasattr(op, 'ext_data_num'):
  83. ext_data_num = getattr(op, 'ext_data_num')
  84. break
  85. load_data_ops = self.ops[:self.ext_op_transform_idx]
  86. ext_data = []
  87. while len(ext_data) < ext_data_num:
  88. lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
  89. len(self))]
  90. lmdb_idx = int(lmdb_idx)
  91. file_idx = int(file_idx)
  92. sample_info = self.get_lmdb_sample_info(
  93. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  94. if sample_info is None:
  95. continue
  96. img, label = sample_info
  97. data = {'image': img, 'label': label}
  98. data = transform(data, load_data_ops)
  99. if data is None:
  100. continue
  101. ext_data.append(data)
  102. return ext_data
  103. def get_lmdb_sample_info(self, txn, index):
  104. label_key = 'label-%09d'.encode() % index
  105. label = txn.get(label_key)
  106. if label is None:
  107. return None
  108. label = label.decode('utf-8')
  109. img_key = 'image-%09d'.encode() % index
  110. imgbuf = txn.get(img_key)
  111. return imgbuf, label
  112. def __getitem__(self, idx):
  113. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  114. lmdb_idx = int(lmdb_idx)
  115. file_idx = int(file_idx)
  116. sample_info = self.get_lmdb_sample_info(
  117. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  118. if sample_info is None:
  119. return self.__getitem__(np.random.randint(self.__len__()))
  120. img, label = sample_info
  121. data = {'image': img, 'label': label}
  122. data['ext_data'] = self.get_ext_data()
  123. outs = transform(data, self.ops)
  124. if outs is None:
  125. return self.__getitem__(np.random.randint(self.__len__()))
  126. return outs
  127. def __len__(self):
  128. return self.data_idx_order_list.shape[0]