1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import copy
- from importlib import import_module
- __all__ = ['build_post_process']
- # 定义类名到模块路径的映射
- module_mapping = {
- 'CTCLabelDecode': '.ctc_postprocess',
- 'CharLabelDecode': '.char_postprocess',
- 'CELabelDecode': '.ce_postprocess',
- 'CPPDLabelDecode': '.cppd_postprocess',
- 'NRTRLabelDecode': '.nrtr_postprocess',
- 'ABINetLabelDecode': '.abinet_postprocess',
- 'ARLabelDecode': '.ar_postprocess',
- 'IGTRLabelDecode': '.igtr_postprocess',
- 'VisionLANLabelDecode': '.visionlan_postprocess',
- 'SMTRLabelDecode': '.smtr_postprocess',
- 'SRNLabelDecode': '.srn_postprocess',
- 'LISTERLabelDecode': '.lister_postprocess',
- 'MPGLabelDecode': '.mgp_postprocess',
- 'GTCLabelDecode': '.' # 当前模块中的类
- }
- def build_post_process(config, global_config=None):
- config = copy.deepcopy(config)
- module_name = config.pop('name')
- if global_config is not None:
- config.update(global_config)
- assert module_name in module_mapping, Exception(
- 'post process only support {}'.format(list(module_mapping.keys())))
- module_path = module_mapping[module_name]
- # 处理当前模块中的类
- if module_path == '.':
- module_class = globals()[module_name]
- else:
- # 动态导入模块
- module = import_module(module_path, package=__package__)
- module_class = getattr(module, module_name)
- return module_class(**config)
- class GTCLabelDecode(object):
- """Convert between text-label and text-index."""
- def __init__(self,
- gtc_label_decode=None,
- character_dict_path=None,
- use_space_char=True,
- only_gtc=False,
- with_ratio=False,
- **kwargs):
- gtc_label_decode['character_dict_path'] = character_dict_path
- gtc_label_decode['use_space_char'] = use_space_char
- self.gtc_label_decode = build_post_process(gtc_label_decode)
- self.ctc_label_decode = build_post_process({
- 'name':
- 'CTCLabelDecode',
- 'character_dict_path':
- character_dict_path,
- 'use_space_char':
- use_space_char
- })
- self.gtc_character = self.gtc_label_decode.character
- self.ctc_character = self.ctc_label_decode.character
- self.only_gtc = only_gtc
- self.with_ratio = with_ratio
- def get_character_num(self):
- return [len(self.gtc_character), len(self.ctc_character)]
- def __call__(self, preds, batch=None, *args, **kwargs):
- if self.with_ratio:
- batch = batch[:-1]
- gtc = self.gtc_label_decode(preds['gtc_pred'],
- batch[:-2] if batch is not None else None)
- if self.only_gtc:
- return gtc
- ctc = self.ctc_label_decode(preds['ctc_pred'], [None] +
- batch[-2:] if batch is not None else None)
- return [gtc, ctc]
|