smtr_postprocess.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import numpy as np
  2. import torch
  3. from .ctc_postprocess import BaseRecLabelDecode
  4. class SMTRLabelDecode(BaseRecLabelDecode):
  5. """Convert between text-label and text-index."""
  6. BOS = '<s>'
  7. EOS = '</s>'
  8. IN_F = '<INF>' # ignore
  9. IN_B = '<INB>' # ignore
  10. PAD = '<pad>'
  11. def __init__(self,
  12. character_dict_path=None,
  13. use_space_char=True,
  14. next_mode=True,
  15. **kwargs):
  16. super(SMTRLabelDecode, self).__init__(character_dict_path,
  17. use_space_char)
  18. self.next_mode = next_mode
  19. def __call__(self, preds, batch=None, *args, **kwargs):
  20. if isinstance(preds, list):
  21. preds = preds[-1]
  22. if isinstance(preds, torch.Tensor):
  23. preds = preds.detach().cpu().numpy()
  24. preds_idx = preds.argmax(axis=2)
  25. preds_prob = preds.max(axis=2)
  26. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  27. if batch is None:
  28. return text
  29. label = batch[1]
  30. label = self.decode(label[:, 1:])
  31. return text, label
  32. def add_special_char(self, dict_character):
  33. dict_character = [self.EOS] + dict_character + [
  34. self.BOS, self.IN_F, self.IN_B, self.PAD
  35. ]
  36. self.num_character = len(dict_character)
  37. return dict_character
  38. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  39. """convert text-index into text-label."""
  40. result_list = []
  41. batch_size = len(text_index)
  42. for batch_idx in range(batch_size):
  43. char_list = []
  44. conf_list = []
  45. for idx in range(len(text_index[batch_idx])):
  46. try:
  47. char_idx = self.character[int(text_index[batch_idx][idx])]
  48. except:
  49. continue
  50. if char_idx == '</s>': # end
  51. break
  52. if char_idx == '<s>' or char_idx == '<pad>':
  53. continue
  54. char_list.append(char_idx)
  55. if text_prob is not None:
  56. conf_list.append(text_prob[batch_idx][idx])
  57. else:
  58. conf_list.append(1)
  59. if self.next_mode or text_prob is None:
  60. text = ''.join(char_list)
  61. else:
  62. text = ''.join(char_list[::-1])
  63. result_list.append((text, np.mean(conf_list).tolist()))
  64. return result_list