__init__.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import copy
  2. from importlib import import_module
  3. __all__ = ['build_post_process']
  4. # 定义类名到模块路径的映射
  5. module_mapping = {
  6. 'CTCLabelDecode': '.ctc_postprocess',
  7. 'CharLabelDecode': '.char_postprocess',
  8. 'CELabelDecode': '.ce_postprocess',
  9. 'CPPDLabelDecode': '.cppd_postprocess',
  10. 'NRTRLabelDecode': '.nrtr_postprocess',
  11. 'ABINetLabelDecode': '.abinet_postprocess',
  12. 'ARLabelDecode': '.ar_postprocess',
  13. 'IGTRLabelDecode': '.igtr_postprocess',
  14. 'VisionLANLabelDecode': '.visionlan_postprocess',
  15. 'SMTRLabelDecode': '.smtr_postprocess',
  16. 'SRNLabelDecode': '.srn_postprocess',
  17. 'LISTERLabelDecode': '.lister_postprocess',
  18. 'MPGLabelDecode': '.mgp_postprocess',
  19. 'GTCLabelDecode': '.' # 当前模块中的类
  20. }
  21. def build_post_process(config, global_config=None):
  22. config = copy.deepcopy(config)
  23. module_name = config.pop('name')
  24. if global_config is not None:
  25. config.update(global_config)
  26. assert module_name in module_mapping, Exception(
  27. 'post process only support {}'.format(list(module_mapping.keys())))
  28. module_path = module_mapping[module_name]
  29. # 处理当前模块中的类
  30. if module_path == '.':
  31. module_class = globals()[module_name]
  32. else:
  33. # 动态导入模块
  34. module = import_module(module_path, package=__package__)
  35. module_class = getattr(module, module_name)
  36. return module_class(**config)
  37. class GTCLabelDecode(object):
  38. """Convert between text-label and text-index."""
  39. def __init__(self,
  40. gtc_label_decode=None,
  41. character_dict_path=None,
  42. use_space_char=True,
  43. only_gtc=False,
  44. with_ratio=False,
  45. **kwargs):
  46. gtc_label_decode['character_dict_path'] = character_dict_path
  47. gtc_label_decode['use_space_char'] = use_space_char
  48. self.gtc_label_decode = build_post_process(gtc_label_decode)
  49. self.ctc_label_decode = build_post_process({
  50. 'name':
  51. 'CTCLabelDecode',
  52. 'character_dict_path':
  53. character_dict_path,
  54. 'use_space_char':
  55. use_space_char
  56. })
  57. self.gtc_character = self.gtc_label_decode.character
  58. self.ctc_character = self.ctc_label_decode.character
  59. self.only_gtc = only_gtc
  60. self.with_ratio = with_ratio
  61. def get_character_num(self):
  62. return [len(self.gtc_character), len(self.ctc_character)]
  63. def __call__(self, preds, batch=None, *args, **kwargs):
  64. if self.with_ratio:
  65. batch = batch[:-1]
  66. gtc = self.gtc_label_decode(preds['gtc_pred'],
  67. batch[:-2] if batch is not None else None)
  68. if self.only_gtc:
  69. return gtc
  70. ctc = self.ctc_label_decode(preds['ctc_pred'], [None] +
  71. batch[-2:] if batch is not None else None)
  72. return [gtc, ctc]