create_lmdb_dataset.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import lmdb
  3. import cv2
  4. from tqdm import tqdm
  5. import numpy as np
  6. import io
  7. from PIL import Image
  8. """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """
  9. def get_datalist(data_dir, data_path, max_len):
  10. """
  11. 获取训练和验证的数据list
  12. :param data_dir: 数据集根目录
  13. :param data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
  14. :return:
  15. """
  16. train_data = []
  17. if isinstance(data_path, list):
  18. for p in data_path:
  19. train_data.extend(get_datalist(data_dir, p, max_len))
  20. else:
  21. with open(data_path, 'r', encoding='utf-8') as f:
  22. for line in tqdm(f.readlines(),
  23. desc=f'load data from {data_path}'):
  24. line = (line.strip('\n').replace('.jpg ', '.jpg\t').replace(
  25. '.png ', '.png\t').split('\t'))
  26. if len(line) > 1:
  27. img_path = os.path.join(data_dir, line[0].strip(' '))
  28. label = line[1]
  29. if len(label) > max_len:
  30. continue
  31. if os.path.exists(
  32. img_path) and os.path.getsize(img_path) > 0:
  33. train_data.append([str(img_path), label])
  34. return train_data
  35. def checkImageIsValid(imageBin):
  36. if imageBin is None:
  37. return False
  38. imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
  39. img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
  40. imgH, imgW = img.shape[0], img.shape[1]
  41. if imgH * imgW == 0:
  42. return False
  43. return True
  44. def writeCache(env, cache):
  45. with env.begin(write=True) as txn:
  46. for k, v in cache.items():
  47. txn.put(k, v)
  48. def createDataset(data_list, outputPath, checkValid=True):
  49. """
  50. Create LMDB dataset for training and evaluation.
  51. ARGS:
  52. inputPath : input folder path where starts imagePath
  53. outputPath : LMDB output path
  54. gtFile : list of image path and label
  55. checkValid : if true, check the validity of every image
  56. """
  57. os.makedirs(outputPath, exist_ok=True)
  58. env = lmdb.open(outputPath, map_size=1099511627776)
  59. cache = {}
  60. cnt = 1
  61. for imagePath, label in tqdm(data_list,
  62. desc=f'make dataset, save to {outputPath}'):
  63. with open(imagePath, 'rb') as f:
  64. imageBin = f.read()
  65. buf = io.BytesIO(imageBin)
  66. w, h = Image.open(buf).size
  67. if checkValid:
  68. try:
  69. if not checkImageIsValid(imageBin):
  70. print('%s is not a valid image' % imagePath)
  71. continue
  72. except:
  73. continue
  74. imageKey = 'image-%09d'.encode() % cnt
  75. labelKey = 'label-%09d'.encode() % cnt
  76. whKey = 'wh-%09d'.encode() % cnt
  77. cache[imageKey] = imageBin
  78. cache[labelKey] = label.encode()
  79. cache[whKey] = (str(w) + '_' + str(h)).encode()
  80. if cnt % 1000 == 0:
  81. writeCache(env, cache)
  82. cache = {}
  83. cnt += 1
  84. nSamples = cnt - 1
  85. cache['num-samples'.encode()] = str(nSamples).encode()
  86. writeCache(env, cache)
  87. print('Created dataset with %d samples' % nSamples)
  88. if __name__ == '__main__':
  89. data_dir = './Union14M-L/'
  90. # downloading the filtered_label_list from https://drive.google.com/drive/folders/1x1LC8C_W-Frl3sGV9i9_i_OD-bqNdodJ?usp=drive_link
  91. label_file_list = [
  92. './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_challenging.jsonl.txt',
  93. './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_easy.jsonl.txt',
  94. './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_hard.jsonl.txt',
  95. './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_medium.jsonl.txt',
  96. './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_normal.jsonl.txt'
  97. ]
  98. save_path_root = './Union14M-L-LMDB-Filtered/'
  99. for data_list in label_file_list:
  100. save_path = save_path_root + data_list.split('/')[-1].split(
  101. '.')[0] + '/'
  102. os.makedirs(save_path, exist_ok=True)
  103. print(save_path)
  104. train_data_list = get_datalist(data_dir, data_list, 800)
  105. createDataset(train_data_list, save_path)