__init__.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. __all__ = ['build_encoder']
  2. from importlib import import_module
  3. name_to_module = {
  4. 'MobileNetV1Enhance': '.rec_mv1_enhance',
  5. 'ResNet31': '.rec_resnet_31',
  6. 'MobileNetV3': '.rec_mobilenet_v3',
  7. 'PPLCNetV3': '.rec_lcnetv3',
  8. 'PPHGNet_small': '.rec_hgnet',
  9. 'ResNet': '.rec_resnet_vd',
  10. 'MTB': '.rec_nrtr_mtb',
  11. 'SVTRNet': '.svtrnet',
  12. 'ResNet45': '.rec_resnet_45',
  13. 'ViT': '.vit',
  14. 'SVTRNet2DPos': '.svtrnet2dpos',
  15. 'SVTRv2': '.svtrv2',
  16. 'FocalSVTR': '.focalsvtr',
  17. 'ResNet_FPN': '.rec_resnet_fpn',
  18. 'ResNet_ASTER': '.resnet31_rnn',
  19. 'SVTRv2LNConv': '.svtrv2_lnconv',
  20. 'SVTRv2LNConvTwo33': '.svtrv2_lnconv_two33',
  21. 'CAMEncoder': '.cam_encoder',
  22. 'ConvNeXtV2': '.convnextv2',
  23. 'AutoSTREncoder': '.autostr_encoder',
  24. 'NRTREncoder': '.nrtr_encoder',
  25. 'RepSVTREncoder': '.repvit',
  26. }
  27. def build_encoder(config):
  28. module_name = config.pop('name')
  29. assert module_name in name_to_module, Exception(
  30. f'Encoder only supports: {list(name_to_module.keys())}')
  31. module_path = name_to_module[module_name]
  32. mod = import_module(module_path, package=__package__)
  33. module_class = getattr(mod, module_name)(**config)
  34. return module_class