rec_metric.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import string
  2. import numpy as np
  3. from rapidfuzz.distance import Levenshtein
  4. def match_ss(ss1, ss2):
  5. s1_len = len(ss1)
  6. for c_i in range(s1_len):
  7. if ss1[c_i:] == ss2[:s1_len - c_i]:
  8. return ss2[s1_len - c_i:]
  9. return ss2
  10. def stream_match(text):
  11. bs = len(text)
  12. s_list = []
  13. conf_list = []
  14. for s_conf in text:
  15. s_list.append(s_conf[0])
  16. conf_list.append(s_conf[1])
  17. s_n = bs
  18. s_start = s_list[0][:-1]
  19. s_new = s_start
  20. for s_i in range(1, s_n):
  21. s_start = match_ss(
  22. s_start, s_list[s_i][1:-1] if s_i < s_n - 1 else s_list[s_i][1:])
  23. s_new += s_start
  24. return s_new, sum(conf_list) / bs
  25. class RecMetric(object):
  26. def __init__(self,
  27. main_indicator='acc',
  28. is_filter=False,
  29. is_lower=True,
  30. ignore_space=True,
  31. stream=False,
  32. with_ratio=False,
  33. max_len=25,
  34. max_ratio=4,
  35. **kwargs):
  36. self.main_indicator = main_indicator
  37. self.is_filter = is_filter
  38. self.is_lower = is_lower
  39. self.ignore_space = ignore_space
  40. self.stream = stream
  41. self.eps = 1e-5
  42. self.with_ratio = with_ratio
  43. self.max_len = max_len
  44. self.max_ratio = max_ratio
  45. self.reset()
  46. def _normalize_text(self, text):
  47. text = ''.join(
  48. filter(lambda x: x in (string.digits + string.ascii_letters),
  49. text))
  50. return text
  51. def __call__(self,
  52. pred_label,
  53. batch=None,
  54. training=False,
  55. *args,
  56. **kwargs):
  57. if self.with_ratio and not training:
  58. return self.eval_all_metric(pred_label, batch)
  59. else:
  60. return self.eval_metric(pred_label)
  61. def eval_metric(self, pred_label, *args, **kwargs):
  62. preds, labels = pred_label
  63. correct_num = 0
  64. all_num = 0
  65. norm_edit_dis = 0.0
  66. for (pred, pred_conf), (target, _) in zip(preds, labels):
  67. if self.stream:
  68. assert len(labels) == 1
  69. pred, _ = stream_match(preds)
  70. if self.ignore_space:
  71. pred = pred.replace(' ', '')
  72. target = target.replace(' ', '')
  73. if self.is_filter:
  74. pred = self._normalize_text(pred)
  75. target = self._normalize_text(target)
  76. if self.is_lower:
  77. pred = pred.lower()
  78. target = target.lower()
  79. norm_edit_dis += Levenshtein.normalized_distance(pred, target)
  80. if pred == target:
  81. correct_num += 1
  82. all_num += 1
  83. self.correct_num += correct_num
  84. self.all_num += all_num
  85. self.norm_edit_dis += norm_edit_dis
  86. return {
  87. 'acc': correct_num / (all_num + self.eps),
  88. 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
  89. }
  90. def eval_all_metric(self, pred_label, batch=None, *args, **kwargs):
  91. if self.with_ratio:
  92. ratio = batch[-1]
  93. preds, labels = pred_label
  94. correct_num = 0
  95. correct_num_real = 0
  96. correct_num_lower = 0
  97. correct_num_ignore_space = 0
  98. correct_num_ignore_space_lower = 0
  99. correct_num_ignore_space_symbol = 0
  100. all_num = 0
  101. norm_edit_dis = 0.0
  102. each_len_num = [0 for _ in range(self.max_len)]
  103. each_len_correct_num = [0 for _ in range(self.max_len)]
  104. each_len_norm_edit_dis = [0 for _ in range(self.max_len)]
  105. each_ratio_num = [0 for _ in range(self.max_ratio)]
  106. each_ratio_correct_num = [0 for _ in range(self.max_ratio)]
  107. each_ratio_norm_edit_dis = [0 for _ in range(self.max_ratio)]
  108. for (pred, pred_conf), (target, _) in zip(preds, labels):
  109. if self.stream:
  110. assert len(labels) == 1
  111. pred, _ = stream_match(preds)
  112. if pred == target:
  113. correct_num_real += 1
  114. if pred.lower() == target.lower():
  115. correct_num_lower += 1
  116. if self.ignore_space:
  117. pred = pred.replace(' ', '')
  118. target = target.replace(' ', '')
  119. if pred == target:
  120. correct_num_ignore_space += 1
  121. if pred.lower() == target.lower():
  122. correct_num_ignore_space_lower += 1
  123. if self.is_filter:
  124. pred = self._normalize_text(pred)
  125. target = self._normalize_text(target)
  126. if pred == target:
  127. correct_num_ignore_space_symbol += 1
  128. if self.is_lower:
  129. pred = pred.lower()
  130. target = target.lower()
  131. dis = Levenshtein.normalized_distance(pred, target)
  132. norm_edit_dis += dis
  133. ratio_i = ratio[all_num] - 1 if ratio[
  134. all_num] < self.max_ratio else self.max_ratio - 1
  135. len_i = max(0, min(self.max_len, len(target)) - 1)
  136. if pred == target:
  137. correct_num += 1
  138. each_len_correct_num[len_i] += 1
  139. each_ratio_correct_num[ratio_i] += 1
  140. each_len_num[len_i] += 1
  141. each_len_norm_edit_dis[len_i] += dis
  142. each_ratio_num[ratio_i] += 1
  143. each_ratio_norm_edit_dis[ratio_i] += dis
  144. all_num += 1
  145. self.correct_num += correct_num
  146. self.correct_num_real += correct_num_real
  147. self.correct_num_lower += correct_num_lower
  148. self.correct_num_ignore_space += correct_num_ignore_space
  149. self.correct_num_ignore_space_lower += correct_num_ignore_space_lower
  150. self.correct_num_ignore_space_symbol += correct_num_ignore_space_symbol
  151. self.all_num += all_num
  152. self.norm_edit_dis += norm_edit_dis
  153. self.each_len_num = self.each_len_num + np.array(each_len_num)
  154. self.each_len_correct_num = self.each_len_correct_num + np.array(
  155. each_len_correct_num)
  156. self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array(
  157. each_len_norm_edit_dis)
  158. self.each_ratio_num = self.each_ratio_num + np.array(each_ratio_num)
  159. self.each_ratio_correct_num = self.each_ratio_correct_num + np.array(
  160. each_ratio_correct_num)
  161. self.each_ratio_norm_edit_dis = self.each_ratio_norm_edit_dis + np.array(
  162. each_ratio_norm_edit_dis)
  163. return {
  164. 'acc': correct_num / (all_num + self.eps),
  165. 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
  166. }
  167. def get_metric(self, training=False):
  168. """
  169. return metrics {
  170. 'acc': 0,
  171. 'norm_edit_dis': 0,
  172. }
  173. """
  174. if self.with_ratio and not training:
  175. return self.get_all_metric()
  176. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  177. norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
  178. num_samples = self.all_num
  179. self.reset()
  180. return {
  181. 'acc': acc,
  182. 'norm_edit_dis': norm_edit_dis,
  183. 'num_samples': num_samples
  184. }
  185. def get_all_metric(self):
  186. """
  187. return metrics {
  188. 'acc': 0,
  189. 'norm_edit_dis': 0,
  190. }
  191. """
  192. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  193. acc_real = 1.0 * self.correct_num_real / (self.all_num + self.eps)
  194. acc_lower = 1.0 * self.correct_num_lower / (self.all_num + self.eps)
  195. acc_ignore_space = 1.0 * self.correct_num_ignore_space / (
  196. self.all_num + self.eps)
  197. acc_ignore_space_lower = 1.0 * self.correct_num_ignore_space_lower / (
  198. self.all_num + self.eps)
  199. acc_ignore_space_symbol = 1.0 * self.correct_num_ignore_space_symbol / (
  200. self.all_num + self.eps)
  201. norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
  202. num_samples = self.all_num
  203. each_len_acc = (self.each_len_correct_num /
  204. (self.each_len_num + self.eps)).tolist()
  205. each_len_norm_edit_dis = (1 -
  206. ((self.each_len_norm_edit_dis) /
  207. ((self.each_len_num) + self.eps))).tolist()
  208. each_len_num = self.each_len_num.tolist()
  209. each_ratio_acc = (self.each_ratio_correct_num /
  210. (self.each_ratio_num + self.eps)).tolist()
  211. each_ratio_norm_edit_dis = (1 - ((self.each_ratio_norm_edit_dis) / (
  212. (self.each_ratio_num) + self.eps))).tolist()
  213. each_ratio_num = self.each_ratio_num.tolist()
  214. self.reset()
  215. return {
  216. 'acc': acc,
  217. 'acc_real': acc_real,
  218. 'acc_lower': acc_lower,
  219. 'acc_ignore_space': acc_ignore_space,
  220. 'acc_ignore_space_lower': acc_ignore_space_lower,
  221. 'acc_ignore_space_symbol': acc_ignore_space_symbol,
  222. 'acc_ignore_space_lower_symbol': acc,
  223. 'each_len_num': each_len_num,
  224. 'each_len_acc': each_len_acc,
  225. 'each_len_norm_edit_dis': each_len_norm_edit_dis,
  226. 'each_ratio_num': each_ratio_num,
  227. 'each_ratio_acc': each_ratio_acc,
  228. 'each_ratio_norm_edit_dis': each_ratio_norm_edit_dis,
  229. 'norm_edit_dis': norm_edit_dis,
  230. 'num_samples': num_samples
  231. }
  232. def reset(self):
  233. self.correct_num = 0
  234. self.all_num = 0
  235. self.norm_edit_dis = 0
  236. self.correct_num_real = 0
  237. self.correct_num_lower = 0
  238. self.correct_num_ignore_space = 0
  239. self.correct_num_ignore_space_lower = 0
  240. self.correct_num_ignore_space_symbol = 0
  241. self.each_len_num = np.array([0 for _ in range(self.max_len)])
  242. self.each_len_correct_num = np.array([0 for _ in range(self.max_len)])
  243. self.each_len_norm_edit_dis = np.array(
  244. [0. for _ in range(self.max_len)])
  245. self.each_ratio_num = np.array([0 for _ in range(self.max_ratio)])
  246. self.each_ratio_correct_num = np.array(
  247. [0 for _ in range(self.max_ratio)])
  248. self.each_ratio_norm_edit_dis = np.array(
  249. [0. for _ in range(self.max_ratio)])