Deteval.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. import json
  2. import numpy as np
  3. import scipy.io as io
  4. from tools.utils.utility import check_install
  5. from tools.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
  6. def get_socre_A(gt_dir, pred_dict):
  7. allInputs = 1
  8. def input_reading_mod(pred_dict):
  9. """This helper reads input from txt files"""
  10. det = []
  11. n = len(pred_dict)
  12. for i in range(n):
  13. points = pred_dict[i]["points"]
  14. text = pred_dict[i]["texts"]
  15. point = ",".join(map(
  16. str,
  17. points.reshape(-1, ), ))
  18. det.append([point, text])
  19. return det
  20. def gt_reading_mod(gt_dict):
  21. """This helper reads groundtruths from mat files"""
  22. gt = []
  23. n = len(gt_dict)
  24. for i in range(n):
  25. points = gt_dict[i]["points"].tolist()
  26. h = len(points)
  27. text = gt_dict[i]["text"]
  28. xx = [
  29. np.array(
  30. ["x:"], dtype="<U2"),
  31. 0,
  32. np.array(
  33. ["y:"], dtype="<U2"),
  34. 0,
  35. np.array(
  36. ["#"], dtype="<U1"),
  37. np.array(
  38. ["#"], dtype="<U1"),
  39. ]
  40. t_x, t_y = [], []
  41. for j in range(h):
  42. t_x.append(points[j][0])
  43. t_y.append(points[j][1])
  44. xx[1] = np.array([t_x], dtype="int16")
  45. xx[3] = np.array([t_y], dtype="int16")
  46. if text != "":
  47. xx[4] = np.array([text], dtype="U{}".format(len(text)))
  48. xx[5] = np.array(["c"], dtype="<U1")
  49. gt.append(xx)
  50. return gt
  51. def detection_filtering(detections, groundtruths, threshold=0.5):
  52. for gt_id, gt in enumerate(groundtruths):
  53. if (gt[5] == "#") and (gt[1].shape[1] > 1):
  54. gt_x = list(map(int, np.squeeze(gt[1])))
  55. gt_y = list(map(int, np.squeeze(gt[3])))
  56. for det_id, detection in enumerate(detections):
  57. detection_orig = detection
  58. detection = [float(x) for x in detection[0].split(",")]
  59. detection = list(map(int, detection))
  60. det_x = detection[0::2]
  61. det_y = detection[1::2]
  62. det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
  63. if det_gt_iou > threshold:
  64. detections[det_id] = []
  65. detections[:] = [item for item in detections if item != []]
  66. return detections
  67. def sigma_calculation(det_x, det_y, gt_x, gt_y):
  68. """
  69. sigma = inter_area / gt_area
  70. """
  71. return np.round(
  72. (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)),
  73. 2)
  74. def tau_calculation(det_x, det_y, gt_x, gt_y):
  75. if area(det_x, det_y) == 0.0:
  76. return 0
  77. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  78. area(det_x, det_y)), 2)
  79. ##############################Initialization###################################
  80. # global_sigma = []
  81. # global_tau = []
  82. # global_pred_str = []
  83. # global_gt_str = []
  84. ###############################################################################
  85. for input_id in range(allInputs):
  86. if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and
  87. (input_id != "Pascal_result_curved.txt") and
  88. (input_id != "Pascal_result_non_curved.txt") and
  89. (input_id != "Deteval_result.txt") and
  90. (input_id != "Deteval_result_curved.txt") and
  91. (input_id != "Deteval_result_non_curved.txt")):
  92. detections = input_reading_mod(pred_dict)
  93. groundtruths = gt_reading_mod(gt_dir)
  94. detections = detection_filtering(
  95. detections,
  96. groundtruths) # filters detections overlapping with DC area
  97. dc_id = []
  98. for i in range(len(groundtruths)):
  99. if groundtruths[i][5] == "#":
  100. dc_id.append(i)
  101. cnt = 0
  102. for a in dc_id:
  103. num = a - cnt
  104. del groundtruths[num]
  105. cnt += 1
  106. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  107. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  108. local_pred_str = {}
  109. local_gt_str = {}
  110. for gt_id, gt in enumerate(groundtruths):
  111. if len(detections) > 0:
  112. for det_id, detection in enumerate(detections):
  113. detection_orig = detection
  114. detection = [float(x) for x in detection[0].split(",")]
  115. detection = list(map(int, detection))
  116. pred_seq_str = detection_orig[1].strip()
  117. det_x = detection[0::2]
  118. det_y = detection[1::2]
  119. gt_x = list(map(int, np.squeeze(gt[1])))
  120. gt_y = list(map(int, np.squeeze(gt[3])))
  121. gt_seq_str = str(gt[4].tolist()[0])
  122. local_sigma_table[gt_id, det_id] = sigma_calculation(
  123. det_x, det_y, gt_x, gt_y)
  124. local_tau_table[gt_id, det_id] = tau_calculation(
  125. det_x, det_y, gt_x, gt_y)
  126. local_pred_str[det_id] = pred_seq_str
  127. local_gt_str[gt_id] = gt_seq_str
  128. global_sigma = local_sigma_table
  129. global_tau = local_tau_table
  130. global_pred_str = local_pred_str
  131. global_gt_str = local_gt_str
  132. single_data = {}
  133. single_data["sigma"] = global_sigma
  134. single_data["global_tau"] = global_tau
  135. single_data["global_pred_str"] = global_pred_str
  136. single_data["global_gt_str"] = global_gt_str
  137. return single_data
  138. def get_socre_B(gt_dir, img_id, pred_dict):
  139. allInputs = 1
  140. def input_reading_mod(pred_dict):
  141. """This helper reads input from txt files"""
  142. det = []
  143. n = len(pred_dict)
  144. for i in range(n):
  145. points = pred_dict[i]["points"]
  146. text = pred_dict[i]["texts"]
  147. point = ",".join(map(
  148. str,
  149. points.reshape(-1, ), ))
  150. det.append([point, text])
  151. return det
  152. def gt_reading_mod(gt_dir, gt_id):
  153. gt = io.loadmat("%s/poly_gt_img%s.mat" % (gt_dir, gt_id))
  154. gt = gt["polygt"]
  155. return gt
  156. def detection_filtering(detections, groundtruths, threshold=0.5):
  157. for gt_id, gt in enumerate(groundtruths):
  158. if (gt[5] == "#") and (gt[1].shape[1] > 1):
  159. gt_x = list(map(int, np.squeeze(gt[1])))
  160. gt_y = list(map(int, np.squeeze(gt[3])))
  161. for det_id, detection in enumerate(detections):
  162. detection_orig = detection
  163. detection = [float(x) for x in detection[0].split(",")]
  164. detection = list(map(int, detection))
  165. det_x = detection[0::2]
  166. det_y = detection[1::2]
  167. det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
  168. if det_gt_iou > threshold:
  169. detections[det_id] = []
  170. detections[:] = [item for item in detections if item != []]
  171. return detections
  172. def sigma_calculation(det_x, det_y, gt_x, gt_y):
  173. """
  174. sigma = inter_area / gt_area
  175. """
  176. return np.round(
  177. (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)),
  178. 2)
  179. def tau_calculation(det_x, det_y, gt_x, gt_y):
  180. if area(det_x, det_y) == 0.0:
  181. return 0
  182. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  183. area(det_x, det_y)), 2)
  184. ##############################Initialization###################################
  185. # global_sigma = []
  186. # global_tau = []
  187. # global_pred_str = []
  188. # global_gt_str = []
  189. ###############################################################################
  190. for input_id in range(allInputs):
  191. if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and
  192. (input_id != "Pascal_result_curved.txt") and
  193. (input_id != "Pascal_result_non_curved.txt") and
  194. (input_id != "Deteval_result.txt") and
  195. (input_id != "Deteval_result_curved.txt") and
  196. (input_id != "Deteval_result_non_curved.txt")):
  197. detections = input_reading_mod(pred_dict)
  198. groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
  199. detections = detection_filtering(
  200. detections,
  201. groundtruths) # filters detections overlapping with DC area
  202. dc_id = []
  203. for i in range(len(groundtruths)):
  204. if groundtruths[i][5] == "#":
  205. dc_id.append(i)
  206. cnt = 0
  207. for a in dc_id:
  208. num = a - cnt
  209. del groundtruths[num]
  210. cnt += 1
  211. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  212. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  213. local_pred_str = {}
  214. local_gt_str = {}
  215. for gt_id, gt in enumerate(groundtruths):
  216. if len(detections) > 0:
  217. for det_id, detection in enumerate(detections):
  218. detection_orig = detection
  219. detection = [float(x) for x in detection[0].split(",")]
  220. detection = list(map(int, detection))
  221. pred_seq_str = detection_orig[1].strip()
  222. det_x = detection[0::2]
  223. det_y = detection[1::2]
  224. gt_x = list(map(int, np.squeeze(gt[1])))
  225. gt_y = list(map(int, np.squeeze(gt[3])))
  226. gt_seq_str = str(gt[4].tolist()[0])
  227. local_sigma_table[gt_id, det_id] = sigma_calculation(
  228. det_x, det_y, gt_x, gt_y)
  229. local_tau_table[gt_id, det_id] = tau_calculation(
  230. det_x, det_y, gt_x, gt_y)
  231. local_pred_str[det_id] = pred_seq_str
  232. local_gt_str[gt_id] = gt_seq_str
  233. global_sigma = local_sigma_table
  234. global_tau = local_tau_table
  235. global_pred_str = local_pred_str
  236. global_gt_str = local_gt_str
  237. single_data = {}
  238. single_data["sigma"] = global_sigma
  239. single_data["global_tau"] = global_tau
  240. single_data["global_pred_str"] = global_pred_str
  241. single_data["global_gt_str"] = global_gt_str
  242. return single_data
  243. def get_score_C(gt_label, text, pred_bboxes):
  244. """
  245. get score for CentripetalText (CT) prediction.
  246. """
  247. check_install("Polygon", "Polygon3")
  248. import Polygon as plg
  249. def gt_reading_mod(gt_label, text):
  250. """This helper reads groundtruths from mat files"""
  251. groundtruths = []
  252. nbox = len(gt_label)
  253. for i in range(nbox):
  254. label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
  255. groundtruths.append(label)
  256. return groundtruths
  257. def get_union(pD, pG):
  258. areaA = pD.area()
  259. areaB = pG.area()
  260. return areaA + areaB - get_intersection(pD, pG)
  261. def get_intersection(pD, pG):
  262. pInt = pD & pG
  263. if len(pInt) == 0:
  264. return 0
  265. return pInt.area()
  266. def detection_filtering(detections, groundtruths, threshold=0.5):
  267. for gt in groundtruths:
  268. point_num = gt["points"].shape[1] // 2
  269. if gt["transcription"] == "###" and (point_num > 1):
  270. gt_p = np.array(gt["points"]).reshape(point_num,
  271. 2).astype("int32")
  272. gt_p = plg.Polygon(gt_p)
  273. for det_id, detection in enumerate(detections):
  274. det_y = detection[0::2]
  275. det_x = detection[1::2]
  276. det_p = np.concatenate((np.array(det_x), np.array(det_y)))
  277. det_p = det_p.reshape(2, -1).transpose()
  278. det_p = plg.Polygon(det_p)
  279. try:
  280. det_gt_iou = get_intersection(det_p,
  281. gt_p) / det_p.area()
  282. except:
  283. print(det_x, det_y, gt_p)
  284. if det_gt_iou > threshold:
  285. detections[det_id] = []
  286. detections[:] = [item for item in detections if item != []]
  287. return detections
  288. def sigma_calculation(det_p, gt_p):
  289. """
  290. sigma = inter_area / gt_area
  291. """
  292. if gt_p.area() == 0.0:
  293. return 0
  294. return get_intersection(det_p, gt_p) / gt_p.area()
  295. def tau_calculation(det_p, gt_p):
  296. """
  297. tau = inter_area / det_area
  298. """
  299. if det_p.area() == 0.0:
  300. return 0
  301. return get_intersection(det_p, gt_p) / det_p.area()
  302. detections = []
  303. for item in pred_bboxes:
  304. detections.append(item[:, ::-1].reshape(-1))
  305. groundtruths = gt_reading_mod(gt_label, text)
  306. detections = detection_filtering(
  307. detections, groundtruths) # filters detections overlapping with DC area
  308. for idx in range(len(groundtruths) - 1, -1, -1):
  309. # NOTE: source code use 'orin' to indicate '#', here we use 'anno',
  310. # which may cause slight drop in fscore, about 0.12
  311. if groundtruths[idx]["transcription"] == "###":
  312. groundtruths.pop(idx)
  313. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  314. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  315. for gt_id, gt in enumerate(groundtruths):
  316. if len(detections) > 0:
  317. for det_id, detection in enumerate(detections):
  318. point_num = gt["points"].shape[1] // 2
  319. gt_p = np.array(gt["points"]).reshape(point_num,
  320. 2).astype("int32")
  321. gt_p = plg.Polygon(gt_p)
  322. det_y = detection[0::2]
  323. det_x = detection[1::2]
  324. det_p = np.concatenate((np.array(det_x), np.array(det_y)))
  325. det_p = det_p.reshape(2, -1).transpose()
  326. det_p = plg.Polygon(det_p)
  327. local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
  328. gt_p)
  329. local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
  330. data = {}
  331. data["sigma"] = local_sigma_table
  332. data["global_tau"] = local_tau_table
  333. data["global_pred_str"] = ""
  334. data["global_gt_str"] = ""
  335. return data
  336. def combine_results(all_data, rec_flag=True):
  337. tr = 0.7
  338. tp = 0.6
  339. fsc_k = 0.8
  340. k = 2
  341. global_sigma = []
  342. global_tau = []
  343. global_pred_str = []
  344. global_gt_str = []
  345. for data in all_data:
  346. global_sigma.append(data["sigma"])
  347. global_tau.append(data["global_tau"])
  348. global_pred_str.append(data["global_pred_str"])
  349. global_gt_str.append(data["global_gt_str"])
  350. global_accumulative_recall = 0
  351. global_accumulative_precision = 0
  352. total_num_gt = 0
  353. total_num_det = 0
  354. hit_str_count = 0
  355. hit_count = 0
  356. def one_to_one(
  357. local_sigma_table,
  358. local_tau_table,
  359. local_accumulative_recall,
  360. local_accumulative_precision,
  361. global_accumulative_recall,
  362. global_accumulative_precision,
  363. gt_flag,
  364. det_flag,
  365. idy,
  366. rec_flag, ):
  367. hit_str_num = 0
  368. for gt_id in range(num_gt):
  369. gt_matching_qualified_sigma_candidates = np.where(
  370. local_sigma_table[gt_id, :] > tr)
  371. gt_matching_num_qualified_sigma_candidates = (
  372. gt_matching_qualified_sigma_candidates[0].shape[0])
  373. gt_matching_qualified_tau_candidates = np.where(
  374. local_tau_table[gt_id, :] > tp)
  375. gt_matching_num_qualified_tau_candidates = (
  376. gt_matching_qualified_tau_candidates[0].shape[0])
  377. det_matching_qualified_sigma_candidates = np.where(
  378. local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
  379. > tr)
  380. det_matching_num_qualified_sigma_candidates = (
  381. det_matching_qualified_sigma_candidates[0].shape[0])
  382. det_matching_qualified_tau_candidates = np.where(
  383. local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
  384. tp)
  385. det_matching_num_qualified_tau_candidates = (
  386. det_matching_qualified_tau_candidates[0].shape[0])
  387. if ((gt_matching_num_qualified_sigma_candidates == 1) and
  388. (gt_matching_num_qualified_tau_candidates == 1) and
  389. (det_matching_num_qualified_sigma_candidates == 1) and
  390. (det_matching_num_qualified_tau_candidates == 1)):
  391. global_accumulative_recall = global_accumulative_recall + 1.0
  392. global_accumulative_precision = global_accumulative_precision + 1.0
  393. local_accumulative_recall = local_accumulative_recall + 1.0
  394. local_accumulative_precision = local_accumulative_precision + 1.0
  395. gt_flag[0, gt_id] = 1
  396. matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
  397. # recg start
  398. if rec_flag:
  399. gt_str_cur = global_gt_str[idy][gt_id]
  400. pred_str_cur = global_pred_str[idy][matched_det_id[0]
  401. .tolist()[0]]
  402. if pred_str_cur == gt_str_cur:
  403. hit_str_num += 1
  404. else:
  405. if pred_str_cur.lower() == gt_str_cur.lower():
  406. hit_str_num += 1
  407. # recg end
  408. det_flag[0, matched_det_id] = 1
  409. return (
  410. local_accumulative_recall,
  411. local_accumulative_precision,
  412. global_accumulative_recall,
  413. global_accumulative_precision,
  414. gt_flag,
  415. det_flag,
  416. hit_str_num, )
  417. def one_to_many(
  418. local_sigma_table,
  419. local_tau_table,
  420. local_accumulative_recall,
  421. local_accumulative_precision,
  422. global_accumulative_recall,
  423. global_accumulative_precision,
  424. gt_flag,
  425. det_flag,
  426. idy,
  427. rec_flag, ):
  428. hit_str_num = 0
  429. for gt_id in range(num_gt):
  430. # skip the following if the groundtruth was matched
  431. if gt_flag[0, gt_id] > 0:
  432. continue
  433. non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
  434. num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
  435. if num_non_zero_in_sigma >= k:
  436. ####search for all detections that overlaps with this groundtruth
  437. qualified_tau_candidates = np.where((local_tau_table[
  438. gt_id, :] >= tp) & (det_flag[0, :] == 0))
  439. num_qualified_tau_candidates = qualified_tau_candidates[
  440. 0].shape[0]
  441. if num_qualified_tau_candidates == 1:
  442. if (local_tau_table[gt_id, qualified_tau_candidates] >= tp
  443. ) and (
  444. local_sigma_table[gt_id, qualified_tau_candidates]
  445. >= tr):
  446. # became an one-to-one case
  447. global_accumulative_recall = global_accumulative_recall + 1.0
  448. global_accumulative_precision = (
  449. global_accumulative_precision + 1.0)
  450. local_accumulative_recall = local_accumulative_recall + 1.0
  451. local_accumulative_precision = (
  452. local_accumulative_precision + 1.0)
  453. gt_flag[0, gt_id] = 1
  454. det_flag[0, qualified_tau_candidates] = 1
  455. # recg start
  456. if rec_flag:
  457. gt_str_cur = global_gt_str[idy][gt_id]
  458. pred_str_cur = global_pred_str[idy][
  459. qualified_tau_candidates[0].tolist()[0]]
  460. if pred_str_cur == gt_str_cur:
  461. hit_str_num += 1
  462. else:
  463. if pred_str_cur.lower() == gt_str_cur.lower():
  464. hit_str_num += 1
  465. # recg end
  466. elif np.sum(local_sigma_table[gt_id,
  467. qualified_tau_candidates]) >= tr:
  468. gt_flag[0, gt_id] = 1
  469. det_flag[0, qualified_tau_candidates] = 1
  470. # recg start
  471. if rec_flag:
  472. gt_str_cur = global_gt_str[idy][gt_id]
  473. pred_str_cur = global_pred_str[idy][
  474. qualified_tau_candidates[0].tolist()[0]]
  475. if pred_str_cur == gt_str_cur:
  476. hit_str_num += 1
  477. else:
  478. if pred_str_cur.lower() == gt_str_cur.lower():
  479. hit_str_num += 1
  480. # recg end
  481. global_accumulative_recall = global_accumulative_recall + fsc_k
  482. global_accumulative_precision = (
  483. global_accumulative_precision +
  484. num_qualified_tau_candidates * fsc_k)
  485. local_accumulative_recall = local_accumulative_recall + fsc_k
  486. local_accumulative_precision = (
  487. local_accumulative_precision +
  488. num_qualified_tau_candidates * fsc_k)
  489. return (
  490. local_accumulative_recall,
  491. local_accumulative_precision,
  492. global_accumulative_recall,
  493. global_accumulative_precision,
  494. gt_flag,
  495. det_flag,
  496. hit_str_num, )
  497. def many_to_one(
  498. local_sigma_table,
  499. local_tau_table,
  500. local_accumulative_recall,
  501. local_accumulative_precision,
  502. global_accumulative_recall,
  503. global_accumulative_precision,
  504. gt_flag,
  505. det_flag,
  506. idy,
  507. rec_flag, ):
  508. hit_str_num = 0
  509. for det_id in range(num_det):
  510. # skip the following if the detection was matched
  511. if det_flag[0, det_id] > 0:
  512. continue
  513. non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
  514. num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
  515. if num_non_zero_in_tau >= k:
  516. ####search for all detections that overlaps with this groundtruth
  517. qualified_sigma_candidates = np.where((
  518. local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
  519. num_qualified_sigma_candidates = qualified_sigma_candidates[
  520. 0].shape[0]
  521. if num_qualified_sigma_candidates == 1:
  522. if (
  523. local_tau_table[qualified_sigma_candidates, det_id]
  524. >= tp
  525. ) and (local_sigma_table[qualified_sigma_candidates, det_id]
  526. >= tr):
  527. # became an one-to-one case
  528. global_accumulative_recall = global_accumulative_recall + 1.0
  529. global_accumulative_precision = (
  530. global_accumulative_precision + 1.0)
  531. local_accumulative_recall = local_accumulative_recall + 1.0
  532. local_accumulative_precision = (
  533. local_accumulative_precision + 1.0)
  534. gt_flag[0, qualified_sigma_candidates] = 1
  535. det_flag[0, det_id] = 1
  536. # recg start
  537. if rec_flag:
  538. pred_str_cur = global_pred_str[idy][det_id]
  539. gt_len = len(qualified_sigma_candidates[0])
  540. for idx in range(gt_len):
  541. ele_gt_id = qualified_sigma_candidates[
  542. 0].tolist()[idx]
  543. if ele_gt_id not in global_gt_str[idy]:
  544. continue
  545. gt_str_cur = global_gt_str[idy][ele_gt_id]
  546. if pred_str_cur == gt_str_cur:
  547. hit_str_num += 1
  548. break
  549. else:
  550. if pred_str_cur.lower() == gt_str_cur.lower(
  551. ):
  552. hit_str_num += 1
  553. break
  554. # recg end
  555. elif np.sum(local_tau_table[qualified_sigma_candidates,
  556. det_id]) >= tp:
  557. det_flag[0, det_id] = 1
  558. gt_flag[0, qualified_sigma_candidates] = 1
  559. # recg start
  560. if rec_flag:
  561. pred_str_cur = global_pred_str[idy][det_id]
  562. gt_len = len(qualified_sigma_candidates[0])
  563. for idx in range(gt_len):
  564. ele_gt_id = qualified_sigma_candidates[0].tolist()[
  565. idx]
  566. if ele_gt_id not in global_gt_str[idy]:
  567. continue
  568. gt_str_cur = global_gt_str[idy][ele_gt_id]
  569. if pred_str_cur == gt_str_cur:
  570. hit_str_num += 1
  571. break
  572. else:
  573. if pred_str_cur.lower() == gt_str_cur.lower():
  574. hit_str_num += 1
  575. break
  576. # recg end
  577. global_accumulative_recall = (
  578. global_accumulative_recall +
  579. num_qualified_sigma_candidates * fsc_k)
  580. global_accumulative_precision = (
  581. global_accumulative_precision + fsc_k)
  582. local_accumulative_recall = (
  583. local_accumulative_recall +
  584. num_qualified_sigma_candidates * fsc_k)
  585. local_accumulative_precision = local_accumulative_precision + fsc_k
  586. return (
  587. local_accumulative_recall,
  588. local_accumulative_precision,
  589. global_accumulative_recall,
  590. global_accumulative_precision,
  591. gt_flag,
  592. det_flag,
  593. hit_str_num, )
  594. for idx in range(len(global_sigma)):
  595. local_sigma_table = np.array(global_sigma[idx])
  596. local_tau_table = global_tau[idx]
  597. num_gt = local_sigma_table.shape[0]
  598. num_det = local_sigma_table.shape[1]
  599. total_num_gt = total_num_gt + num_gt
  600. total_num_det = total_num_det + num_det
  601. local_accumulative_recall = 0
  602. local_accumulative_precision = 0
  603. gt_flag = np.zeros((1, num_gt))
  604. det_flag = np.zeros((1, num_det))
  605. #######first check for one-to-one case##########
  606. (
  607. local_accumulative_recall,
  608. local_accumulative_precision,
  609. global_accumulative_recall,
  610. global_accumulative_precision,
  611. gt_flag,
  612. det_flag,
  613. hit_str_num, ) = one_to_one(
  614. local_sigma_table,
  615. local_tau_table,
  616. local_accumulative_recall,
  617. local_accumulative_precision,
  618. global_accumulative_recall,
  619. global_accumulative_precision,
  620. gt_flag,
  621. det_flag,
  622. idx,
  623. rec_flag, )
  624. hit_str_count += hit_str_num
  625. #######then check for one-to-many case##########
  626. (
  627. local_accumulative_recall,
  628. local_accumulative_precision,
  629. global_accumulative_recall,
  630. global_accumulative_precision,
  631. gt_flag,
  632. det_flag,
  633. hit_str_num, ) = one_to_many(
  634. local_sigma_table,
  635. local_tau_table,
  636. local_accumulative_recall,
  637. local_accumulative_precision,
  638. global_accumulative_recall,
  639. global_accumulative_precision,
  640. gt_flag,
  641. det_flag,
  642. idx,
  643. rec_flag, )
  644. hit_str_count += hit_str_num
  645. #######then check for many-to-one case##########
  646. (
  647. local_accumulative_recall,
  648. local_accumulative_precision,
  649. global_accumulative_recall,
  650. global_accumulative_precision,
  651. gt_flag,
  652. det_flag,
  653. hit_str_num, ) = many_to_one(
  654. local_sigma_table,
  655. local_tau_table,
  656. local_accumulative_recall,
  657. local_accumulative_precision,
  658. global_accumulative_recall,
  659. global_accumulative_precision,
  660. gt_flag,
  661. det_flag,
  662. idx,
  663. rec_flag, )
  664. hit_str_count += hit_str_num
  665. try:
  666. recall = global_accumulative_recall / total_num_gt
  667. except ZeroDivisionError:
  668. recall = 0
  669. try:
  670. precision = global_accumulative_precision / total_num_det
  671. except ZeroDivisionError:
  672. precision = 0
  673. try:
  674. f_score = 2 * precision * recall / (precision + recall)
  675. except ZeroDivisionError:
  676. f_score = 0
  677. try:
  678. seqerr = 1 - float(hit_str_count) / global_accumulative_recall
  679. except ZeroDivisionError:
  680. seqerr = 1
  681. try:
  682. recall_e2e = float(hit_str_count) / total_num_gt
  683. except ZeroDivisionError:
  684. recall_e2e = 0
  685. try:
  686. precision_e2e = float(hit_str_count) / total_num_det
  687. except ZeroDivisionError:
  688. precision_e2e = 0
  689. try:
  690. f_score_e2e = 2 * precision_e2e * recall_e2e / (
  691. precision_e2e + recall_e2e)
  692. except ZeroDivisionError:
  693. f_score_e2e = 0
  694. final = {
  695. "total_num_gt": total_num_gt,
  696. "total_num_det": total_num_det,
  697. "global_accumulative_recall": global_accumulative_recall,
  698. "hit_str_count": hit_str_count,
  699. "recall": recall,
  700. "precision": precision,
  701. "f_score": f_score,
  702. "seqerr": seqerr,
  703. "recall_e2e": recall_e2e,
  704. "precision_e2e": precision_e2e,
  705. "f_score_e2e": f_score_e2e,
  706. }
  707. return final