123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- from .rec_metric import RecMetric
- class RecMPGMetric(object):
- def __init__(self,
- main_indicator='acc',
- is_filter=False,
- ignore_space=True,
- stream=False,
- with_ratio=False,
- max_len=25,
- max_ratio=4,
- **kwargs):
- self.main_indicator = main_indicator
- self.is_filter = is_filter
- self.ignore_space = ignore_space
- self.eps = 1e-5
- self.char_metric = RecMetric(main_indicator=main_indicator,
- is_filter=is_filter,
- ignore_space=ignore_space,
- stream=stream,
- with_ratio=with_ratio,
- max_len=max_len,
- max_ratio=max_ratio)
- self.bpe_metric = RecMetric(main_indicator=main_indicator,
- is_filter=is_filter,
- ignore_space=ignore_space,
- stream=stream,
- with_ratio=with_ratio,
- max_len=max_len,
- max_ratio=max_ratio)
- self.wp_metric = RecMetric(main_indicator=main_indicator,
- is_filter=is_filter,
- ignore_space=ignore_space,
- stream=stream,
- with_ratio=with_ratio,
- max_len=max_len,
- max_ratio=max_ratio)
- self.final_metric = RecMetric(main_indicator=main_indicator,
- is_filter=is_filter,
- ignore_space=ignore_space,
- stream=stream,
- with_ratio=with_ratio,
- max_len=max_len,
- max_ratio=max_ratio)
- def __call__(self,
- pred_label,
- batch=None,
- training=False,
- *args,
- **kwargs):
- char_metric = self.char_metric((pred_label[0], pred_label[-1]),
- batch,
- training=training)
- bpe_metric = self.bpe_metric((pred_label[1], pred_label[-1]),
- batch,
- training=training)
- wp_metric = self.wp_metric((pred_label[2], pred_label[-1]),
- batch,
- training=training)
- final_metric = self.final_metric((pred_label[3], pred_label[-1]),
- batch,
- training=training)
- final_metric['char_acc'] = char_metric['acc']
- final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis']
- final_metric['bpe_acc'] = bpe_metric['acc']
- final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis']
- final_metric['wp_acc'] = wp_metric['acc']
- final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis']
- return final_metric
- def get_metric(self):
- """
- return metrics {
- 'acc': 0,
- 'norm_edit_dis': 0,
- }
- """
- char_metric = self.char_metric.get_metric()
- bpe_metric = self.bpe_metric.get_metric()
- wp_metric = self.wp_metric.get_metric()
- final_metric = self.final_metric.get_metric()
- final_metric['char_acc'] = char_metric['acc']
- final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis']
- final_metric['bpe_acc'] = bpe_metric['acc']
- final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis']
- final_metric['wp_acc'] = wp_metric['acc']
- final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis']
- return final_metric
|