__init__.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import os
  2. import sys
  3. import copy
  4. import importlib
  5. __dir__ = os.path.dirname(os.path.abspath(__file__))
  6. sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
  7. from torch.utils.data import DataLoader, DistributedSampler
  8. # 定义支持的 Dataset 类及其对应的模块路径
  9. DATASET_MODULES = {
  10. 'SimpleDataSet': 'tools.data.simple_dataset',
  11. 'LMDBDataSet': 'tools.data.lmdb_dataset',
  12. 'TextLMDBDataSet': 'tools.data.text_lmdb_dataset',
  13. 'MultiScaleDataSet': 'tools.data.simple_dataset',
  14. 'STRLMDBDataSet': 'tools.data.strlmdb_dataset',
  15. 'LMDBDataSetTest': 'tools.data.lmdb_dataset_test',
  16. 'RatioDataSet': 'tools.data.ratio_dataset',
  17. 'RatioDataSetTest': 'tools.data.ratio_dataset_test',
  18. 'RatioDataSetTVResize': 'tools.data.ratio_dataset_tvresize',
  19. 'RatioDataSetTVResizeTest': 'tools.data.ratio_dataset_tvresize_test'
  20. }
  21. # 定义支持的 Sampler 类及其对应的模块路径
  22. SAMPLER_MODULES = {
  23. 'MultiScaleSampler': 'tools.data.multi_scale_sampler',
  24. 'RatioSampler': 'tools.data.ratio_sampler'
  25. }
  26. __all__ = [
  27. 'build_dataloader',
  28. ]
  29. def build_dataloader(config, mode, logger, seed=None, epoch=3, task='rec'):
  30. config = copy.deepcopy(config)
  31. mode = mode.capitalize() # 确保 mode 是首字母大写形式(Train/Eval/Test)
  32. # 获取 dataset 配置
  33. dataset_config = config[mode]['dataset']
  34. module_name = dataset_config['name']
  35. # 动态导入 dataset 类
  36. if module_name not in DATASET_MODULES:
  37. raise ValueError(
  38. f'Unsupported dataset: {module_name}. Supported datasets: {list(DATASET_MODULES.keys())}'
  39. )
  40. dataset_module = importlib.import_module(DATASET_MODULES[module_name])
  41. dataset_class = getattr(dataset_module, module_name)
  42. dataset = dataset_class(config, mode, logger, seed, epoch=epoch, task=task)
  43. # DataLoader 配置
  44. loader_config = config[mode]['loader']
  45. batch_size = loader_config['batch_size_per_card']
  46. drop_last = loader_config['drop_last']
  47. shuffle = loader_config['shuffle']
  48. num_workers = loader_config['num_workers']
  49. pin_memory = loader_config.get('pin_memory', False)
  50. sampler = None
  51. batch_sampler = None
  52. if 'sampler' in config[mode]:
  53. sampler_config = config[mode]['sampler']
  54. sampler_name = sampler_config.pop('name')
  55. if sampler_name not in SAMPLER_MODULES:
  56. raise ValueError(
  57. f'Unsupported sampler: {sampler_name}. Supported samplers: {list(SAMPLER_MODULES.keys())}'
  58. )
  59. sampler_module = importlib.import_module(SAMPLER_MODULES[sampler_name])
  60. sampler_class = getattr(sampler_module, sampler_name)
  61. batch_sampler = sampler_class(dataset, **sampler_config)
  62. elif config['Global']['distributed'] and mode == 'Train':
  63. sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)
  64. if 'collate_fn' in loader_config:
  65. from . import collate_fn
  66. collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
  67. else:
  68. collate_fn = None
  69. if batch_sampler is None:
  70. data_loader = DataLoader(
  71. dataset=dataset,
  72. sampler=sampler,
  73. num_workers=num_workers,
  74. pin_memory=pin_memory,
  75. collate_fn=collate_fn,
  76. batch_size=batch_size,
  77. drop_last=drop_last,
  78. )
  79. else:
  80. data_loader = DataLoader(
  81. dataset=dataset,
  82. batch_sampler=batch_sampler,
  83. num_workers=num_workers,
  84. pin_memory=pin_memory,
  85. collate_fn=collate_fn,
  86. )
  87. # 检查数据加载器是否为空
  88. if len(data_loader) == 0:
  89. logger.error(
  90. f'No Images in {mode.lower()} dataloader. Please check:\n'
  91. '\t1. The images num in the train label_file_list should be >= batch size.\n'
  92. '\t2. The annotation file and path in the configuration are correct.\n'
  93. '\t3. The BatchSize is not larger than the number of images.')
  94. sys.exit()
  95. return data_loader