rec_metric_long.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import string
  2. import numpy as np
  3. from rapidfuzz.distance import Levenshtein
  4. from .rec_metric import stream_match
  5. # f_pred = open('pred_focal_subs_rand1_h2_bi_first.txt', 'w')
  6. class RecMetricLong(object):
  7. def __init__(self,
  8. main_indicator='acc',
  9. is_filter=False,
  10. ignore_space=True,
  11. stream=False,
  12. **kwargs):
  13. self.main_indicator = main_indicator
  14. self.is_filter = is_filter
  15. self.ignore_space = ignore_space
  16. self.stream = stream
  17. self.eps = 1e-5
  18. self.max_len = 201
  19. self.reset()
  20. def _normalize_text(self, text):
  21. text = ''.join(
  22. filter(lambda x: x in (string.digits + string.ascii_letters),
  23. text))
  24. return text.lower()
  25. def __call__(self, pred_label, *args, **kwargs):
  26. preds, labels = pred_label
  27. correct_num = 0
  28. correct_num_slice = 0
  29. f_l_acc = 0
  30. all_num = 0
  31. norm_edit_dis = 0.0
  32. len_acc = 0
  33. each_len_num = [0 for _ in range(self.max_len)]
  34. each_len_correct_num = [0 for _ in range(self.max_len)]
  35. each_len_norm_edit_dis = [0 for _ in range(self.max_len)]
  36. for (pred, pred_conf), (target, _) in zip(preds, labels):
  37. if self.stream:
  38. assert len(labels) == 1
  39. pred, _ = stream_match(preds)
  40. if self.ignore_space:
  41. pred = pred.replace(' ', '')
  42. target = target.replace(' ', '')
  43. if self.is_filter:
  44. pred = self._normalize_text(pred)
  45. target = self._normalize_text(target)
  46. dis = Levenshtein.normalized_distance(pred, target)
  47. norm_edit_dis += dis
  48. # print(pred, target)
  49. if pred == target:
  50. correct_num += 1
  51. each_len_correct_num[len(target)] += 1
  52. each_len_num[len(target)] += 1
  53. each_len_norm_edit_dis[len(target)] += dis
  54. # f_pred.write(pred+'\t'+target+'\t1'+'\n')
  55. # print(pred, target, 1)
  56. # else:
  57. # f_pred.write(pred+'\t'+target+'\t0'+'\n')
  58. # print(pred, target, 0)
  59. if len(pred) >= 1 and len(target) >= 1:
  60. if pred[0] == target[0] and pred[-1] == target[-1]:
  61. f_l_acc += 1
  62. if len(pred) == len(target):
  63. len_acc += 1
  64. if pred == target[:len(pred)]:
  65. # if pred == target[-len(pred):]:
  66. correct_num_slice += 1
  67. all_num += 1
  68. self.correct_num += correct_num
  69. self.correct_num_slice += correct_num_slice
  70. self.f_l_acc += f_l_acc
  71. self.all_num += all_num
  72. self.len_acc += len_acc
  73. self.each_len_num = self.each_len_num + np.array(each_len_num)
  74. self.each_len_correct_num = self.each_len_correct_num + np.array(
  75. each_len_correct_num)
  76. self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array(
  77. each_len_norm_edit_dis)
  78. self.norm_edit_dis += norm_edit_dis
  79. return {
  80. 'acc': correct_num / (all_num + self.eps),
  81. 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
  82. }
  83. def get_metric(self):
  84. """
  85. return metrics {
  86. 'acc': 0,
  87. 'norm_edit_dis': 0,
  88. }
  89. """
  90. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  91. acc_slice = 1.0 * self.correct_num_slice / (self.all_num + self.eps)
  92. f_l_acc = 1.0 * self.f_l_acc / (self.all_num + self.eps)
  93. len_acc = 1.0 * self.len_acc / (self.all_num + self.eps)
  94. norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
  95. each_len_acc = (self.each_len_correct_num /
  96. (self.each_len_num + self.eps)).tolist()
  97. # each_len_acc_25 = each_len_acc[:26]
  98. # each_len_acc_26 = each_len_acc[26:]
  99. each_len_norm_edit_dis = (1 -
  100. ((self.each_len_norm_edit_dis) /
  101. ((self.each_len_num) + self.eps))).tolist()
  102. # each_len_norm_edit_dis_25 = each_len_norm_edit_dis[:26]
  103. # each_len_norm_edit_dis_26 = each_len_norm_edit_dis[26:]
  104. each_len_num = self.each_len_num.tolist()
  105. all_num = self.all_num
  106. self.reset()
  107. return {
  108. 'acc': acc,
  109. 'norm_edit_dis': norm_edit_dis,
  110. 'acc_slice': acc_slice,
  111. 'f_l_acc': f_l_acc,
  112. 'len_acc': len_acc,
  113. 'each_len_num': each_len_num,
  114. 'each_len_acc': each_len_acc,
  115. # "each_len_acc_25": each_len_acc_25,
  116. # "each_len_acc_26": each_len_acc_26,
  117. 'each_len_norm_edit_dis': each_len_norm_edit_dis,
  118. # "each_len_norm_edit_dis_25":each_len_norm_edit_dis_25,
  119. # "each_len_norm_edit_dis_26":each_len_norm_edit_dis_26,
  120. 'all_num': all_num
  121. }
  122. def reset(self):
  123. self.correct_num = 0
  124. self.all_num = 0
  125. self.norm_edit_dis = 0
  126. self.correct_num_slice = 0
  127. self.each_len_num = np.array([0 for _ in range(self.max_len)])
  128. self.each_len_correct_num = np.array([0 for _ in range(self.max_len)])
  129. self.each_len_norm_edit_dis = np.array(
  130. [0. for _ in range(self.max_len)])
  131. self.f_l_acc = 0
  132. self.len_acc = 0