text_lmdb_dataset.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import os
  2. import lmdb
  3. import numpy as np
  4. from torch.utils.data import Dataset
  5. from openrec.preprocess import create_operators, transform
  6. class TextLMDBDataSet(Dataset):
  7. def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'):
  8. super(TextLMDBDataSet, self).__init__()
  9. global_config = config['Global']
  10. dataset_config = config[mode]['dataset']
  11. loader_config = config[mode]['loader']
  12. loader_config['batch_size_per_card']
  13. data_dir = dataset_config['data_dir']
  14. self.do_shuffle = loader_config['shuffle']
  15. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
  16. logger.info(f'Initialize indexs of datasets: {data_dir}')
  17. self.data_idx_order_list = self.dataset_traversal()
  18. if self.do_shuffle:
  19. np.random.shuffle(self.data_idx_order_list)
  20. self.ops = create_operators(dataset_config['transforms'],
  21. global_config)
  22. self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx',
  23. 1)
  24. ratio_list = dataset_config.get('ratio_list', [1.0])
  25. self.need_reset = True in [x < 1 for x in ratio_list]
  26. def load_hierarchical_lmdb_dataset(self, data_dir):
  27. lmdb_sets = {}
  28. dataset_idx = 0
  29. for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
  30. if not dirnames:
  31. env = lmdb.open(
  32. dirpath,
  33. max_readers=32,
  34. readonly=True,
  35. lock=False,
  36. readahead=False,
  37. meminit=False,
  38. )
  39. txn = env.begin(write=False)
  40. num_samples = int(txn.get('num-samples'.encode()))
  41. lmdb_sets[dataset_idx] = {
  42. 'dirpath': dirpath,
  43. 'env': env,
  44. 'txn': txn,
  45. 'num_samples': num_samples,
  46. }
  47. dataset_idx += 1
  48. return lmdb_sets
  49. def dataset_traversal(self):
  50. lmdb_num = len(self.lmdb_sets)
  51. total_sample_num = 0
  52. for lno in range(lmdb_num):
  53. total_sample_num += self.lmdb_sets[lno]['num_samples']
  54. data_idx_order_list = np.zeros((total_sample_num, 2))
  55. beg_idx = 0
  56. for lno in range(lmdb_num):
  57. tmp_sample_num = self.lmdb_sets[lno]['num_samples']
  58. end_idx = beg_idx + tmp_sample_num
  59. data_idx_order_list[beg_idx:end_idx, 0] = lno
  60. data_idx_order_list[beg_idx:end_idx,
  61. 1] = list(range(tmp_sample_num))
  62. data_idx_order_list[beg_idx:end_idx, 1] += 1
  63. beg_idx = beg_idx + tmp_sample_num
  64. return data_idx_order_list
  65. def get_ext_data(self):
  66. ext_data_num = 0
  67. for op in self.ops:
  68. if hasattr(op, 'ext_data_num'):
  69. ext_data_num = getattr(op, 'ext_data_num')
  70. break
  71. load_data_ops = self.ops[:self.ext_op_transform_idx]
  72. ext_data = []
  73. while len(ext_data) < ext_data_num:
  74. lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
  75. len(self))]
  76. lmdb_idx = int(lmdb_idx)
  77. file_idx = int(file_idx)
  78. sample_info = self.get_lmdb_sample_info(
  79. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  80. if sample_info is None:
  81. continue
  82. label = sample_info
  83. data = {'label': label}
  84. data = transform(data, load_data_ops)
  85. if data is None:
  86. continue
  87. ext_data.append(data)
  88. return ext_data
  89. def get_lmdb_sample_info(self,
  90. txn,
  91. index,
  92. normalize_unicode=True,
  93. remove_whitespace=True,
  94. max_length=True):
  95. label_key = 'label-%09d'.encode() % index
  96. label = txn.get(label_key)
  97. if label is None:
  98. return None
  99. label = label.decode('utf-8')
  100. return label
  101. def __getitem__(self, idx):
  102. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  103. lmdb_idx = int(lmdb_idx)
  104. file_idx = int(file_idx)
  105. sample_info = self.get_lmdb_sample_info(
  106. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  107. if sample_info is None:
  108. return self.__getitem__(np.random.randint(self.__len__()))
  109. label = sample_info
  110. data = {'label': label}
  111. outs = transform(data, self.ops)
  112. if outs is None:
  113. return self.__getitem__(np.random.randint(self.__len__()))
  114. return outs
  115. def __len__(self):
  116. return self.data_idx_order_list.shape[0]