rec_metric_mgp.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from .rec_metric import RecMetric
  2. class RecMPGMetric(object):
  3. def __init__(self,
  4. main_indicator='acc',
  5. is_filter=False,
  6. ignore_space=True,
  7. stream=False,
  8. with_ratio=False,
  9. max_len=25,
  10. max_ratio=4,
  11. **kwargs):
  12. self.main_indicator = main_indicator
  13. self.is_filter = is_filter
  14. self.ignore_space = ignore_space
  15. self.eps = 1e-5
  16. self.char_metric = RecMetric(main_indicator=main_indicator,
  17. is_filter=is_filter,
  18. ignore_space=ignore_space,
  19. stream=stream,
  20. with_ratio=with_ratio,
  21. max_len=max_len,
  22. max_ratio=max_ratio)
  23. self.bpe_metric = RecMetric(main_indicator=main_indicator,
  24. is_filter=is_filter,
  25. ignore_space=ignore_space,
  26. stream=stream,
  27. with_ratio=with_ratio,
  28. max_len=max_len,
  29. max_ratio=max_ratio)
  30. self.wp_metric = RecMetric(main_indicator=main_indicator,
  31. is_filter=is_filter,
  32. ignore_space=ignore_space,
  33. stream=stream,
  34. with_ratio=with_ratio,
  35. max_len=max_len,
  36. max_ratio=max_ratio)
  37. self.final_metric = RecMetric(main_indicator=main_indicator,
  38. is_filter=is_filter,
  39. ignore_space=ignore_space,
  40. stream=stream,
  41. with_ratio=with_ratio,
  42. max_len=max_len,
  43. max_ratio=max_ratio)
  44. def __call__(self,
  45. pred_label,
  46. batch=None,
  47. training=False,
  48. *args,
  49. **kwargs):
  50. char_metric = self.char_metric((pred_label[0], pred_label[-1]),
  51. batch,
  52. training=training)
  53. bpe_metric = self.bpe_metric((pred_label[1], pred_label[-1]),
  54. batch,
  55. training=training)
  56. wp_metric = self.wp_metric((pred_label[2], pred_label[-1]),
  57. batch,
  58. training=training)
  59. final_metric = self.final_metric((pred_label[3], pred_label[-1]),
  60. batch,
  61. training=training)
  62. final_metric['char_acc'] = char_metric['acc']
  63. final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis']
  64. final_metric['bpe_acc'] = bpe_metric['acc']
  65. final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis']
  66. final_metric['wp_acc'] = wp_metric['acc']
  67. final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis']
  68. return final_metric
  69. def get_metric(self):
  70. """
  71. return metrics {
  72. 'acc': 0,
  73. 'norm_edit_dis': 0,
  74. }
  75. """
  76. char_metric = self.char_metric.get_metric()
  77. bpe_metric = self.bpe_metric.get_metric()
  78. wp_metric = self.wp_metric.get_metric()
  79. final_metric = self.final_metric.get_metric()
  80. final_metric['char_acc'] = char_metric['acc']
  81. final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis']
  82. final_metric['bpe_acc'] = bpe_metric['acc']
  83. final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis']
  84. final_metric['wp_acc'] = wp_metric['acc']
  85. final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis']
  86. return final_metric