ce_postprocess.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import torch
  2. from .ctc_postprocess import BaseRecLabelDecode
  3. class CELabelDecode(BaseRecLabelDecode):
  4. """Convert between text-label and text-index."""
  5. def __init__(self,
  6. character_dict_path=None,
  7. use_space_char=False,
  8. **kwargs):
  9. super(CELabelDecode, self).__init__(character_dict_path,
  10. use_space_char)
  11. def __call__(self, preds, label=None, *args, **kwargs):
  12. if isinstance(preds, tuple) or isinstance(preds, list):
  13. preds = preds[-1]
  14. if isinstance(preds, torch.Tensor):
  15. preds = preds.numpy()
  16. preds_idx = preds.argmax(axis=-1)
  17. preds_prob = preds.max(axis=-1)
  18. text = self.decode(preds_idx, preds_prob)
  19. if label is None:
  20. return text
  21. label = self.decode(label.flatten())
  22. return text, label
  23. def decode(self, text_index, text_prob=None):
  24. """convert text-index into text-label."""
  25. result_list = []
  26. batch_size = len(text_index)
  27. for batch_idx in range(batch_size):
  28. text = self.character[text_index[batch_idx]]
  29. if text_prob is not None:
  30. conf_list = text_prob[batch_idx]
  31. else:
  32. conf_list = 1.0
  33. result_list.append((text, conf_list))
  34. return result_list
  35. def add_special_char(self, dict_character):
  36. return dict_character