lister_postprocess.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import numpy as np
  2. import torch
  3. from openrec.postprocess.ctc_postprocess import BaseRecLabelDecode
  4. class LISTERLabelDecode(BaseRecLabelDecode):
  5. """Convert between text-label and text-index."""
  6. def __init__(self,
  7. character_dict_path=None,
  8. use_space_char=True,
  9. **kwargs):
  10. super(LISTERLabelDecode, self).__init__(character_dict_path,
  11. use_space_char)
  12. def __call__(self, preds, batch=None, *args, **kwargs):
  13. preds = preds[1]['logits']
  14. if isinstance(preds, torch.Tensor):
  15. preds = preds.detach().cpu().numpy()
  16. preds_idx = preds.argmax(axis=2)
  17. # preds_idx_top5 = preds.argsort(axis=2)[:, :, -5:]
  18. preds_prob = preds.max(axis=2)
  19. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  20. if batch is None:
  21. return text
  22. label = batch[1]
  23. label = self.decode(label)
  24. return text, label
  25. def add_special_char(self, dict_character):
  26. dict_character = ['</s>'] + dict_character + ['<pad>']
  27. return dict_character
  28. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  29. """convert text-index into text-label."""
  30. result_list = []
  31. batch_size = len(text_index)
  32. for batch_idx in range(batch_size):
  33. char_list = []
  34. conf_list = []
  35. for idx in range(len(text_index[batch_idx])):
  36. try:
  37. char_idx = self.character[int(text_index[batch_idx][idx])]
  38. except:
  39. continue
  40. if char_idx == '</s>': # end
  41. break
  42. if char_idx == '<s>' or char_idx == '<pad>':
  43. continue
  44. char_list.append(char_idx)
  45. if text_prob is not None:
  46. conf_list.append(text_prob[batch_idx][idx])
  47. else:
  48. conf_list.append(1)
  49. text = ''.join(char_list)
  50. result_list.append((text, np.mean(conf_list).tolist()))
  51. return result_list