simple_dataset.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import json
  2. import math
  3. import os
  4. import random
  5. import traceback
  6. import cv2
  7. import numpy as np
  8. from torch.utils.data import Dataset
  9. from openrec.preprocess import transform
  10. class SimpleDataSet(Dataset):
  11. def __init__(self, config, mode, logger, seed=None, epoch=0, task='rec'):
  12. super(SimpleDataSet, self).__init__()
  13. self.logger = logger
  14. self.mode = mode.lower()
  15. global_config = config['Global']
  16. dataset_config = config[mode]['dataset']
  17. loader_config = config[mode]['loader']
  18. self.delimiter = dataset_config.get('delimiter', '\t')
  19. label_file_list = dataset_config.pop('label_file_list')
  20. data_source_num = len(label_file_list)
  21. ratio_list = dataset_config.get('ratio_list', 1.0)
  22. if isinstance(ratio_list, (float, int)):
  23. ratio_list = [float(ratio_list)] * int(data_source_num)
  24. assert len(
  25. ratio_list
  26. ) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
  27. self.data_dir = dataset_config['data_dir']
  28. self.do_shuffle = loader_config['shuffle']
  29. self.seed = seed
  30. logger.info(f'Initialize indexs of datasets: {label_file_list}')
  31. self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
  32. self.data_idx_order_list = list(range(len(self.data_lines)))
  33. if self.mode == 'train' and self.do_shuffle:
  34. self.shuffle_data_random()
  35. self.set_epoch_as_seed(self.seed, dataset_config)
  36. if task == 'rec':
  37. from openrec.preprocess import create_operators
  38. elif task == 'det':
  39. from opendet.preprocess import create_operators
  40. self.ops = create_operators(dataset_config['transforms'],
  41. global_config)
  42. self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx',
  43. 2)
  44. self.need_reset = True in [x < 1 for x in ratio_list]
  45. def set_epoch_as_seed(self, seed, dataset_config):
  46. if self.mode == 'train':
  47. try:
  48. border_map_id = [
  49. index for index, dictionary in enumerate(
  50. dataset_config['transforms'])
  51. if 'MakeBorderMap' in dictionary
  52. ][0]
  53. shrink_map_id = [
  54. index for index, dictionary in enumerate(
  55. dataset_config['transforms'])
  56. if 'MakeShrinkMap' in dictionary
  57. ][0]
  58. dataset_config['transforms'][border_map_id]['MakeBorderMap'][
  59. 'epoch'] = seed if seed is not None else 0
  60. dataset_config['transforms'][shrink_map_id]['MakeShrinkMap'][
  61. 'epoch'] = seed if seed is not None else 0
  62. except Exception:
  63. return
  64. def get_image_info_list(self, file_list, ratio_list):
  65. if isinstance(file_list, str):
  66. file_list = [file_list]
  67. data_lines = []
  68. for idx, file in enumerate(file_list):
  69. with open(file, 'rb') as f:
  70. lines = f.readlines()
  71. if self.mode == 'train' or ratio_list[idx] < 1.0:
  72. random.seed(self.seed)
  73. lines = random.sample(lines,
  74. round(len(lines) * ratio_list[idx]))
  75. data_lines.extend(lines)
  76. return data_lines
  77. def shuffle_data_random(self):
  78. random.seed(self.seed)
  79. random.shuffle(self.data_lines)
  80. return
  81. def _try_parse_filename_list(self, file_name):
  82. # multiple images -> one gt label
  83. if len(file_name) > 0 and file_name[0] == '[':
  84. try:
  85. info = json.loads(file_name)
  86. file_name = random.choice(info)
  87. except:
  88. pass
  89. return file_name
  90. def get_ext_data(self):
  91. ext_data_num = 0
  92. for op in self.ops:
  93. if hasattr(op, 'ext_data_num'):
  94. ext_data_num = getattr(op, 'ext_data_num')
  95. break
  96. load_data_ops = self.ops[:self.ext_op_transform_idx]
  97. ext_data = []
  98. while len(ext_data) < ext_data_num:
  99. file_idx = self.data_idx_order_list[np.random.randint(
  100. self.__len__())]
  101. data_line = self.data_lines[file_idx]
  102. data_line = data_line.decode('utf-8')
  103. substr = data_line.strip('\n').split(self.delimiter)
  104. file_name = substr[0]
  105. file_name = self._try_parse_filename_list(file_name)
  106. label = substr[1]
  107. img_path = os.path.join(self.data_dir, file_name)
  108. data = {'img_path': img_path, 'label': label}
  109. if not os.path.exists(img_path):
  110. continue
  111. with open(data['img_path'], 'rb') as f:
  112. img = f.read()
  113. data['image'] = img
  114. data = transform(data, load_data_ops)
  115. if data is None:
  116. continue
  117. if 'polys' in data.keys():
  118. if data['polys'].shape[1] != 4:
  119. continue
  120. ext_data.append(data)
  121. return ext_data
  122. def __getitem__(self, idx):
  123. file_idx = self.data_idx_order_list[idx]
  124. data_line = self.data_lines[file_idx]
  125. try:
  126. data_line = data_line.decode('utf-8')
  127. substr = data_line.strip('\n').split(self.delimiter)
  128. file_name = substr[0]
  129. file_name = self._try_parse_filename_list(file_name)
  130. label = substr[1]
  131. img_path = os.path.join(self.data_dir, file_name)
  132. data = {'img_path': img_path, 'label': label}
  133. if not os.path.exists(img_path):
  134. raise Exception('{} does not exist!'.format(img_path))
  135. with open(data['img_path'], 'rb') as f:
  136. img = f.read()
  137. data['image'] = img
  138. data['ext_data'] = self.get_ext_data()
  139. outs = transform(data, self.ops)
  140. except:
  141. self.logger.error(
  142. 'When parsing line {}, error happened with msg: {}'.format(
  143. data_line, traceback.format_exc()))
  144. outs = None
  145. if outs is None:
  146. # during evaluation, we should fix the idx to get same results for many times of evaluation.
  147. rnd_idx = np.random.randint(self.__len__(
  148. )) if self.mode == 'train' else (idx + 1) % self.__len__()
  149. return self.__getitem__(rnd_idx)
  150. return outs
  151. def __len__(self):
  152. return len(self.data_idx_order_list)
  153. class MultiScaleDataSet(SimpleDataSet):
  154. def __init__(self, config, mode, logger, seed=None):
  155. super(MultiScaleDataSet, self).__init__(config, mode, logger, seed)
  156. self.ds_width = config[mode]['dataset'].get('ds_width', False)
  157. if self.ds_width:
  158. self.wh_aware()
  159. def wh_aware(self):
  160. data_line_new = []
  161. wh_ratio = []
  162. for lins in self.data_lines:
  163. data_line_new.append(lins)
  164. lins = lins.decode('utf-8')
  165. name, label, w, h = lins.strip('\n').split(self.delimiter)
  166. wh_ratio.append(float(w) / float(h))
  167. self.data_lines = data_line_new
  168. self.wh_ratio = np.array(wh_ratio)
  169. self.wh_ratio_sort = np.argsort(self.wh_ratio)
  170. self.data_idx_order_list = list(range(len(self.data_lines)))
  171. def resize_norm_img(self, data, imgW, imgH, padding=True):
  172. img = data['image']
  173. h = img.shape[0]
  174. w = img.shape[1]
  175. if not padding:
  176. resized_image = cv2.resize(img, (imgW, imgH),
  177. interpolation=cv2.INTER_LINEAR)
  178. resized_w = imgW
  179. else:
  180. ratio = w / float(h)
  181. if math.ceil(imgH * ratio) > imgW:
  182. resized_w = imgW
  183. else:
  184. resized_w = int(math.ceil(imgH * ratio))
  185. resized_image = cv2.resize(img, (resized_w, imgH))
  186. resized_image = resized_image.astype('float32')
  187. resized_image = resized_image.transpose((2, 0, 1)) / 255
  188. resized_image -= 0.5
  189. resized_image /= 0.5
  190. padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
  191. padding_im[:, :, :resized_w] = resized_image
  192. valid_ratio = min(1.0, float(resized_w / imgW))
  193. data['image'] = padding_im
  194. data['valid_ratio'] = valid_ratio
  195. return data
  196. def __getitem__(self, properties):
  197. # properites is a tuple, contains (width, height, index)
  198. img_height = properties[1]
  199. idx = properties[2]
  200. if self.ds_width and properties[3] is not None:
  201. wh_ratio = properties[3]
  202. img_width = img_height * (1 if int(round(wh_ratio)) == 0 else int(
  203. round(wh_ratio)))
  204. file_idx = self.wh_ratio_sort[idx]
  205. else:
  206. file_idx = self.data_idx_order_list[idx]
  207. img_width = properties[0]
  208. wh_ratio = None
  209. data_line = self.data_lines[file_idx]
  210. try:
  211. data_line = data_line.decode('utf-8')
  212. substr = data_line.strip('\n').split(self.delimiter)
  213. file_name = substr[0]
  214. file_name = self._try_parse_filename_list(file_name)
  215. label = substr[1]
  216. img_path = os.path.join(self.data_dir, file_name)
  217. data = {'img_path': img_path, 'label': label}
  218. if not os.path.exists(img_path):
  219. raise Exception('{} does not exist!'.format(img_path))
  220. with open(data['img_path'], 'rb') as f:
  221. img = f.read()
  222. data['image'] = img
  223. data['ext_data'] = self.get_ext_data()
  224. outs = transform(data, self.ops[:-1])
  225. if outs is not None:
  226. outs = self.resize_norm_img(outs, img_width, img_height)
  227. outs = transform(outs, self.ops[-1:])
  228. except:
  229. self.logger.error(
  230. 'When parsing line {}, error happened with msg: {}'.format(
  231. data_line, traceback.format_exc()))
  232. outs = None
  233. if outs is None:
  234. # during evaluation, we should fix the idx to get same results for many times of evaluation.
  235. rnd_idx = np.random.randint(self.__len__(
  236. )) if self.mode == 'train' else (idx + 1) % self.__len__()
  237. return self.__getitem__([img_width, img_height, rnd_idx, wh_ratio])
  238. return outs