eval_det_iou.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. #!/usr/bin/env python
  2. import numpy as np
  3. from shapely.geometry import Polygon
  4. """
  5. reference from :
  6. https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
  7. """
  8. class DetectionIoUEvaluator(object):
  9. def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
  10. self.iou_constraint = iou_constraint
  11. self.area_precision_constraint = area_precision_constraint
  12. def evaluate_image(self, gt, pred):
  13. def get_union(pD, pG):
  14. return Polygon(pD).union(Polygon(pG)).area
  15. def get_intersection_over_union(pD, pG):
  16. return get_intersection(pD, pG) / get_union(pD, pG)
  17. def get_intersection(pD, pG):
  18. return Polygon(pD).intersection(Polygon(pG)).area
  19. def compute_ap(confList, matchList, numGtCare):
  20. correct = 0
  21. AP = 0
  22. if len(confList) > 0:
  23. confList = np.array(confList)
  24. matchList = np.array(matchList)
  25. sorted_ind = np.argsort(-confList)
  26. confList = confList[sorted_ind]
  27. matchList = matchList[sorted_ind]
  28. for n in range(len(confList)):
  29. match = matchList[n]
  30. if match:
  31. correct += 1
  32. AP += float(correct) / (n + 1)
  33. if numGtCare > 0:
  34. AP /= numGtCare
  35. return AP
  36. perSampleMetrics = {}
  37. matchedSum = 0
  38. numGlobalCareGt = 0
  39. numGlobalCareDet = 0
  40. precision = 0
  41. detMatched = 0
  42. iouMat = np.empty([1, 1])
  43. gtPols = []
  44. detPols = []
  45. gtPolPoints = []
  46. detPolPoints = []
  47. # Array of Ground Truth Polygons' keys marked as don't Care
  48. gtDontCarePolsNum = []
  49. # Array of Detected Polygons' matched with a don't Care GT
  50. detDontCarePolsNum = []
  51. pairs = []
  52. detMatchedNums = []
  53. evaluationLog = ''
  54. for n in range(len(gt)):
  55. points = gt[n]['points']
  56. dontCare = gt[n]['ignore']
  57. if not Polygon(points).is_valid:
  58. continue
  59. gtPol = points
  60. gtPols.append(gtPol)
  61. gtPolPoints.append(points)
  62. if dontCare:
  63. gtDontCarePolsNum.append(len(gtPols) - 1)
  64. evaluationLog += (
  65. 'GT polygons: ' + str(len(gtPols)) +
  66. (' (' + str(len(gtDontCarePolsNum)) +
  67. " don't care)\n" if len(gtDontCarePolsNum) > 0 else '\n'))
  68. for n in range(len(pred)):
  69. points = pred[n]['points']
  70. if not Polygon(points).is_valid:
  71. continue
  72. detPol = points
  73. detPols.append(detPol)
  74. detPolPoints.append(points)
  75. if len(gtDontCarePolsNum) > 0:
  76. for dontCarePol in gtDontCarePolsNum:
  77. dontCarePol = gtPols[dontCarePol]
  78. intersected_area = get_intersection(dontCarePol, detPol)
  79. pdDimensions = Polygon(detPol).area
  80. precision = (0 if pdDimensions == 0 else intersected_area /
  81. pdDimensions)
  82. if precision > self.area_precision_constraint:
  83. detDontCarePolsNum.append(len(detPols) - 1)
  84. break
  85. evaluationLog += (
  86. 'DET polygons: ' + str(len(detPols)) +
  87. (' (' + str(len(detDontCarePolsNum)) +
  88. " don't care)\n" if len(detDontCarePolsNum) > 0 else '\n'))
  89. if len(gtPols) > 0 and len(detPols) > 0:
  90. # Calculate IoU and precision matrixs
  91. outputShape = [len(gtPols), len(detPols)]
  92. iouMat = np.empty(outputShape)
  93. gtRectMat = np.zeros(len(gtPols), np.int8)
  94. detRectMat = np.zeros(len(detPols), np.int8)
  95. for gtNum in range(len(gtPols)):
  96. for detNum in range(len(detPols)):
  97. pG = gtPols[gtNum]
  98. pD = detPols[detNum]
  99. iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
  100. for gtNum in range(len(gtPols)):
  101. for detNum in range(len(detPols)):
  102. if (gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0
  103. and gtNum not in gtDontCarePolsNum
  104. and detNum not in detDontCarePolsNum):
  105. if iouMat[gtNum, detNum] > self.iou_constraint:
  106. gtRectMat[gtNum] = 1
  107. detRectMat[detNum] = 1
  108. detMatched += 1
  109. pairs.append({'gt': gtNum, 'det': detNum})
  110. detMatchedNums.append(detNum)
  111. evaluationLog += ('Match GT #' + str(gtNum) +
  112. ' with Det #' + str(detNum) +
  113. '\n')
  114. numGtCare = len(gtPols) - len(gtDontCarePolsNum)
  115. numDetCare = len(detPols) - len(detDontCarePolsNum)
  116. if numGtCare == 0:
  117. precision = float(0) if numDetCare > 0 else float(1)
  118. else:
  119. precision = 0 if numDetCare == 0 else float(
  120. detMatched) / numDetCare
  121. matchedSum += detMatched
  122. numGlobalCareGt += numGtCare
  123. numGlobalCareDet += numDetCare
  124. perSampleMetrics = {
  125. 'gtCare': numGtCare,
  126. 'detCare': numDetCare,
  127. 'detMatched': detMatched,
  128. }
  129. return perSampleMetrics
  130. def combine_results(self, results):
  131. numGlobalCareGt = 0
  132. numGlobalCareDet = 0
  133. matchedSum = 0
  134. for result in results:
  135. numGlobalCareGt += result['gtCare']
  136. numGlobalCareDet += result['detCare']
  137. matchedSum += result['detMatched']
  138. methodRecall = (0 if numGlobalCareGt == 0 else float(matchedSum) /
  139. numGlobalCareGt)
  140. methodPrecision = (0 if numGlobalCareDet == 0 else float(matchedSum) /
  141. numGlobalCareDet)
  142. methodHmean = (0 if methodRecall + methodPrecision == 0 else 2 *
  143. methodRecall * methodPrecision /
  144. (methodRecall + methodPrecision))
  145. methodMetrics = {
  146. 'precision': methodPrecision,
  147. 'recall': methodRecall,
  148. 'hmean': methodHmean,
  149. }
  150. return methodMetrics
  151. if __name__ == '__main__':
  152. evaluator = DetectionIoUEvaluator()
  153. gts = [[
  154. {
  155. 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
  156. 'text': 1234,
  157. 'ignore': False,
  158. },
  159. {
  160. 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
  161. 'text': 5678,
  162. 'ignore': False,
  163. },
  164. ]]
  165. preds = [[{
  166. 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
  167. 'text': 123,
  168. 'ignore': False,
  169. }]]
  170. results = []
  171. for gt, pred in zip(gts, preds):
  172. results.append(evaluator.evaluate_image(gt, pred))
  173. metrics = evaluator.combine_results(results)
  174. print(metrics)