import cv2 import numpy as np from collections import defaultdict import scipy.spatial import supervision as sv from scipy.signal import savgol_filter class PlayerTracker: def __init__(self, max_disappeared=30, min_distance=50, track_smoothing=True): """ 初始化球员跟踪器 Args: max_disappeared: 最大连续丢失帧数 min_distance: 最小匹配距离阈值 track_smoothing: 是否平滑轨迹 """ # 跟踪参数 self.tracks = {} self.next_id = 1 self.max_disappeared = max_disappeared self.disappeared = defaultdict(int) self.min_distance = min_distance # 存储轨迹历史 self.track_history = defaultdict(list) # 轨迹平滑 self.track_smoothing = track_smoothing self.smooth_window = 5 # 平滑窗口大小,必须是奇数 self.smooth_poly_order = 2 # 平滑多项式阶数 # 初始化彩色轨迹 self.colors = {} # 初始化supervision轨迹工具 self.bounding_box_annotator = sv.BoxAnnotator( thickness=2, text_thickness=2, text_scale=0.5 ) # 注释掉TraceAnnotator,当前版本的supervision库不支持 # self.trace_annotator = sv.TraceAnnotator( # thickness=2, # trace_length=20 # ) def update(self, detections, frame): """ 更新跟踪状态 Args: detections: 当前帧检测到的球员 [x1, y1, x2, y2, conf] frame: 当前视频帧 Returns: 当前帧的跟踪状态 {id: (x, y, w, h)} """ # 如果没有检测到任何球员 if len(detections) == 0: # 增加所有跟踪记录的消失计数 for track_id in list(self.tracks.keys()): self.disappeared[track_id] += 1 # 如果连续消失帧数超过阈值,删除该跟踪记录 if self.disappeared[track_id] > self.max_disappeared: del self.tracks[track_id] del self.disappeared[track_id] return self.tracks # 如果还没有任何跟踪记录,为所有检测结果创建新的跟踪ID if len(self.tracks) == 0: for i, det in enumerate(detections): x1, y1, x2, y2, _ = det center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2 width, height = x2 - x1, y2 - y1 self.tracks[self.next_id] = (center_x, center_y, width, height) self.track_history[self.next_id].append((center_x, center_y)) # 为新轨迹指定颜色 self.colors[self.next_id] = self._generate_color(self.next_id) self.next_id += 1 else: # 计算当前检测结果与现有跟踪记录之间的距离 track_centers = np.array([[x, y] for x, y, _, _ in self.tracks.values()]) # 计算检测结果的中心点 detection_centers = [] for det in detections: x1, y1, x2, y2, _ = det center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2 detection_centers.append([center_x, center_y]) detection_centers = np.array(detection_centers) # 计算距离矩阵 distances = scipy.spatial.distance.cdist(track_centers, detection_centers) # 使用匈牙利算法进行跟踪-检测匹配 from scipy.optimize import linear_sum_assignment track_indices, detection_indices = linear_sum_assignment(distances) # 初始化已使用的跟踪ID和检测结果 used_tracks = set() used_detections = set() # 遍历所有匹配结果 track_ids = list(self.tracks.keys()) for track_idx, det_idx in zip(track_indices, detection_indices): # 如果匹配距离太大,则忽略 if distances[track_idx, det_idx] > self.min_distance: continue track_id = track_ids[track_idx] used_tracks.add(track_id) used_detections.add(det_idx) # 获取当前跟踪对象的位置和尺寸 curr_x, curr_y, curr_w, curr_h = self.tracks[track_id] # 更新跟踪记录 x1, y1, x2, y2, _ = detections[det_idx] new_x, new_y = (x1 + x2) / 2, (y1 + y2) / 2 new_w, new_h = x2 - x1, y2 - y1 # 使用动态加权平均进行平滑 alpha = 0.8 # 当前检测的权重 smooth_x = alpha * new_x + (1 - alpha) * curr_x smooth_y = alpha * new_y + (1 - alpha) * curr_y smooth_w = alpha * new_w + (1 - alpha) * curr_w smooth_h = alpha * new_h + (1 - alpha) * curr_h # 更新跟踪对象 self.tracks[track_id] = (smooth_x, smooth_y, smooth_w, smooth_h) self.disappeared[track_id] = 0 # 更新轨迹历史 self.track_history[track_id].append((smooth_x, smooth_y)) # 处理未匹配的跟踪记录 for track_id in list(self.tracks.keys()): if track_id not in used_tracks: self.disappeared[track_id] += 1 # 如果连续消失帧数超过阈值,删除该跟踪记录 if self.disappeared[track_id] > self.max_disappeared: del self.tracks[track_id] del self.disappeared[track_id] if track_id in self.colors: del self.colors[track_id] # 处理未匹配的检测结果 for det_idx, det in enumerate(detections): if det_idx not in used_detections: x1, y1, x2, y2, _ = det center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2 width, height = x2 - x1, y2 - y1 # 创建新的跟踪对象 self.tracks[self.next_id] = (center_x, center_y, width, height) self.disappeared[self.next_id] = 0 # 为新轨迹指定颜色 self.colors[self.next_id] = self._generate_color(self.next_id) # 更新轨迹历史 self.track_history[self.next_id].append((center_x, center_y)) self.next_id += 1 # 对轨迹进行平滑处理 if self.track_smoothing: self._smooth_tracks() return self.tracks def _smooth_tracks(self): """使用Savitzky-Golay滤波器平滑轨迹""" for track_id in self.tracks.keys(): # 如果轨迹长度足够,进行平滑 history = self.track_history[track_id] if len(history) >= self.smooth_window: # 分离x和y坐标 xs = np.array([p[0] for p in history]) ys = np.array([p[1] for p in history]) try: # 应用Savitzky-Golay滤波器 smooth_xs = savgol_filter(xs, self.smooth_window, self.smooth_poly_order) smooth_ys = savgol_filter(ys, self.smooth_window, self.smooth_poly_order) # 更新最新位置为平滑后的位置 x, y, w, h = self.tracks[track_id] self.tracks[track_id] = (smooth_xs[-1], smooth_ys[-1], w, h) # 更新历史轨迹 self.track_history[track_id][-1] = (smooth_xs[-1], smooth_ys[-1]) except: # 如果平滑失败,保持原样 pass def _generate_color(self, track_id): """为轨迹生成一个唯一的颜色""" # 使用HSV颜色空间生成均匀分布的颜色 h = (track_id * 0.1) % 1.0 s = 0.8 v = 0.8 # 转换为RGB import colorsys r, g, b = colorsys.hsv_to_rgb(h, s, v) return (int(r * 255), int(g * 255), int(b * 255)) def visualize(self, frame, draw_history=True, history_len=30): """ 可视化跟踪结果 Args: frame: 原始视频帧 draw_history: 是否绘制历史轨迹 history_len: 历史轨迹长度限制 Returns: 标注后的视频帧 """ # 复制原始帧以避免修改 annotated_frame = frame.copy() # 绘制当前跟踪的球员 for track_id, (x, y, w, h) in self.tracks.items(): x1, y1 = int(x - w/2), int(y - h/2) x2, y2 = int(x + w/2), int(y + h/2) # 获取该轨迹的颜色 color = self.colors.get(track_id, (0, 255, 0)) # 绘制边界框 cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2) # 添加ID文本 cv2.putText(annotated_frame, str(track_id), (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) # 绘制历史轨迹 if draw_history and track_id in self.track_history: history = self.track_history[track_id] # 限制历史轨迹长度 if history_len > 0: history = history[-history_len:] for i in range(1, len(history)): # 轨迹线的透明度随着时间减弱 alpha = min(1.0, i / len(history)) c = (int(color[0] * alpha), int(color[1] * alpha), int(color[2] * alpha)) pt1 = (int(history[i-1][0]), int(history[i-1][1])) pt2 = (int(history[i][0]), int(history[i][1])) # 确保点在图像范围内 if 0 <= pt1[0] < frame.shape[1] and 0 <= pt1[1] < frame.shape[0] and \ 0 <= pt2[0] < frame.shape[1] and 0 <= pt2[1] < frame.shape[0]: cv2.line(annotated_frame, pt1, pt2, c, 2) return annotated_frame def visualize_with_supervision(self, frame, detections_sv=None): """ 使用supervision库可视化跟踪结果 Args: frame: 原始视频帧 detections_sv: 可选的supervision检测结果 Returns: 标注后的视频帧 """ # 如果没有提供detections_sv,从当前跟踪状态创建 if detections_sv is None: # 创建空的检测结果 xyxy = np.zeros((len(self.tracks), 4)) confidence = np.ones(len(self.tracks)) class_id = np.zeros(len(self.tracks), dtype=int) tracker_id = np.array(list(self.tracks.keys())) # 填充边界框信息 for i, (track_id, (x, y, w, h)) in enumerate(self.tracks.items()): x1, y1 = x - w/2, y - h/2 x2, y2 = x + w/2, y + h/2 xyxy[i] = [x1, y1, x2, y2] # 创建supervision检测结果 detections_sv = sv.Detections( xyxy=xyxy, confidence=confidence, class_id=class_id, tracker_id=tracker_id ) # 创建标签 labels = [f"ID:{tracker_id}" for tracker_id in detections_sv.tracker_id] # 使用supervision的工具进行可视化 annotated_frame = frame.copy() # 绘制边界框和ID annotated_frame = self.bounding_box_annotator.annotate( scene=annotated_frame, detections=detections_sv, labels=labels ) # 绘制轨迹 for track_id in self.tracks.keys(): history = self.track_history[track_id] # 至少需要两个点才能绘制线条 if len(history) < 2: continue # 获取该轨迹的颜色 color = self.colors.get(track_id, (0, 255, 0)) # 绘制完整轨迹 points = np.array(history, dtype=np.int32) for i in range(1, len(points)): pt1 = tuple(points[i-1]) pt2 = tuple(points[i]) cv2.line(annotated_frame, pt1, pt2, color, 2) return annotated_frame def get_tracks_for_visualization(self): """ 获取适合可视化的轨迹数据 Returns: 轨迹数据,格式为 {id: [(x1, y1), (x2, y2), ...]} """ visualization_tracks = {} for track_id, history in self.track_history.items(): # 只返回活跃的轨迹 if track_id in self.tracks: visualization_tracks[track_id] = history.copy() return visualization_tracks, self.colors