lmdb_dataset.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import os
  2. import cv2
  3. import lmdb
  4. import numpy as np
  5. from torch.utils.data import Dataset
  6. from openrec.preprocess import create_operators, transform
  7. class LMDBDataSet(Dataset):
  8. def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'):
  9. super(LMDBDataSet, self).__init__()
  10. global_config = config['Global']
  11. dataset_config = config[mode]['dataset']
  12. loader_config = config[mode]['loader']
  13. loader_config['batch_size_per_card']
  14. data_dir = dataset_config['data_dir']
  15. self.do_shuffle = loader_config['shuffle']
  16. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
  17. logger.info(f'Initialize indexs of datasets: {data_dir}')
  18. self.data_idx_order_list = self.dataset_traversal()
  19. if self.do_shuffle:
  20. np.random.shuffle(self.data_idx_order_list)
  21. self.ops = create_operators(dataset_config['transforms'],
  22. global_config)
  23. self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx',
  24. 1)
  25. ratio_list = dataset_config.get('ratio_list', [1.0])
  26. self.need_reset = True in [x < 1 for x in ratio_list]
  27. def load_hierarchical_lmdb_dataset(self, data_dir):
  28. lmdb_sets = {}
  29. dataset_idx = 0
  30. for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
  31. if not dirnames:
  32. env = lmdb.open(
  33. dirpath,
  34. max_readers=32,
  35. readonly=True,
  36. lock=False,
  37. readahead=False,
  38. meminit=False,
  39. )
  40. txn = env.begin(write=False)
  41. num_samples = int(txn.get('num-samples'.encode()))
  42. lmdb_sets[dataset_idx] = {
  43. 'dirpath': dirpath,
  44. 'env': env,
  45. 'txn': txn,
  46. 'num_samples': num_samples,
  47. }
  48. dataset_idx += 1
  49. return lmdb_sets
  50. def dataset_traversal(self):
  51. lmdb_num = len(self.lmdb_sets)
  52. total_sample_num = 0
  53. for lno in range(lmdb_num):
  54. total_sample_num += self.lmdb_sets[lno]['num_samples']
  55. data_idx_order_list = np.zeros((total_sample_num, 2))
  56. beg_idx = 0
  57. for lno in range(lmdb_num):
  58. tmp_sample_num = self.lmdb_sets[lno]['num_samples']
  59. end_idx = beg_idx + tmp_sample_num
  60. data_idx_order_list[beg_idx:end_idx, 0] = lno
  61. data_idx_order_list[beg_idx:end_idx,
  62. 1] = list(range(tmp_sample_num))
  63. data_idx_order_list[beg_idx:end_idx, 1] += 1
  64. beg_idx = beg_idx + tmp_sample_num
  65. return data_idx_order_list
  66. def get_img_data(self, value):
  67. """get_img_data."""
  68. if not value:
  69. return None
  70. imgdata = np.frombuffer(value, dtype='uint8')
  71. if imgdata is None:
  72. return None
  73. imgori = cv2.imdecode(imgdata, 1)
  74. if imgori is None:
  75. return None
  76. return imgori
  77. def get_ext_data(self):
  78. ext_data_num = 0
  79. for op in self.ops:
  80. if hasattr(op, 'ext_data_num'):
  81. ext_data_num = getattr(op, 'ext_data_num')
  82. break
  83. load_data_ops = self.ops[:self.ext_op_transform_idx]
  84. ext_data = []
  85. while len(ext_data) < ext_data_num:
  86. lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
  87. len(self))]
  88. lmdb_idx = int(lmdb_idx)
  89. file_idx = int(file_idx)
  90. sample_info = self.get_lmdb_sample_info(
  91. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  92. if sample_info is None:
  93. continue
  94. img, label = sample_info
  95. data = {'image': img, 'label': label}
  96. data = transform(data, load_data_ops)
  97. if data is None:
  98. continue
  99. ext_data.append(data)
  100. return ext_data
  101. def get_lmdb_sample_info(self, txn, index):
  102. label_key = 'label-%09d'.encode() % index
  103. label = txn.get(label_key)
  104. if label is None:
  105. return None
  106. label = label.decode('utf-8')
  107. img_key = 'image-%09d'.encode() % index
  108. imgbuf = txn.get(img_key)
  109. return imgbuf, label
  110. def __getitem__(self, idx):
  111. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  112. lmdb_idx = int(lmdb_idx)
  113. file_idx = int(file_idx)
  114. sample_info = self.get_lmdb_sample_info(
  115. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  116. if sample_info is None:
  117. return self.__getitem__(np.random.randint(self.__len__()))
  118. img, label = sample_info
  119. data = {'image': img, 'label': label}
  120. data['ext_data'] = self.get_ext_data()
  121. outs = transform(data, self.ops)
  122. if outs is None:
  123. return self.__getitem__(np.random.randint(self.__len__()))
  124. return outs
  125. def __len__(self):
  126. return self.data_idx_order_list.shape[0]