ratio_dataset_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import io
  2. import math
  3. import random
  4. import re
  5. import unicodedata
  6. import cv2
  7. import lmdb
  8. import numpy as np
  9. from PIL import Image
  10. from torch.utils.data import Dataset
  11. from openrec.preprocess import create_operators, transform
  12. class CharsetAdapter:
  13. """Transforms labels according to the target charset."""
  14. def __init__(self, target_charset) -> None:
  15. super().__init__()
  16. self.lowercase_only = target_charset == target_charset.lower()
  17. self.uppercase_only = target_charset == target_charset.upper()
  18. self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
  19. def __call__(self, label):
  20. if self.lowercase_only:
  21. label = label.lower()
  22. elif self.uppercase_only:
  23. label = label.upper()
  24. # Remove unsupported characters
  25. label = self.unsupported.sub('', label)
  26. return label
  27. class RatioDataSetTest(Dataset):
  28. def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'):
  29. super(RatioDataSetTest, self).__init__()
  30. self.ds_width = config[mode]['dataset'].get('ds_width', True)
  31. global_config = config['Global']
  32. dataset_config = config[mode]['dataset']
  33. loader_config = config[mode]['loader']
  34. max_ratio = loader_config.get('max_ratio', 10)
  35. min_ratio = loader_config.get('min_ratio', 1)
  36. data_dir_list = dataset_config['data_dir_list']
  37. self.do_shuffle = loader_config['shuffle']
  38. self.seed = epoch
  39. self.max_text_length = global_config['max_text_length']
  40. data_source_num = len(data_dir_list)
  41. ratio_list = dataset_config.get('ratio_list', 1.0)
  42. if isinstance(ratio_list, (float, int)):
  43. ratio_list = [float(ratio_list)] * int(data_source_num)
  44. assert len(
  45. ratio_list
  46. ) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
  47. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
  48. data_dir_list, ratio_list)
  49. for data_dir in data_dir_list:
  50. logger.info('Initialize indexs of datasets:%s' % data_dir)
  51. self.logger = logger
  52. data_idx_order_list = self.dataset_traversal()
  53. character_dict_path = global_config.get('character_dict_path', None)
  54. use_space_char = global_config.get('use_space_char', False)
  55. if character_dict_path is None:
  56. char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
  57. else:
  58. char_test = ''
  59. with open(character_dict_path, 'rb') as fin:
  60. lines = fin.readlines()
  61. for line in lines:
  62. line = line.decode('utf-8').strip('\n').strip('\r\n')
  63. char_test += line
  64. if use_space_char:
  65. char_test += ' '
  66. wh_ratio, data_idx_order_list = self.get_wh_ratio(
  67. data_idx_order_list, char_test)
  68. self.data_idx_order_list = np.array(data_idx_order_list)
  69. wh_ratio = np.around(np.array(wh_ratio))
  70. self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
  71. for i in range(max_ratio + 1):
  72. logger.info((1 * (self.wh_ratio == i)).sum())
  73. self.wh_ratio_sort = np.argsort(self.wh_ratio)
  74. self.ops = create_operators(dataset_config['transforms'],
  75. global_config)
  76. self.need_reset = True in [x < 1 for x in ratio_list]
  77. self.error = 0
  78. self.base_shape = dataset_config.get(
  79. 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
  80. self.base_h = 32
  81. def get_wh_ratio(self, data_idx_order_list, char_test):
  82. wh_ratio = []
  83. wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)]
  84. data_idx_order_list_filter = []
  85. charset_adapter = CharsetAdapter(char_test)
  86. for idx in range(data_idx_order_list.shape[0]):
  87. lmdb_idx, file_idx = data_idx_order_list[idx]
  88. lmdb_idx = int(lmdb_idx)
  89. file_idx = int(file_idx)
  90. wh_key = 'wh-%09d'.encode() % file_idx
  91. wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
  92. if wh is None:
  93. img_key = f'image-{file_idx:09d}'.encode()
  94. img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
  95. buf = io.BytesIO(img)
  96. w, h = Image.open(buf).size
  97. else:
  98. wh = wh.decode('utf-8')
  99. w, h = wh.split('_')
  100. label_key = 'label-%09d'.encode() % file_idx
  101. label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key)
  102. if label is not None:
  103. # return None
  104. label = label.decode('utf-8')
  105. # if remove_whitespace:
  106. label = ''.join(label.split())
  107. # Normalize unicode composites (if any) and convert to compatible ASCII characters
  108. # if normalize_unicode:
  109. label = unicodedata.normalize('NFKD',
  110. label).encode('ascii',
  111. 'ignore').decode()
  112. # Filter by length before removing unsupported characters. The original label might be too long.
  113. if len(label) > self.max_text_length:
  114. continue
  115. label = charset_adapter(label)
  116. if not label:
  117. continue
  118. wh_ratio.append(float(w) / float(h))
  119. wh_ratio_len[int(float(w) /
  120. float(h)) if int(float(w) /
  121. float(h)) <= 10 else
  122. 10][len(label) if len(label) <= 25 else 25] += 1
  123. data_idx_order_list_filter.append([lmdb_idx, file_idx])
  124. self.logger.info(wh_ratio_len)
  125. return wh_ratio, data_idx_order_list_filter
  126. def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
  127. lmdb_sets = {}
  128. dataset_idx = 0
  129. for dirpath, ratio in zip(data_dir_list, ratio_list):
  130. env = lmdb.open(dirpath,
  131. max_readers=32,
  132. readonly=True,
  133. lock=False,
  134. readahead=False,
  135. meminit=False)
  136. txn = env.begin(write=False)
  137. num_samples = int(txn.get('num-samples'.encode()))
  138. lmdb_sets[dataset_idx] = {
  139. 'dirpath': dirpath,
  140. 'env': env,
  141. 'txn': txn,
  142. 'num_samples': num_samples,
  143. 'ratio_num_samples': int(ratio * num_samples),
  144. }
  145. dataset_idx += 1
  146. return lmdb_sets
  147. def dataset_traversal(self):
  148. lmdb_num = len(self.lmdb_sets)
  149. total_sample_num = 0
  150. for lno in range(lmdb_num):
  151. total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
  152. data_idx_order_list = np.zeros((total_sample_num, 2))
  153. beg_idx = 0
  154. for lno in range(lmdb_num):
  155. tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
  156. end_idx = beg_idx + tmp_sample_num
  157. data_idx_order_list[beg_idx:end_idx, 0] = lno
  158. data_idx_order_list[beg_idx:end_idx, 1] = list(
  159. random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
  160. self.lmdb_sets[lno]['ratio_num_samples']))
  161. beg_idx = beg_idx + tmp_sample_num
  162. return data_idx_order_list
  163. def get_img_data(self, value):
  164. """get_img_data."""
  165. if not value:
  166. return None
  167. imgdata = np.frombuffer(value, dtype='uint8')
  168. if imgdata is None:
  169. return None
  170. imgori = cv2.imdecode(imgdata, 1)
  171. if imgori is None:
  172. return None
  173. return imgori
  174. def resize_norm_img(self, data, gen_ratio, padding=True):
  175. img = data['image']
  176. h = img.shape[0]
  177. w = img.shape[1]
  178. imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
  179. self.base_h * gen_ratio, self.base_h
  180. ]
  181. use_ratio = imgW // imgH
  182. if use_ratio >= (w // h) + 2:
  183. self.error += 1
  184. return None
  185. if not padding:
  186. resized_image = cv2.resize(img, (imgW, imgH),
  187. interpolation=cv2.INTER_LINEAR)
  188. resized_w = imgW
  189. else:
  190. ratio = w / float(h)
  191. if math.ceil(imgH * ratio) > imgW:
  192. resized_w = imgW
  193. else:
  194. resized_w = int(
  195. math.ceil(imgH * ratio * (random.random() + 0.5)))
  196. resized_w = min(imgW, resized_w)
  197. resized_image = cv2.resize(img, (resized_w, imgH))
  198. resized_image = resized_image.astype('float32')
  199. resized_image = resized_image.transpose((2, 0, 1)) / 255
  200. resized_image -= 0.5
  201. resized_image /= 0.5
  202. padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
  203. padding_im[:, :, :resized_w] = resized_image
  204. valid_ratio = min(1.0, float(resized_w / imgW))
  205. data['image'] = padding_im
  206. data['valid_ratio'] = valid_ratio
  207. data['gen_ratio'] = imgW // imgH
  208. data['real_ratio'] = max(1, round(w / h))
  209. return data
  210. def get_lmdb_sample_info(self, txn, index):
  211. label_key = 'label-%09d'.encode() % index
  212. label = txn.get(label_key)
  213. if label is None:
  214. return None
  215. label = label.decode('utf-8')
  216. img_key = 'image-%09d'.encode() % index
  217. imgbuf = txn.get(img_key)
  218. return imgbuf, label
  219. def __getitem__(self, properties):
  220. img_width = properties[0]
  221. img_height = properties[1]
  222. idx = properties[2]
  223. ratio = properties[3]
  224. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  225. lmdb_idx = int(lmdb_idx)
  226. file_idx = int(file_idx)
  227. sample_info = self.get_lmdb_sample_info(
  228. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  229. if sample_info is None:
  230. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  231. ids = random.sample(ratio_ids, 1)
  232. return self.__getitem__([img_width, img_height, ids[0], ratio])
  233. img, label = sample_info
  234. data = {'image': img, 'label': label}
  235. outs = transform(data, self.ops[:-1])
  236. if outs is not None:
  237. outs = self.resize_norm_img(outs, ratio, padding=False)
  238. if outs is None:
  239. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  240. ids = random.sample(ratio_ids, 1)
  241. return self.__getitem__([img_width, img_height, ids[0], ratio])
  242. outs = transform(outs, self.ops[-1:])
  243. if outs is None:
  244. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  245. ids = random.sample(ratio_ids, 1)
  246. return self.__getitem__([img_width, img_height, ids[0], ratio])
  247. return outs
  248. def __len__(self):
  249. return self.data_idx_order_list.shape[0]