abinet_postprocess.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. from .nrtr_postprocess import NRTRLabelDecode
  3. class ABINetLabelDecode(NRTRLabelDecode):
  4. """Convert between text-label and text-index."""
  5. def __init__(self,
  6. character_dict_path=None,
  7. use_space_char=False,
  8. **kwargs):
  9. super(ABINetLabelDecode, self).__init__(character_dict_path,
  10. use_space_char)
  11. def __call__(self, preds, batch=None, *args, **kwargs):
  12. if isinstance(preds, dict):
  13. if len(preds['align']) > 0:
  14. preds = preds['align'][-1].detach().cpu().numpy()
  15. else:
  16. preds = preds['vision'].detach().cpu().numpy()
  17. elif isinstance(preds, torch.Tensor):
  18. preds = preds.detach().cpu().numpy()
  19. else:
  20. preds = preds
  21. preds_idx = preds.argmax(axis=2)
  22. preds_prob = preds.max(axis=2)
  23. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  24. if batch is None:
  25. return text
  26. label = self.decode(batch[1])
  27. return text, label
  28. def add_special_char(self, dict_character):
  29. dict_character = ['</s>'] + dict_character
  30. return dict_character