ratio_dataset_tvresize_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import io
  2. import math
  3. import random
  4. import re
  5. import unicodedata
  6. import cv2
  7. import lmdb
  8. import numpy as np
  9. from PIL import Image
  10. from torch.utils.data import Dataset
  11. from torchvision import transforms as T
  12. from torchvision.transforms import functional as F
  13. from openrec.preprocess import create_operators, transform
  14. class CharsetAdapter:
  15. """Transforms labels according to the target charset."""
  16. def __init__(self, target_charset) -> None:
  17. super().__init__()
  18. self.lowercase_only = target_charset == target_charset.lower()
  19. self.uppercase_only = target_charset == target_charset.upper()
  20. self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
  21. def __call__(self, label):
  22. if self.lowercase_only:
  23. label = label.lower()
  24. elif self.uppercase_only:
  25. label = label.upper()
  26. # Remove unsupported characters
  27. label = self.unsupported.sub('', label)
  28. return label
  29. class RatioDataSetTVResizeTest(Dataset):
  30. def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'):
  31. super(RatioDataSetTVResizeTest, self).__init__()
  32. self.ds_width = config[mode]['dataset'].get('ds_width', True)
  33. global_config = config['Global']
  34. dataset_config = config[mode]['dataset']
  35. loader_config = config[mode]['loader']
  36. max_ratio = loader_config.get('max_ratio', 10)
  37. min_ratio = loader_config.get('min_ratio', 1)
  38. data_dir_list = dataset_config['data_dir_list']
  39. self.do_shuffle = loader_config['shuffle']
  40. self.seed = epoch
  41. self.max_text_length = global_config['max_text_length']
  42. data_source_num = len(data_dir_list)
  43. ratio_list = dataset_config.get('ratio_list', 1.0)
  44. if isinstance(ratio_list, (float, int)):
  45. ratio_list = [float(ratio_list)] * int(data_source_num)
  46. assert len(
  47. ratio_list
  48. ) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
  49. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
  50. data_dir_list, ratio_list)
  51. for data_dir in data_dir_list:
  52. logger.info('Initialize indexs of datasets:%s' % data_dir)
  53. self.logger = logger
  54. data_idx_order_list = self.dataset_traversal()
  55. character_dict_path = global_config.get('character_dict_path', None)
  56. use_space_char = global_config.get('use_space_char', False)
  57. if character_dict_path is None:
  58. char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
  59. else:
  60. char_test = ''
  61. with open(character_dict_path, 'rb') as fin:
  62. lines = fin.readlines()
  63. for line in lines:
  64. line = line.decode('utf-8').strip('\n').strip('\r\n')
  65. char_test += line
  66. if use_space_char:
  67. char_test += ' '
  68. wh_ratio, data_idx_order_list = self.get_wh_ratio(
  69. data_idx_order_list, char_test)
  70. self.data_idx_order_list = np.array(data_idx_order_list)
  71. wh_ratio = np.around(np.array(wh_ratio))
  72. self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
  73. for i in range(max_ratio + 1):
  74. logger.info((1 * (self.wh_ratio == i)).sum())
  75. self.wh_ratio_sort = np.argsort(self.wh_ratio)
  76. self.ops = create_operators(dataset_config['transforms'],
  77. global_config)
  78. self.need_reset = True in [x < 1 for x in ratio_list]
  79. self.error = 0
  80. self.base_shape = dataset_config.get(
  81. 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
  82. self.base_h = dataset_config.get('base_h', 32)
  83. self.interpolation = T.InterpolationMode.BICUBIC
  84. transforms = []
  85. transforms.extend([
  86. T.ToTensor(),
  87. T.Normalize(0.5, 0.5),
  88. ])
  89. self.transforms = T.Compose(transforms)
  90. def get_wh_ratio(self, data_idx_order_list, char_test):
  91. wh_ratio = []
  92. wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)]
  93. data_idx_order_list_filter = []
  94. charset_adapter = CharsetAdapter(char_test)
  95. for idx in range(data_idx_order_list.shape[0]):
  96. lmdb_idx, file_idx = data_idx_order_list[idx]
  97. lmdb_idx = int(lmdb_idx)
  98. file_idx = int(file_idx)
  99. wh_key = 'wh-%09d'.encode() % file_idx
  100. wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
  101. if wh is None:
  102. img_key = f'image-{file_idx:09d}'.encode()
  103. img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
  104. buf = io.BytesIO(img)
  105. w, h = Image.open(buf).size
  106. else:
  107. wh = wh.decode('utf-8')
  108. w, h = wh.split('_')
  109. label_key = 'label-%09d'.encode() % file_idx
  110. label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key)
  111. if label is not None:
  112. # return None
  113. label = label.decode('utf-8')
  114. # if remove_whitespace:
  115. label = ''.join(label.split())
  116. # Normalize unicode composites (if any) and convert to compatible ASCII characters
  117. # if normalize_unicode:
  118. label = unicodedata.normalize('NFKD',
  119. label).encode('ascii',
  120. 'ignore').decode()
  121. # Filter by length before removing unsupported characters. The original label might be too long.
  122. if len(label) > self.max_text_length:
  123. continue
  124. label = charset_adapter(label)
  125. if not label:
  126. continue
  127. wh_ratio.append(float(w) / float(h))
  128. wh_ratio_len[int(float(w) /
  129. float(h)) if int(float(w) /
  130. float(h)) <= 10 else
  131. 10][len(label) if len(label) <= 25 else 25] += 1
  132. data_idx_order_list_filter.append([lmdb_idx, file_idx])
  133. self.logger.info(wh_ratio_len)
  134. return wh_ratio, data_idx_order_list_filter
  135. def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
  136. lmdb_sets = {}
  137. dataset_idx = 0
  138. for dirpath, ratio in zip(data_dir_list, ratio_list):
  139. env = lmdb.open(dirpath,
  140. max_readers=32,
  141. readonly=True,
  142. lock=False,
  143. readahead=False,
  144. meminit=False)
  145. txn = env.begin(write=False)
  146. num_samples = int(txn.get('num-samples'.encode()))
  147. lmdb_sets[dataset_idx] = {
  148. 'dirpath': dirpath,
  149. 'env': env,
  150. 'txn': txn,
  151. 'num_samples': num_samples,
  152. 'ratio_num_samples': int(ratio * num_samples),
  153. }
  154. dataset_idx += 1
  155. return lmdb_sets
  156. def dataset_traversal(self):
  157. lmdb_num = len(self.lmdb_sets)
  158. total_sample_num = 0
  159. for lno in range(lmdb_num):
  160. total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
  161. data_idx_order_list = np.zeros((total_sample_num, 2))
  162. beg_idx = 0
  163. for lno in range(lmdb_num):
  164. tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
  165. end_idx = beg_idx + tmp_sample_num
  166. data_idx_order_list[beg_idx:end_idx, 0] = lno
  167. data_idx_order_list[beg_idx:end_idx, 1] = list(
  168. random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
  169. self.lmdb_sets[lno]['ratio_num_samples']))
  170. beg_idx = beg_idx + tmp_sample_num
  171. return data_idx_order_list
  172. def get_img_data(self, value):
  173. """get_img_data."""
  174. if not value:
  175. return None
  176. imgdata = np.frombuffer(value, dtype='uint8')
  177. if imgdata is None:
  178. return None
  179. imgori = cv2.imdecode(imgdata, 1)
  180. if imgori is None:
  181. return None
  182. return imgori
  183. def resize_norm_img(self, data, gen_ratio, padding=True):
  184. img = data['image']
  185. w, h = img.size
  186. imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
  187. self.base_h * gen_ratio, self.base_h
  188. ]
  189. use_ratio = imgW // imgH
  190. if use_ratio >= (w // h) + 2:
  191. self.error += 1
  192. return None
  193. if not padding:
  194. resized_w = imgW
  195. else:
  196. ratio = w / float(h)
  197. if math.ceil(imgH * ratio) > imgW:
  198. resized_w = imgW
  199. else:
  200. resized_w = int(
  201. math.ceil(imgH * ratio * (random.random() + 0.5)))
  202. resized_w = min(imgW, resized_w)
  203. resized_image = F.resize(img, (imgH, resized_w),
  204. interpolation=self.interpolation)
  205. img = self.transforms(resized_image)
  206. if resized_w < imgW and padding:
  207. img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
  208. valid_ratio = min(1.0, float(resized_w / imgW))
  209. data['image'] = img
  210. data['valid_ratio'] = valid_ratio
  211. data['gen_ratio'] = imgW // imgH
  212. r = float(w) / float(h)
  213. data['real_ratio'] = max(1, round(r))
  214. return data
  215. def get_lmdb_sample_info(self, txn, index):
  216. label_key = 'label-%09d'.encode() % index
  217. label = txn.get(label_key)
  218. if label is None:
  219. return None
  220. label = label.decode('utf-8')
  221. img_key = 'image-%09d'.encode() % index
  222. imgbuf = txn.get(img_key)
  223. return imgbuf, label
  224. def __getitem__(self, properties):
  225. img_width = properties[0]
  226. img_height = properties[1]
  227. idx = properties[2]
  228. ratio = properties[3]
  229. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  230. lmdb_idx = int(lmdb_idx)
  231. file_idx = int(file_idx)
  232. sample_info = self.get_lmdb_sample_info(
  233. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  234. if sample_info is None:
  235. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  236. ids = random.sample(ratio_ids, 1)
  237. return self.__getitem__([img_width, img_height, ids[0], ratio])
  238. img, label = sample_info
  239. data = {'image': img, 'label': label}
  240. outs = transform(data, self.ops[:-1])
  241. if outs is not None:
  242. outs = self.resize_norm_img(outs, ratio, padding=False)
  243. if outs is None:
  244. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  245. ids = random.sample(ratio_ids, 1)
  246. return self.__getitem__([img_width, img_height, ids[0], ratio])
  247. outs = transform(outs, self.ops[-1:])
  248. if outs is None:
  249. ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
  250. ids = random.sample(ratio_ids, 1)
  251. return self.__getitem__([img_width, img_height, ids[0], ratio])
  252. return outs
  253. def __len__(self):
  254. return self.data_idx_order_list.shape[0]