char_postprocess.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import numpy as np
  2. import torch
  3. from .ctc_postprocess import BaseRecLabelDecode
  4. class CharLabelDecode(BaseRecLabelDecode):
  5. """Convert between text-label and text-index."""
  6. def __init__(self,
  7. character_dict_path=None,
  8. use_space_char=True,
  9. **kwargs):
  10. super(CharLabelDecode, self).__init__(character_dict_path,
  11. use_space_char)
  12. def __call__(self, preds, label=None, *args, **kwargs):
  13. if len(preds) >= 4:
  14. preds_id = preds[0]
  15. preds_prob = preds[1]
  16. char_preds = preds[2]
  17. if isinstance(preds_id, torch.Tensor):
  18. preds_id = preds_id.numpy()
  19. if isinstance(preds_prob, torch.Tensor):
  20. preds_prob = preds_prob.numpy()
  21. if preds_id[0][0] == 2:
  22. preds_idx = preds_id[:, 1:]
  23. preds_prob = preds_prob[:, 1:]
  24. # char_preds = char_preds[:, 1:]
  25. else:
  26. preds_idx = preds_id
  27. char_preds = char_preds.numpy()
  28. char_preds_idx = char_preds.argmax(-1) + 4
  29. char_preds_prob = char_preds.max(-1)
  30. text, text_box = self.decode(preds_idx, preds_prob, char_preds_idx,
  31. char_preds_prob)
  32. else:
  33. preds_logit = preds[0].numpy()
  34. char_preds = preds[1].numpy()
  35. # if isinstance(preds, torch.Tensor):
  36. # preds = preds.numpy()
  37. preds_idx = preds_logit.argmax(axis=2)
  38. preds_prob = preds_logit.max(axis=2)
  39. char_preds_idx = char_preds.argmax(-1) + 4
  40. char_preds_prob = char_preds.max(-1)
  41. text, text_box = self.decode(preds_idx, preds_prob, char_preds_idx,
  42. char_preds_prob)
  43. if label is None:
  44. return text, text_box
  45. label = self.decode(label[:, 1:])
  46. return text, text_box, label
  47. def add_special_char(self, dict_character):
  48. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  49. return dict_character
  50. def decode(
  51. self,
  52. text_index,
  53. text_prob=None,
  54. char_text_index=None,
  55. char_text_prob=None,
  56. is_remove_duplicate=False,
  57. ):
  58. """convert text-index into text-label."""
  59. result_list = []
  60. box_result_list = []
  61. batch_size = len(text_index)
  62. for batch_idx in range(batch_size):
  63. char_list = []
  64. conf_list = []
  65. char_box_list = []
  66. conf_box_list = []
  67. for idx in range(len(text_index[batch_idx])):
  68. try:
  69. char_idx = self.character[int(text_index[batch_idx][idx])]
  70. if char_text_index is not None:
  71. char_box_idx = self.character[int(
  72. char_text_index[batch_idx][idx])]
  73. except:
  74. continue
  75. if char_idx == '</s>': # end
  76. break
  77. char_list.append(char_idx)
  78. if char_text_index is not None:
  79. char_box_list.append(char_box_idx)
  80. if text_prob is not None:
  81. conf_list.append(text_prob[batch_idx][idx])
  82. else:
  83. conf_list.append(1)
  84. if char_text_prob is not None:
  85. conf_box_list.append(char_text_prob[batch_idx][idx])
  86. else:
  87. conf_box_list.append(1)
  88. text = ''.join(char_list)
  89. result_list.append((text, np.mean(conf_list).tolist()))
  90. if char_text_index is not None:
  91. text_box = ''.join(char_box_list)
  92. box_result_list.append(
  93. (text_box, np.mean(conf_box_list).tolist()))
  94. if char_text_index is not None:
  95. return result_list, box_result_list
  96. return result_list