cppd_postprocess.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. from .nrtr_postprocess import NRTRLabelDecode
  3. class CPPDLabelDecode(NRTRLabelDecode):
  4. """Convert between text-label and text-index."""
  5. def __init__(self,
  6. character_dict_path=None,
  7. use_space_char=False,
  8. **kwargs):
  9. super(CPPDLabelDecode, self).__init__(character_dict_path,
  10. use_space_char)
  11. def __call__(self, preds, batch=None, *args, **kwargs):
  12. if isinstance(preds, tuple):
  13. if isinstance(preds[-1], dict):
  14. preds = preds[-1]['align'][-1].detach().cpu().numpy()
  15. else:
  16. preds = preds[-1].detach().cpu().numpy()
  17. if isinstance(preds, list):
  18. preds = preds[-1].detach().cpu().numpy()
  19. if isinstance(preds, torch.Tensor):
  20. preds = preds.detach().cpu().numpy()
  21. elif isinstance(preds, dict):
  22. preds = preds['align'][-1].detach().cpu().numpy()
  23. else:
  24. preds = preds
  25. preds_idx = preds.argmax(axis=2)
  26. preds_prob = preds.max(axis=2)
  27. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  28. if batch is None:
  29. return text
  30. label = batch[1]
  31. label = self.decode(label)
  32. return text, label
  33. def add_special_char(self, dict_character):
  34. dict_character = ['</s>'] + dict_character
  35. return dict_character