ar_postprocess.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import numpy as np
  2. import torch
  3. from .ctc_postprocess import BaseRecLabelDecode
  4. class ARLabelDecode(BaseRecLabelDecode):
  5. """Convert between text-label and text-index."""
  6. BOS = '<s>'
  7. EOS = '</s>'
  8. PAD = '<pad>'
  9. def __init__(self,
  10. character_dict_path=None,
  11. use_space_char=True,
  12. **kwargs):
  13. super(ARLabelDecode, self).__init__(character_dict_path,
  14. use_space_char)
  15. def __call__(self, preds, batch=None, *args, **kwargs):
  16. if isinstance(preds, list):
  17. preds = preds[-1]
  18. if isinstance(preds, torch.Tensor):
  19. preds = preds.detach().cpu().numpy()
  20. preds_idx = preds.argmax(axis=2)
  21. preds_prob = preds.max(axis=2)
  22. text = self.decode(preds_idx, preds_prob)
  23. if batch is None:
  24. return text
  25. label = batch[1]
  26. label = self.decode(label[:, 1:])
  27. return text, label
  28. def add_special_char(self, dict_character):
  29. dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
  30. return dict_character
  31. def decode(self, text_index, text_prob=None):
  32. """convert text-index into text-label."""
  33. result_list = []
  34. batch_size = len(text_index)
  35. for batch_idx in range(batch_size):
  36. char_list = []
  37. conf_list = []
  38. for idx in range(len(text_index[batch_idx])):
  39. try:
  40. char_idx = self.character[int(text_index[batch_idx][idx])]
  41. except:
  42. continue
  43. if char_idx == self.EOS: # end
  44. break
  45. if char_idx == self.BOS or char_idx == self.PAD:
  46. continue
  47. char_list.append(char_idx)
  48. if text_prob is not None:
  49. conf_list.append(text_prob[batch_idx][idx])
  50. else:
  51. conf_list.append(1)
  52. text = ''.join(char_list)
  53. result_list.append((text, np.mean(conf_list).tolist()))
  54. return result_list