igtr_postprocess.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import numpy as np
  2. import torch
  3. from .nrtr_postprocess import NRTRLabelDecode
  4. class IGTRLabelDecode(NRTRLabelDecode):
  5. """Convert between text-label and text-index."""
  6. def __init__(self,
  7. character_dict_path=None,
  8. use_space_char=False,
  9. **kwargs):
  10. super(IGTRLabelDecode, self).__init__(character_dict_path,
  11. use_space_char)
  12. def __call__(self, preds, batch=None, *args, **kwargs):
  13. if isinstance(preds, list):
  14. if isinstance(preds[0], dict):
  15. preds = preds[-1].detach().cpu().numpy()
  16. if isinstance(preds, torch.Tensor):
  17. preds = preds.detach().cpu().numpy()
  18. elif isinstance(preds, dict):
  19. preds = preds['align'][-1].detach().cpu().numpy()
  20. else:
  21. preds = preds
  22. preds_idx = preds.argmax(axis=2)
  23. preds_prob = preds.max(axis=2)
  24. text = self.decode(preds_idx,
  25. preds_prob,
  26. is_remove_duplicate=False)
  27. else:
  28. preds_idx = preds[0].detach().cpu().numpy()
  29. preds_prob = preds[1].detach().cpu().numpy()
  30. text = self.decode(preds_idx,
  31. preds_prob,
  32. is_remove_duplicate=False)
  33. else:
  34. if isinstance(preds, torch.Tensor):
  35. preds = preds.detach().cpu().numpy()
  36. elif isinstance(preds, dict):
  37. preds = preds['align'][-1].detach().cpu().numpy()
  38. else:
  39. preds = preds
  40. preds_idx = preds.argmax(axis=2)
  41. preds_idx_top5 = preds.argsort(axis=2)[:, :, -5:]
  42. preds_prob = preds.max(axis=2)
  43. text = self.decode(preds_idx,
  44. preds_prob,
  45. is_remove_duplicate=False,
  46. idx_top5=preds_idx_top5)
  47. if batch is None:
  48. return text
  49. label = batch[1]
  50. label = self.decode(label)
  51. return text, label
  52. def add_special_char(self, dict_character):
  53. dict_character = ['</s>'] + dict_character + ['<s>', '<pad>']
  54. return dict_character
  55. def decode(self,
  56. text_index,
  57. text_prob=None,
  58. is_remove_duplicate=False,
  59. idx_top5=None):
  60. """convert text-index into text-label."""
  61. result_list = []
  62. batch_size = len(text_index)
  63. for batch_idx in range(batch_size):
  64. char_list = []
  65. char_list_top5 = []
  66. conf_list = []
  67. for idx in range(len(text_index[batch_idx])):
  68. char_idx_top5 = []
  69. try:
  70. char_idx = self.character[int(text_index[batch_idx][idx])]
  71. if idx_top5 is not None:
  72. for top5_i in idx_top5[batch_idx][idx]:
  73. char_idx_top5.append(self.character[top5_i])
  74. except:
  75. continue
  76. if char_idx == '</s>': # end
  77. break
  78. if char_idx == '<s>' or char_idx == '<pad>':
  79. continue
  80. char_list.append(char_idx)
  81. char_list_top5.append(char_idx_top5)
  82. if text_prob is not None:
  83. conf_list.append(text_prob[batch_idx][idx])
  84. else:
  85. conf_list.append(1)
  86. text = ''.join(char_list)
  87. if idx_top5 is not None:
  88. result_list.append(
  89. (text, [np.mean(conf_list).tolist(), char_list_top5]))
  90. else:
  91. result_list.append((text, np.mean(conf_list).tolist()))
  92. return result_list