1234567891011121314151617181920212223242526272829303132333435363738394041 |
- __all__ = ['build_encoder']
- from importlib import import_module
- name_to_module = {
- 'MobileNetV1Enhance': '.rec_mv1_enhance',
- 'ResNet31': '.rec_resnet_31',
- 'MobileNetV3': '.rec_mobilenet_v3',
- 'PPLCNetV3': '.rec_lcnetv3',
- 'PPHGNet_small': '.rec_hgnet',
- 'ResNet': '.rec_resnet_vd',
- 'MTB': '.rec_nrtr_mtb',
- 'SVTRNet': '.svtrnet',
- 'ResNet45': '.rec_resnet_45',
- 'ViT': '.vit',
- 'SVTRNet2DPos': '.svtrnet2dpos',
- 'SVTRv2': '.svtrv2',
- 'FocalSVTR': '.focalsvtr',
- 'ResNet_FPN': '.rec_resnet_fpn',
- 'ResNet_ASTER': '.resnet31_rnn',
- 'SVTRv2LNConv': '.svtrv2_lnconv',
- 'SVTRv2LNConvTwo33': '.svtrv2_lnconv_two33',
- 'CAMEncoder': '.cam_encoder',
- 'ConvNeXtV2': '.convnextv2',
- 'AutoSTREncoder': '.autostr_encoder',
- 'NRTREncoder': '.nrtr_encoder',
- 'RepSVTREncoder': '.repvit',
- }
- def build_encoder(config):
- module_name = config.pop('name')
- assert module_name in name_to_module, Exception(
- f'Encoder only supports: {list(name_to_module.keys())}')
- module_path = name_to_module[module_name]
- mod = import_module(module_path, package=__package__)
- module_class = getattr(mod, module_name)(**config)
- return module_class
|