__init__.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import io
  2. import copy
  3. import importlib
  4. import cv2
  5. import numpy as np
  6. from PIL import Image
  7. class KeepKeys:
  8. def __init__(self, keep_keys, **kwargs):
  9. self.keep_keys = keep_keys
  10. def __call__(self, data):
  11. return [data[key] for key in self.keep_keys]
  12. class Fasttext:
  13. def __init__(self, path='None', **kwargs):
  14. import fasttext
  15. self.fast_model = fasttext.load_model(path)
  16. def __call__(self, data):
  17. data['fast_label'] = self.fast_model[data['label']]
  18. return data
  19. class DecodeImage:
  20. def __init__(self,
  21. img_mode='RGB',
  22. channel_first=False,
  23. ignore_orientation=False,
  24. **kwargs):
  25. self.img_mode = img_mode
  26. self.channel_first = channel_first
  27. self.ignore_orientation = ignore_orientation
  28. def __call__(self, data):
  29. assert isinstance(data['image'], bytes) and len(data['image']) > 0
  30. img = np.frombuffer(data['image'], dtype='uint8')
  31. flags = cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR if self.ignore_orientation else 1
  32. img = cv2.imdecode(img, flags)
  33. if self.img_mode == 'GRAY':
  34. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  35. elif self.img_mode == 'RGB':
  36. img = img[:, :, ::-1]
  37. if self.channel_first:
  38. img = img.transpose((2, 0, 1))
  39. data['image'] = img
  40. return data
  41. class DecodeImagePIL:
  42. def __init__(self, img_mode='RGB', **kwargs):
  43. self.img_mode = img_mode
  44. def __call__(self, data):
  45. assert isinstance(data['image'], bytes) and len(data['image']) > 0
  46. img = Image.open(io.BytesIO(data['image'])).convert('RGB')
  47. if self.img_mode == 'Gray':
  48. img = img.convert('L')
  49. elif self.img_mode == 'BGR':
  50. img = Image.fromarray(np.array(img)[:, :, ::-1])
  51. data['image'] = img
  52. return data
  53. def transform(data, ops=None):
  54. """transform."""
  55. if ops is None:
  56. ops = []
  57. for op in ops:
  58. data = op(data)
  59. if data is None:
  60. return None
  61. return data
  62. # 类名到模块的映射
  63. MODULE_MAPPING = {
  64. 'ABINetLabelEncode': '.abinet_label_encode',
  65. 'ARLabelEncode': '.ar_label_encode',
  66. 'CELabelEncode': '.ce_label_encode',
  67. 'CharLabelEncode': '.char_label_encode',
  68. 'CPPDLabelEncode': '.cppd_label_encode',
  69. 'CTCLabelEncode': '.ctc_label_encode',
  70. 'EPLabelEncode': '.ep_label_encode',
  71. 'IGTRLabelEncode': '.igtr_label_encode',
  72. 'MGPLabelEncode': '.mgp_label_encode',
  73. 'SMTRLabelEncode': '.smtr_label_encode',
  74. 'SRNLabelEncode': '.srn_label_encode',
  75. 'VisionLANLabelEncode': '.visionlan_label_encode',
  76. 'CAMLabelEncode': '.cam_label_encode',
  77. 'ABINetAug': '.rec_aug',
  78. 'BDA': '.rec_aug',
  79. 'PARSeqAug': '.rec_aug',
  80. 'PARSeqAugPIL': '.rec_aug',
  81. 'SVTRAug': '.rec_aug',
  82. 'ABINetResize': '.resize',
  83. 'CDistNetResize': '.resize',
  84. 'LongResize': '.resize',
  85. 'RecTVResize': '.resize',
  86. 'RobustScannerRecResizeImg': '.resize',
  87. 'SliceResize': '.resize',
  88. 'SliceTVResize': '.resize',
  89. 'SRNRecResizeImg': '.resize',
  90. 'SVTRResize': '.resize',
  91. 'VisionLANResize': '.resize',
  92. 'RecDynamicResize': '.resize',
  93. }
  94. def dynamic_import(class_name):
  95. module_path = MODULE_MAPPING.get(class_name)
  96. if not module_path:
  97. raise ValueError(f'Unsupported class: {class_name}')
  98. module = importlib.import_module(module_path, package=__package__)
  99. return getattr(module, class_name)
  100. def create_operators(op_param_list, global_config=None):
  101. ops = []
  102. for op_info in op_param_list:
  103. op_name = list(op_info.keys())[0]
  104. param = copy.deepcopy(op_info[op_name]) or {}
  105. if global_config:
  106. param.update(global_config)
  107. if op_name in globals():
  108. op_class = globals()[op_name]
  109. else:
  110. op_class = dynamic_import(op_name)
  111. ops.append(op_class(**param))
  112. return ops
  113. class GTCLabelEncode():
  114. """Convert between text-label and text-index."""
  115. def __init__(self,
  116. gtc_label_encode,
  117. max_text_length,
  118. character_dict_path=None,
  119. use_space_char=False,
  120. **kwargs):
  121. self.gtc_label_encode = dynamic_import(gtc_label_encode['name'])(
  122. max_text_length=max_text_length,
  123. character_dict_path=character_dict_path,
  124. use_space_char=use_space_char,
  125. **gtc_label_encode)
  126. self.ctc_label_encode = dynamic_import('CTCLabelEncode')(
  127. max_text_length, character_dict_path, use_space_char)
  128. def __call__(self, data):
  129. data_ctc = self.ctc_label_encode({'label': data['label']})
  130. data = self.gtc_label_encode(data)
  131. if data_ctc is None or data is None:
  132. return None
  133. data['ctc_label'] = data_ctc['label']
  134. data['ctc_length'] = data_ctc['length']
  135. return data