mgp_postprocess.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from .ctc_postprocess import BaseRecLabelDecode
  2. class MPGLabelDecode(BaseRecLabelDecode):
  3. """Convert between text-label and text-index."""
  4. SPACE = '[s]'
  5. GO = '[GO]'
  6. list_token = [GO, SPACE]
  7. def __init__(self,
  8. character_dict_path=None,
  9. use_space_char=False,
  10. only_char=False,
  11. **kwargs):
  12. super(MPGLabelDecode, self).__init__(character_dict_path,
  13. use_space_char)
  14. self.only_char = only_char
  15. self.EOS = '[s]'
  16. self.PAD = '[GO]'
  17. if not only_char:
  18. # transformers==4.2.1
  19. from transformers import BertTokenizer, GPT2Tokenizer
  20. self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
  21. self.wp_tokenizer = BertTokenizer.from_pretrained(
  22. 'bert-base-uncased')
  23. def __call__(self, preds, batch=None, *args, **kwargs):
  24. if isinstance(preds, list):
  25. char_preds = preds[0].detach().cpu().numpy()
  26. else:
  27. char_preds = preds.detach().cpu().numpy()
  28. preds_idx = char_preds.argmax(axis=2)
  29. preds_prob = char_preds.max(axis=2)
  30. char_text = self.char_decode(preds_idx[:, 1:], preds_prob[:, 1:])
  31. if batch is None:
  32. return char_text
  33. label = batch[1]
  34. label = self.char_decode(label[:, 1:])
  35. if self.only_char:
  36. return char_text, label
  37. else:
  38. bpe_preds = preds[1].detach().cpu().numpy()
  39. wp_preds = preds[2]
  40. bpe_preds_idx = bpe_preds.argmax(axis=2)
  41. bpe_preds_prob = bpe_preds.max(axis=2)
  42. bpe_text = self.bpe_decode(bpe_preds_idx[:, 1:],
  43. bpe_preds_prob[:, 1:])
  44. wp_preds = wp_preds.detach() #.cpu().numpy()
  45. wp_preds_prob, wp_preds_idx = wp_preds.max(-1)
  46. wp_text = self.wp_decode(wp_preds_idx[:, 1:], wp_preds_prob[:, 1:])
  47. final_text = self.final_decode(char_text, bpe_text, wp_text)
  48. return char_text, bpe_text, wp_text, final_text, label
  49. def add_special_char(self, dict_character):
  50. dict_character = self.list_token + dict_character
  51. return dict_character
  52. def final_decode(self, char_text, bpe_text, wp_text):
  53. result_list = []
  54. for (char_pred,
  55. char_pred_conf), (bpe_pred,
  56. bpe_pred_conf), (wp_pred, wp_pred_conf) in zip(
  57. char_text, bpe_text, wp_text):
  58. final_text = char_pred
  59. final_prob = char_pred_conf
  60. if bpe_pred_conf > final_prob:
  61. final_text = bpe_pred
  62. final_prob = bpe_pred_conf
  63. if wp_pred_conf > final_prob:
  64. final_text = wp_pred
  65. final_prob = wp_pred_conf
  66. result_list.append((final_text, final_prob))
  67. return result_list
  68. def char_decode(self, text_index, text_prob=None):
  69. """ convert text-index into text-label. """
  70. result_list = []
  71. batch_size = len(text_index)
  72. for batch_idx in range(batch_size):
  73. char_list = []
  74. conf_list = 1.0
  75. for idx in range(len(text_index[batch_idx])):
  76. try:
  77. char_idx = self.character[int(text_index[batch_idx][idx])]
  78. except:
  79. continue
  80. if text_prob is not None:
  81. conf_list *= text_prob[batch_idx][idx]
  82. if char_idx == self.EOS: # end
  83. break
  84. if char_idx == self.PAD:
  85. continue
  86. char_list.append(char_idx)
  87. text = ''.join(char_list)
  88. result_list.append((text, conf_list))
  89. return result_list
  90. def bpe_decode(self, text_index, text_prob):
  91. """ convert text-index into text-label. """
  92. result_list = []
  93. for text, probs in zip(text_index, text_prob):
  94. text_decoded = []
  95. conf_list = 1.0
  96. for bpeindx, prob in zip(text, probs):
  97. tokenstr = self.bpe_tokenizer.decode([bpeindx])
  98. if tokenstr == '#':
  99. break
  100. text_decoded.append(tokenstr)
  101. conf_list *= prob
  102. text = ''.join(text_decoded)
  103. result_list.append((text, conf_list))
  104. return result_list
  105. def wp_decode(self, text_index, text_prob=None):
  106. """ convert text-index into text-label. """
  107. result_list = []
  108. for batch_idx, text in enumerate(text_index):
  109. wp_pred = self.wp_tokenizer.decode(text)
  110. wp_pred_EOS = wp_pred.find('[SEP]')
  111. wp_pred = wp_pred[:wp_pred_EOS]
  112. if text_prob is not None:
  113. try:
  114. # print(text.cpu().tolist())
  115. wp_pred_EOS_index = text.cpu().tolist().index(102) + 1
  116. except:
  117. wp_pred_EOS_index = -1
  118. wp_pred_max_prob = text_prob[batch_idx][:wp_pred_EOS_index]
  119. try:
  120. wp_confidence_score = wp_pred_max_prob.cumprod(
  121. dim=0)[-1].cpu().numpy().sum()
  122. except:
  123. wp_confidence_score = 0.0
  124. else:
  125. wp_confidence_score = 1.0
  126. result_list.append((wp_pred, wp_confidence_score))
  127. return result_list