srn_postprocess.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import numpy as np
  2. import torch
  3. from .ctc_postprocess import BaseRecLabelDecode
  4. class SRNLabelDecode(BaseRecLabelDecode):
  5. """Convert between text-label and text-index."""
  6. def __init__(self,
  7. character_dict_path=None,
  8. use_space_char=False,
  9. **kwargs):
  10. super(SRNLabelDecode, self).__init__(character_dict_path,
  11. use_space_char)
  12. self.max_len = 25
  13. def add_special_char(self, dict_character):
  14. dict_character = dict_character + ['<BOS>', '<EOS>']
  15. self.start_idx = len(dict_character) - 2
  16. self.end_idx = len(dict_character) - 1
  17. return dict_character
  18. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  19. """convert text-index into text-label."""
  20. result_list = []
  21. ignored_tokens = self.get_ignored_tokens()
  22. # [B,25]
  23. batch_size = len(text_index)
  24. for batch_idx in range(batch_size):
  25. char_list = []
  26. conf_list = []
  27. for idx in range(len(text_index[batch_idx])):
  28. # print(f"text_index[{batch_idx}][{idx}]:{text_index[batch_idx][idx]}")
  29. if text_index[batch_idx][idx] in ignored_tokens:
  30. continue
  31. if int(text_index[batch_idx][idx]) == int(self.end_idx):
  32. if text_prob is None and idx == 0:
  33. continue
  34. else:
  35. break
  36. if is_remove_duplicate:
  37. # only for predict
  38. if idx > 0 and text_index[batch_idx][
  39. idx - 1] == text_index[batch_idx][idx]:
  40. continue
  41. char_list.append(self.character[int(
  42. text_index[batch_idx][idx])])
  43. if text_prob is not None:
  44. conf_list.append(text_prob[batch_idx][idx])
  45. else:
  46. conf_list.append(1)
  47. text = ''.join(char_list)
  48. result_list.append((text, np.mean(conf_list).tolist()))
  49. return result_list
  50. def __call__(self, preds, batch=None, *args, **kwargs):
  51. if isinstance(preds, torch.Tensor):
  52. preds = preds.reshape([-1, self.max_len, preds.shape[-1]])
  53. preds = preds.detach().cpu().numpy()
  54. else:
  55. preds = preds[-1]
  56. preds = preds.reshape([-1, self.max_len,
  57. preds.shape[-1]]).detach().cpu().numpy()
  58. preds_idx = preds.argmax(axis=2)
  59. preds_prob = preds.max(axis=2)
  60. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  61. if batch is None:
  62. return text
  63. label = batch[1]
  64. # print(f"label.shape:{label.shape}")
  65. label = self.decode(label, is_remove_duplicate=False)
  66. return text, label
  67. def get_ignored_tokens(self):
  68. return [self.start_idx, self.end_idx]