lmdb_dataset_test.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import io
  2. import re
  3. import unicodedata
  4. import lmdb
  5. from PIL import Image
  6. from torch.utils.data import Dataset
  7. from openrec.preprocess import create_operators, transform
  8. class CharsetAdapter:
  9. """Transforms labels according to the target charset."""
  10. def __init__(self, target_charset) -> None:
  11. super().__init__()
  12. self.lowercase_only = target_charset == target_charset.lower()
  13. self.uppercase_only = target_charset == target_charset.upper()
  14. self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
  15. def __call__(self, label):
  16. if self.lowercase_only:
  17. label = label.lower()
  18. elif self.uppercase_only:
  19. label = label.upper()
  20. # Remove unsupported characters
  21. label = self.unsupported.sub('', label)
  22. return label
  23. class LMDBDataSetTest(Dataset):
  24. """Dataset interface to an LMDB database.
  25. It supports both labelled and unlabelled datasets. For unlabelled datasets,
  26. the image index itself is returned as the label. Unicode characters are
  27. normalized by default. Case-sensitivity is inferred from the charset.
  28. Labels are transformed according to the charset.
  29. """
  30. def __init__(self,
  31. config,
  32. mode,
  33. logger,
  34. seed=None,
  35. epoch=1,
  36. gpu_i=0,
  37. max_label_len: int = 25,
  38. min_image_dim: int = 0,
  39. remove_whitespace: bool = True,
  40. normalize_unicode: bool = True,
  41. unlabelled: bool = False,
  42. transform=None,
  43. task='rec'):
  44. dataset_config = config[mode]['dataset']
  45. global_config = config['Global']
  46. max_label_len = global_config['max_text_length']
  47. self.root = dataset_config['data_dir']
  48. self._env = None
  49. self.unlabelled = unlabelled
  50. self.transform = transform
  51. self.labels = []
  52. self.filtered_index_list = []
  53. self.min_image_dim = min_image_dim
  54. self.filter_label = dataset_config.get('filter_label',
  55. True) #'data_dir']filter_label
  56. character_dict_path = global_config.get('character_dict_path', None)
  57. use_space_char = global_config.get('use_space_char', False)
  58. if character_dict_path is None:
  59. char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
  60. else:
  61. char_test = ''
  62. with open(character_dict_path, 'rb') as fin:
  63. lines = fin.readlines()
  64. for line in lines:
  65. line = line.decode('utf-8').strip('\n').strip('\r\n')
  66. char_test += line
  67. if use_space_char:
  68. char_test += ' '
  69. self.ops = create_operators(dataset_config['transforms'],
  70. global_config)
  71. self.num_samples = self._preprocess_labels(char_test,
  72. remove_whitespace,
  73. normalize_unicode,
  74. max_label_len,
  75. min_image_dim)
  76. def __del__(self):
  77. if self._env is not None:
  78. self._env.close()
  79. self._env = None
  80. def _create_env(self):
  81. return lmdb.open(self.root,
  82. max_readers=1,
  83. readonly=True,
  84. create=False,
  85. readahead=False,
  86. meminit=False,
  87. lock=False)
  88. @property
  89. def env(self):
  90. if self._env is None:
  91. self._env = self._create_env()
  92. return self._env
  93. def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode,
  94. max_label_len, min_image_dim):
  95. charset_adapter = CharsetAdapter(charset)
  96. with self._create_env() as env, env.begin() as txn:
  97. num_samples = int(txn.get('num-samples'.encode()))
  98. if self.unlabelled:
  99. return num_samples
  100. for index in range(num_samples):
  101. index += 1 # lmdb starts with 1
  102. label_key = f'label-{index:09d}'.encode()
  103. label = txn.get(label_key).decode()
  104. # Normally, whitespace is removed from the labels.
  105. if remove_whitespace:
  106. label = ''.join(label.split())
  107. # Normalize unicode composites (if any) and convert to compatible ASCII characters
  108. if self.filter_label:
  109. # if normalize_unicode:
  110. label = unicodedata.normalize('NFKD', label).encode(
  111. 'ascii', 'ignore').decode()
  112. # Filter by length before removing unsupported characters. The original label might be too long.
  113. if len(label) > max_label_len:
  114. continue
  115. if self.filter_label:
  116. label = charset_adapter(label)
  117. # We filter out samples which don't contain any supported characters
  118. if not label:
  119. continue
  120. # Filter images that are too small.
  121. if min_image_dim > 0:
  122. img_key = f'image-{index:09d}'.encode()
  123. img = txn.get(img_key)
  124. data = {'image': img, 'label': label}
  125. outs = transform(data, self.ops)
  126. if outs is None:
  127. continue
  128. buf = io.BytesIO(img)
  129. w, h = Image.open(buf).size
  130. if w < self.min_image_dim or h < self.min_image_dim:
  131. continue
  132. self.labels.append(label)
  133. self.filtered_index_list.append(index)
  134. return len(self.labels)
  135. def __len__(self):
  136. return self.num_samples
  137. def __getitem__(self, index):
  138. if self.unlabelled:
  139. label = index
  140. else:
  141. label = self.labels[index]
  142. index = self.filtered_index_list[index]
  143. img_key = f'image-{index:09d}'.encode()
  144. with self.env.begin() as txn:
  145. img = txn.get(img_key)
  146. data = {'image': img, 'label': label}
  147. outs = transform(data, self.ops)
  148. return outs