rec_metric_gtc.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from .rec_metric import RecMetric
  2. class RecGTCMetric(object):
  3. def __init__(self,
  4. main_indicator='acc',
  5. is_filter=False,
  6. ignore_space=True,
  7. stream=False,
  8. with_ratio=False,
  9. max_len=25,
  10. max_ratio=4,
  11. **kwargs):
  12. self.main_indicator = main_indicator
  13. self.is_filter = is_filter
  14. self.ignore_space = ignore_space
  15. self.eps = 1e-5
  16. self.gtc_metric = RecMetric(main_indicator=main_indicator,
  17. is_filter=is_filter,
  18. ignore_space=ignore_space,
  19. stream=stream,
  20. with_ratio=with_ratio,
  21. max_len=max_len,
  22. max_ratio=max_ratio)
  23. self.ctc_metric = RecMetric(main_indicator=main_indicator,
  24. is_filter=is_filter,
  25. ignore_space=ignore_space,
  26. stream=stream,
  27. with_ratio=with_ratio,
  28. max_len=max_len,
  29. max_ratio=max_ratio)
  30. def __call__(self,
  31. pred_label,
  32. batch=None,
  33. training=False,
  34. *args,
  35. **kwargs):
  36. ctc_metric = self.ctc_metric(pred_label[1], batch, training=training)
  37. gtc_metric = self.gtc_metric(pred_label[0], batch, training=training)
  38. ctc_metric['gtc_acc'] = gtc_metric['acc']
  39. ctc_metric['gtc_norm_edit_dis'] = gtc_metric['norm_edit_dis']
  40. return ctc_metric
  41. def get_metric(self):
  42. """
  43. return metrics {
  44. 'acc': 0,
  45. 'norm_edit_dis': 0,
  46. }
  47. """
  48. ctc_metric = self.ctc_metric.get_metric()
  49. gtc_metric = self.gtc_metric.get_metric()
  50. ctc_metric['gtc_acc'] = gtc_metric['acc']
  51. ctc_metric['gtc_norm_edit_dis'] = gtc_metric['norm_edit_dis']
  52. return ctc_metric