123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802 |
- import json
- import numpy as np
- import scipy.io as io
- from tools.utils.utility import check_install
- from tools.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
- def get_socre_A(gt_dir, pred_dict):
- allInputs = 1
- def input_reading_mod(pred_dict):
- """This helper reads input from txt files"""
- det = []
- n = len(pred_dict)
- for i in range(n):
- points = pred_dict[i]["points"]
- text = pred_dict[i]["texts"]
- point = ",".join(map(
- str,
- points.reshape(-1, ), ))
- det.append([point, text])
- return det
- def gt_reading_mod(gt_dict):
- """This helper reads groundtruths from mat files"""
- gt = []
- n = len(gt_dict)
- for i in range(n):
- points = gt_dict[i]["points"].tolist()
- h = len(points)
- text = gt_dict[i]["text"]
- xx = [
- np.array(
- ["x:"], dtype="<U2"),
- 0,
- np.array(
- ["y:"], dtype="<U2"),
- 0,
- np.array(
- ["#"], dtype="<U1"),
- np.array(
- ["#"], dtype="<U1"),
- ]
- t_x, t_y = [], []
- for j in range(h):
- t_x.append(points[j][0])
- t_y.append(points[j][1])
- xx[1] = np.array([t_x], dtype="int16")
- xx[3] = np.array([t_y], dtype="int16")
- if text != "":
- xx[4] = np.array([text], dtype="U{}".format(len(text)))
- xx[5] = np.array(["c"], dtype="<U1")
- gt.append(xx)
- return gt
- def detection_filtering(detections, groundtruths, threshold=0.5):
- for gt_id, gt in enumerate(groundtruths):
- if (gt[5] == "#") and (gt[1].shape[1] > 1):
- gt_x = list(map(int, np.squeeze(gt[1])))
- gt_y = list(map(int, np.squeeze(gt[3])))
- for det_id, detection in enumerate(detections):
- detection_orig = detection
- detection = [float(x) for x in detection[0].split(",")]
- detection = list(map(int, detection))
- det_x = detection[0::2]
- det_y = detection[1::2]
- det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
- if det_gt_iou > threshold:
- detections[det_id] = []
- detections[:] = [item for item in detections if item != []]
- return detections
- def sigma_calculation(det_x, det_y, gt_x, gt_y):
- """
- sigma = inter_area / gt_area
- """
- return np.round(
- (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)),
- 2)
- def tau_calculation(det_x, det_y, gt_x, gt_y):
- if area(det_x, det_y) == 0.0:
- return 0
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(det_x, det_y)), 2)
- ##############################Initialization###################################
- # global_sigma = []
- # global_tau = []
- # global_pred_str = []
- # global_gt_str = []
- ###############################################################################
- for input_id in range(allInputs):
- if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and
- (input_id != "Pascal_result_curved.txt") and
- (input_id != "Pascal_result_non_curved.txt") and
- (input_id != "Deteval_result.txt") and
- (input_id != "Deteval_result_curved.txt") and
- (input_id != "Deteval_result_non_curved.txt")):
- detections = input_reading_mod(pred_dict)
- groundtruths = gt_reading_mod(gt_dir)
- detections = detection_filtering(
- detections,
- groundtruths) # filters detections overlapping with DC area
- dc_id = []
- for i in range(len(groundtruths)):
- if groundtruths[i][5] == "#":
- dc_id.append(i)
- cnt = 0
- for a in dc_id:
- num = a - cnt
- del groundtruths[num]
- cnt += 1
- local_sigma_table = np.zeros((len(groundtruths), len(detections)))
- local_tau_table = np.zeros((len(groundtruths), len(detections)))
- local_pred_str = {}
- local_gt_str = {}
- for gt_id, gt in enumerate(groundtruths):
- if len(detections) > 0:
- for det_id, detection in enumerate(detections):
- detection_orig = detection
- detection = [float(x) for x in detection[0].split(",")]
- detection = list(map(int, detection))
- pred_seq_str = detection_orig[1].strip()
- det_x = detection[0::2]
- det_y = detection[1::2]
- gt_x = list(map(int, np.squeeze(gt[1])))
- gt_y = list(map(int, np.squeeze(gt[3])))
- gt_seq_str = str(gt[4].tolist()[0])
- local_sigma_table[gt_id, det_id] = sigma_calculation(
- det_x, det_y, gt_x, gt_y)
- local_tau_table[gt_id, det_id] = tau_calculation(
- det_x, det_y, gt_x, gt_y)
- local_pred_str[det_id] = pred_seq_str
- local_gt_str[gt_id] = gt_seq_str
- global_sigma = local_sigma_table
- global_tau = local_tau_table
- global_pred_str = local_pred_str
- global_gt_str = local_gt_str
- single_data = {}
- single_data["sigma"] = global_sigma
- single_data["global_tau"] = global_tau
- single_data["global_pred_str"] = global_pred_str
- single_data["global_gt_str"] = global_gt_str
- return single_data
- def get_socre_B(gt_dir, img_id, pred_dict):
- allInputs = 1
- def input_reading_mod(pred_dict):
- """This helper reads input from txt files"""
- det = []
- n = len(pred_dict)
- for i in range(n):
- points = pred_dict[i]["points"]
- text = pred_dict[i]["texts"]
- point = ",".join(map(
- str,
- points.reshape(-1, ), ))
- det.append([point, text])
- return det
- def gt_reading_mod(gt_dir, gt_id):
- gt = io.loadmat("%s/poly_gt_img%s.mat" % (gt_dir, gt_id))
- gt = gt["polygt"]
- return gt
- def detection_filtering(detections, groundtruths, threshold=0.5):
- for gt_id, gt in enumerate(groundtruths):
- if (gt[5] == "#") and (gt[1].shape[1] > 1):
- gt_x = list(map(int, np.squeeze(gt[1])))
- gt_y = list(map(int, np.squeeze(gt[3])))
- for det_id, detection in enumerate(detections):
- detection_orig = detection
- detection = [float(x) for x in detection[0].split(",")]
- detection = list(map(int, detection))
- det_x = detection[0::2]
- det_y = detection[1::2]
- det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
- if det_gt_iou > threshold:
- detections[det_id] = []
- detections[:] = [item for item in detections if item != []]
- return detections
- def sigma_calculation(det_x, det_y, gt_x, gt_y):
- """
- sigma = inter_area / gt_area
- """
- return np.round(
- (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)),
- 2)
- def tau_calculation(det_x, det_y, gt_x, gt_y):
- if area(det_x, det_y) == 0.0:
- return 0
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(det_x, det_y)), 2)
- ##############################Initialization###################################
- # global_sigma = []
- # global_tau = []
- # global_pred_str = []
- # global_gt_str = []
- ###############################################################################
- for input_id in range(allInputs):
- if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and
- (input_id != "Pascal_result_curved.txt") and
- (input_id != "Pascal_result_non_curved.txt") and
- (input_id != "Deteval_result.txt") and
- (input_id != "Deteval_result_curved.txt") and
- (input_id != "Deteval_result_non_curved.txt")):
- detections = input_reading_mod(pred_dict)
- groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
- detections = detection_filtering(
- detections,
- groundtruths) # filters detections overlapping with DC area
- dc_id = []
- for i in range(len(groundtruths)):
- if groundtruths[i][5] == "#":
- dc_id.append(i)
- cnt = 0
- for a in dc_id:
- num = a - cnt
- del groundtruths[num]
- cnt += 1
- local_sigma_table = np.zeros((len(groundtruths), len(detections)))
- local_tau_table = np.zeros((len(groundtruths), len(detections)))
- local_pred_str = {}
- local_gt_str = {}
- for gt_id, gt in enumerate(groundtruths):
- if len(detections) > 0:
- for det_id, detection in enumerate(detections):
- detection_orig = detection
- detection = [float(x) for x in detection[0].split(",")]
- detection = list(map(int, detection))
- pred_seq_str = detection_orig[1].strip()
- det_x = detection[0::2]
- det_y = detection[1::2]
- gt_x = list(map(int, np.squeeze(gt[1])))
- gt_y = list(map(int, np.squeeze(gt[3])))
- gt_seq_str = str(gt[4].tolist()[0])
- local_sigma_table[gt_id, det_id] = sigma_calculation(
- det_x, det_y, gt_x, gt_y)
- local_tau_table[gt_id, det_id] = tau_calculation(
- det_x, det_y, gt_x, gt_y)
- local_pred_str[det_id] = pred_seq_str
- local_gt_str[gt_id] = gt_seq_str
- global_sigma = local_sigma_table
- global_tau = local_tau_table
- global_pred_str = local_pred_str
- global_gt_str = local_gt_str
- single_data = {}
- single_data["sigma"] = global_sigma
- single_data["global_tau"] = global_tau
- single_data["global_pred_str"] = global_pred_str
- single_data["global_gt_str"] = global_gt_str
- return single_data
- def get_score_C(gt_label, text, pred_bboxes):
- """
- get score for CentripetalText (CT) prediction.
- """
- check_install("Polygon", "Polygon3")
- import Polygon as plg
- def gt_reading_mod(gt_label, text):
- """This helper reads groundtruths from mat files"""
- groundtruths = []
- nbox = len(gt_label)
- for i in range(nbox):
- label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
- groundtruths.append(label)
- return groundtruths
- def get_union(pD, pG):
- areaA = pD.area()
- areaB = pG.area()
- return areaA + areaB - get_intersection(pD, pG)
- def get_intersection(pD, pG):
- pInt = pD & pG
- if len(pInt) == 0:
- return 0
- return pInt.area()
- def detection_filtering(detections, groundtruths, threshold=0.5):
- for gt in groundtruths:
- point_num = gt["points"].shape[1] // 2
- if gt["transcription"] == "###" and (point_num > 1):
- gt_p = np.array(gt["points"]).reshape(point_num,
- 2).astype("int32")
- gt_p = plg.Polygon(gt_p)
- for det_id, detection in enumerate(detections):
- det_y = detection[0::2]
- det_x = detection[1::2]
- det_p = np.concatenate((np.array(det_x), np.array(det_y)))
- det_p = det_p.reshape(2, -1).transpose()
- det_p = plg.Polygon(det_p)
- try:
- det_gt_iou = get_intersection(det_p,
- gt_p) / det_p.area()
- except:
- print(det_x, det_y, gt_p)
- if det_gt_iou > threshold:
- detections[det_id] = []
- detections[:] = [item for item in detections if item != []]
- return detections
- def sigma_calculation(det_p, gt_p):
- """
- sigma = inter_area / gt_area
- """
- if gt_p.area() == 0.0:
- return 0
- return get_intersection(det_p, gt_p) / gt_p.area()
- def tau_calculation(det_p, gt_p):
- """
- tau = inter_area / det_area
- """
- if det_p.area() == 0.0:
- return 0
- return get_intersection(det_p, gt_p) / det_p.area()
- detections = []
- for item in pred_bboxes:
- detections.append(item[:, ::-1].reshape(-1))
- groundtruths = gt_reading_mod(gt_label, text)
- detections = detection_filtering(
- detections, groundtruths) # filters detections overlapping with DC area
- for idx in range(len(groundtruths) - 1, -1, -1):
- # NOTE: source code use 'orin' to indicate '#', here we use 'anno',
- # which may cause slight drop in fscore, about 0.12
- if groundtruths[idx]["transcription"] == "###":
- groundtruths.pop(idx)
- local_sigma_table = np.zeros((len(groundtruths), len(detections)))
- local_tau_table = np.zeros((len(groundtruths), len(detections)))
- for gt_id, gt in enumerate(groundtruths):
- if len(detections) > 0:
- for det_id, detection in enumerate(detections):
- point_num = gt["points"].shape[1] // 2
- gt_p = np.array(gt["points"]).reshape(point_num,
- 2).astype("int32")
- gt_p = plg.Polygon(gt_p)
- det_y = detection[0::2]
- det_x = detection[1::2]
- det_p = np.concatenate((np.array(det_x), np.array(det_y)))
- det_p = det_p.reshape(2, -1).transpose()
- det_p = plg.Polygon(det_p)
- local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
- gt_p)
- local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
- data = {}
- data["sigma"] = local_sigma_table
- data["global_tau"] = local_tau_table
- data["global_pred_str"] = ""
- data["global_gt_str"] = ""
- return data
- def combine_results(all_data, rec_flag=True):
- tr = 0.7
- tp = 0.6
- fsc_k = 0.8
- k = 2
- global_sigma = []
- global_tau = []
- global_pred_str = []
- global_gt_str = []
- for data in all_data:
- global_sigma.append(data["sigma"])
- global_tau.append(data["global_tau"])
- global_pred_str.append(data["global_pred_str"])
- global_gt_str.append(data["global_gt_str"])
- global_accumulative_recall = 0
- global_accumulative_precision = 0
- total_num_gt = 0
- total_num_det = 0
- hit_str_count = 0
- hit_count = 0
- def one_to_one(
- local_sigma_table,
- local_tau_table,
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- idy,
- rec_flag, ):
- hit_str_num = 0
- for gt_id in range(num_gt):
- gt_matching_qualified_sigma_candidates = np.where(
- local_sigma_table[gt_id, :] > tr)
- gt_matching_num_qualified_sigma_candidates = (
- gt_matching_qualified_sigma_candidates[0].shape[0])
- gt_matching_qualified_tau_candidates = np.where(
- local_tau_table[gt_id, :] > tp)
- gt_matching_num_qualified_tau_candidates = (
- gt_matching_qualified_tau_candidates[0].shape[0])
- det_matching_qualified_sigma_candidates = np.where(
- local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
- > tr)
- det_matching_num_qualified_sigma_candidates = (
- det_matching_qualified_sigma_candidates[0].shape[0])
- det_matching_qualified_tau_candidates = np.where(
- local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
- tp)
- det_matching_num_qualified_tau_candidates = (
- det_matching_qualified_tau_candidates[0].shape[0])
- if ((gt_matching_num_qualified_sigma_candidates == 1) and
- (gt_matching_num_qualified_tau_candidates == 1) and
- (det_matching_num_qualified_sigma_candidates == 1) and
- (det_matching_num_qualified_tau_candidates == 1)):
- global_accumulative_recall = global_accumulative_recall + 1.0
- global_accumulative_precision = global_accumulative_precision + 1.0
- local_accumulative_recall = local_accumulative_recall + 1.0
- local_accumulative_precision = local_accumulative_precision + 1.0
- gt_flag[0, gt_id] = 1
- matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
- # recg start
- if rec_flag:
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][matched_det_id[0]
- .tolist()[0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- # recg end
- det_flag[0, matched_det_id] = 1
- return (
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- hit_str_num, )
- def one_to_many(
- local_sigma_table,
- local_tau_table,
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- idy,
- rec_flag, ):
- hit_str_num = 0
- for gt_id in range(num_gt):
- # skip the following if the groundtruth was matched
- if gt_flag[0, gt_id] > 0:
- continue
- non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
- num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
- if num_non_zero_in_sigma >= k:
- ####search for all detections that overlaps with this groundtruth
- qualified_tau_candidates = np.where((local_tau_table[
- gt_id, :] >= tp) & (det_flag[0, :] == 0))
- num_qualified_tau_candidates = qualified_tau_candidates[
- 0].shape[0]
- if num_qualified_tau_candidates == 1:
- if (local_tau_table[gt_id, qualified_tau_candidates] >= tp
- ) and (
- local_sigma_table[gt_id, qualified_tau_candidates]
- >= tr):
- # became an one-to-one case
- global_accumulative_recall = global_accumulative_recall + 1.0
- global_accumulative_precision = (
- global_accumulative_precision + 1.0)
- local_accumulative_recall = local_accumulative_recall + 1.0
- local_accumulative_precision = (
- local_accumulative_precision + 1.0)
- gt_flag[0, gt_id] = 1
- det_flag[0, qualified_tau_candidates] = 1
- # recg start
- if rec_flag:
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][
- qualified_tau_candidates[0].tolist()[0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- # recg end
- elif np.sum(local_sigma_table[gt_id,
- qualified_tau_candidates]) >= tr:
- gt_flag[0, gt_id] = 1
- det_flag[0, qualified_tau_candidates] = 1
- # recg start
- if rec_flag:
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][
- qualified_tau_candidates[0].tolist()[0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- # recg end
- global_accumulative_recall = global_accumulative_recall + fsc_k
- global_accumulative_precision = (
- global_accumulative_precision +
- num_qualified_tau_candidates * fsc_k)
- local_accumulative_recall = local_accumulative_recall + fsc_k
- local_accumulative_precision = (
- local_accumulative_precision +
- num_qualified_tau_candidates * fsc_k)
- return (
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- hit_str_num, )
- def many_to_one(
- local_sigma_table,
- local_tau_table,
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- idy,
- rec_flag, ):
- hit_str_num = 0
- for det_id in range(num_det):
- # skip the following if the detection was matched
- if det_flag[0, det_id] > 0:
- continue
- non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
- num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
- if num_non_zero_in_tau >= k:
- ####search for all detections that overlaps with this groundtruth
- qualified_sigma_candidates = np.where((
- local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
- num_qualified_sigma_candidates = qualified_sigma_candidates[
- 0].shape[0]
- if num_qualified_sigma_candidates == 1:
- if (
- local_tau_table[qualified_sigma_candidates, det_id]
- >= tp
- ) and (local_sigma_table[qualified_sigma_candidates, det_id]
- >= tr):
- # became an one-to-one case
- global_accumulative_recall = global_accumulative_recall + 1.0
- global_accumulative_precision = (
- global_accumulative_precision + 1.0)
- local_accumulative_recall = local_accumulative_recall + 1.0
- local_accumulative_precision = (
- local_accumulative_precision + 1.0)
- gt_flag[0, qualified_sigma_candidates] = 1
- det_flag[0, det_id] = 1
- # recg start
- if rec_flag:
- pred_str_cur = global_pred_str[idy][det_id]
- gt_len = len(qualified_sigma_candidates[0])
- for idx in range(gt_len):
- ele_gt_id = qualified_sigma_candidates[
- 0].tolist()[idx]
- if ele_gt_id not in global_gt_str[idy]:
- continue
- gt_str_cur = global_gt_str[idy][ele_gt_id]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- break
- else:
- if pred_str_cur.lower() == gt_str_cur.lower(
- ):
- hit_str_num += 1
- break
- # recg end
- elif np.sum(local_tau_table[qualified_sigma_candidates,
- det_id]) >= tp:
- det_flag[0, det_id] = 1
- gt_flag[0, qualified_sigma_candidates] = 1
- # recg start
- if rec_flag:
- pred_str_cur = global_pred_str[idy][det_id]
- gt_len = len(qualified_sigma_candidates[0])
- for idx in range(gt_len):
- ele_gt_id = qualified_sigma_candidates[0].tolist()[
- idx]
- if ele_gt_id not in global_gt_str[idy]:
- continue
- gt_str_cur = global_gt_str[idy][ele_gt_id]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- break
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- break
- # recg end
- global_accumulative_recall = (
- global_accumulative_recall +
- num_qualified_sigma_candidates * fsc_k)
- global_accumulative_precision = (
- global_accumulative_precision + fsc_k)
- local_accumulative_recall = (
- local_accumulative_recall +
- num_qualified_sigma_candidates * fsc_k)
- local_accumulative_precision = local_accumulative_precision + fsc_k
- return (
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- hit_str_num, )
- for idx in range(len(global_sigma)):
- local_sigma_table = np.array(global_sigma[idx])
- local_tau_table = global_tau[idx]
- num_gt = local_sigma_table.shape[0]
- num_det = local_sigma_table.shape[1]
- total_num_gt = total_num_gt + num_gt
- total_num_det = total_num_det + num_det
- local_accumulative_recall = 0
- local_accumulative_precision = 0
- gt_flag = np.zeros((1, num_gt))
- det_flag = np.zeros((1, num_det))
- #######first check for one-to-one case##########
- (
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- hit_str_num, ) = one_to_one(
- local_sigma_table,
- local_tau_table,
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- idx,
- rec_flag, )
- hit_str_count += hit_str_num
- #######then check for one-to-many case##########
- (
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- hit_str_num, ) = one_to_many(
- local_sigma_table,
- local_tau_table,
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- idx,
- rec_flag, )
- hit_str_count += hit_str_num
- #######then check for many-to-one case##########
- (
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- hit_str_num, ) = many_to_one(
- local_sigma_table,
- local_tau_table,
- local_accumulative_recall,
- local_accumulative_precision,
- global_accumulative_recall,
- global_accumulative_precision,
- gt_flag,
- det_flag,
- idx,
- rec_flag, )
- hit_str_count += hit_str_num
- try:
- recall = global_accumulative_recall / total_num_gt
- except ZeroDivisionError:
- recall = 0
- try:
- precision = global_accumulative_precision / total_num_det
- except ZeroDivisionError:
- precision = 0
- try:
- f_score = 2 * precision * recall / (precision + recall)
- except ZeroDivisionError:
- f_score = 0
- try:
- seqerr = 1 - float(hit_str_count) / global_accumulative_recall
- except ZeroDivisionError:
- seqerr = 1
- try:
- recall_e2e = float(hit_str_count) / total_num_gt
- except ZeroDivisionError:
- recall_e2e = 0
- try:
- precision_e2e = float(hit_str_count) / total_num_det
- except ZeroDivisionError:
- precision_e2e = 0
- try:
- f_score_e2e = 2 * precision_e2e * recall_e2e / (
- precision_e2e + recall_e2e)
- except ZeroDivisionError:
- f_score_e2e = 0
- final = {
- "total_num_gt": total_num_gt,
- "total_num_det": total_num_det,
- "global_accumulative_recall": global_accumulative_recall,
- "hit_str_count": hit_str_count,
- "recall": recall,
- "precision": precision,
- "f_score": f_score,
- "seqerr": seqerr,
- "recall_e2e": recall_e2e,
- "precision_e2e": precision_e2e,
- "f_score_e2e": f_score_e2e,
- }
- return final
|