player_detector.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import cv2
  2. import numpy as np
  3. from ultralytics import YOLO
  4. import supervision as sv
  5. from supervision.detection.core import Detections
  6. class PlayerDetector:
  7. def __init__(self, model_path=None, confidence=0.5, iou_threshold=0.45):
  8. """
  9. 初始化球员检测器
  10. Args:
  11. model_path: YOLO模型路径,如果为None则使用预训练模型
  12. confidence: 检测置信度阈值
  13. iou_threshold: NMS IOU阈值
  14. """
  15. self.confidence_threshold = confidence
  16. self.iou_threshold = iou_threshold
  17. if model_path:
  18. self.model = YOLO(model_path)
  19. else:
  20. # 使用预训练的YOLOv8模型
  21. print("加载YOLOv8预训练模型")
  22. self.model = YOLO("yolov8n.pt")
  23. # 人员检测类别ID (COCO数据集中的person类别为0)
  24. self.person_class_id = 0
  25. # 初始化监督库的标注工具
  26. self.box_annotator = sv.BoxAnnotator(
  27. thickness=2,
  28. text_thickness=2,
  29. text_scale=0.5
  30. )
  31. def detect(self, frame):
  32. """
  33. 检测帧中的球员
  34. Args:
  35. frame: 输入的视频帧
  36. Returns:
  37. 检测到的人物边界框列表 [x1, y1, x2, y2, confidence]
  38. """
  39. # 使用YOLO模型进行检测,可以调整参数来提高精度
  40. results = self.model(
  41. frame,
  42. verbose=False,
  43. conf=self.confidence_threshold,
  44. iou=self.iou_threshold,
  45. classes=[self.person_class_id] # 只检测人员
  46. )[0]
  47. # 提取人员检测结果
  48. detections = []
  49. for r in results.boxes.data.cpu().numpy():
  50. x1, y1, x2, y2, conf, class_id = r
  51. # 已经在YOLO调用时过滤了类别,这里再次确认是否为人员
  52. if int(class_id) == self.person_class_id and conf > self.confidence_threshold:
  53. # 添加一些额外的过滤,比如尺寸过滤(可选)
  54. box_width = x2 - x1
  55. box_height = y2 - y1
  56. box_area = box_width * box_height
  57. img_area = frame.shape[0] * frame.shape[1]
  58. # 过滤掉太小或太大的检测框
  59. area_ratio = box_area / img_area
  60. if 0.0005 < area_ratio < 0.2: # 根据实际场景调整这些值
  61. detections.append([x1, y1, x2, y2, conf])
  62. return np.array(detections) if detections else np.empty((0, 5))
  63. def visualize(self, frame, detections):
  64. """
  65. 可视化检测结果
  66. Args:
  67. frame: 原始视频帧
  68. detections: 检测结果
  69. Returns:
  70. 标注后的视频帧
  71. """
  72. # 复制原始帧以避免修改
  73. annotated_frame = frame.copy()
  74. # 使用supervision库进行可视化(更好的可视化效果)
  75. if len(detections) > 0:
  76. # 转换为supervision的Detections格式
  77. sv_detections = Detections(
  78. xyxy=detections[:, :4],
  79. confidence=detections[:, 4],
  80. class_id=np.zeros(len(detections), dtype=int) # 全部是人员类别
  81. )
  82. # 添加标签
  83. labels = [f"{conf:.2f}" for conf in detections[:, 4]]
  84. # 绘制检测框和标签
  85. annotated_frame = self.box_annotator.annotate(
  86. scene=annotated_frame,
  87. detections=sv_detections,
  88. labels=labels
  89. )
  90. return annotated_frame
  91. return annotated_frame
  92. def detect_with_supervision(self, frame):
  93. """
  94. 使用Supervision库进行检测和结果格式化,提供更标准的输出
  95. Args:
  96. frame: 输入的视频帧
  97. Returns:
  98. Supervision库的Detections对象
  99. """
  100. # 使用YOLO模型进行检测
  101. results = self.model(frame, verbose=False, conf=self.confidence_threshold, classes=[self.person_class_id])[0]
  102. # 转换为supervision的Detections格式
  103. detections = sv.Detections.from_ultralytics(results)
  104. # 过滤出人员类别
  105. mask = np.array([int(cls) == self.person_class_id for cls in detections.class_id], dtype=bool)
  106. detections = detections[mask]
  107. return detections