nrtr_postprocess.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import numpy as np
  2. import torch
  3. from .ctc_postprocess import BaseRecLabelDecode
  4. class NRTRLabelDecode(BaseRecLabelDecode):
  5. """Convert between text-label and text-index."""
  6. def __init__(self,
  7. character_dict_path=None,
  8. use_space_char=True,
  9. **kwargs):
  10. super(NRTRLabelDecode, self).__init__(character_dict_path,
  11. use_space_char)
  12. def __call__(self, preds, batch=None, *args, **kwargs):
  13. preds = preds['res']
  14. if len(preds) == 2:
  15. preds_id = preds[0]
  16. preds_prob = preds[1]
  17. if isinstance(preds_id, torch.Tensor):
  18. preds_id = preds_id.detach().cpu().numpy()
  19. if isinstance(preds_prob, torch.Tensor):
  20. preds_prob = preds_prob.detach().cpu().numpy()
  21. if preds_id[0][0] == 2:
  22. preds_idx = preds_id[:, 1:]
  23. preds_prob = preds_prob[:, 1:]
  24. else:
  25. preds_idx = preds_id
  26. text = self.decode(preds_idx,
  27. preds_prob,
  28. is_remove_duplicate=False)
  29. if batch is None:
  30. return text
  31. label = self.decode(batch[1][:, 1:])
  32. else:
  33. if isinstance(preds, torch.Tensor):
  34. preds = preds.detach().cpu().numpy()
  35. preds_idx = preds.argmax(axis=2)
  36. preds_prob = preds.max(axis=2)
  37. text = self.decode(preds_idx,
  38. preds_prob,
  39. is_remove_duplicate=False)
  40. if batch is None:
  41. return text
  42. label = self.decode(batch[1][:, 1:])
  43. return text, label
  44. def add_special_char(self, dict_character):
  45. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  46. return dict_character
  47. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  48. """convert text-index into text-label."""
  49. result_list = []
  50. batch_size = len(text_index)
  51. for batch_idx in range(batch_size):
  52. char_list = []
  53. conf_list = []
  54. for idx in range(len(text_index[batch_idx])):
  55. try:
  56. char_idx = self.character[int(text_index[batch_idx][idx])]
  57. except:
  58. continue
  59. if char_idx == '</s>': # end
  60. break
  61. char_list.append(char_idx)
  62. if text_prob is not None:
  63. conf_list.append(text_prob[batch_idx][idx])
  64. else:
  65. conf_list.append(1)
  66. text = ''.join(char_list)
  67. result_list.append((text, np.mean(conf_list).tolist()))
  68. return result_list