123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 |
- import cv2
- import numpy as np
- from collections import defaultdict
- import scipy.spatial
- import supervision as sv
- from scipy.signal import savgol_filter
- class PlayerTracker:
- def __init__(self, max_disappeared=30, min_distance=50, track_smoothing=True):
- """
- 初始化球员跟踪器
-
- Args:
- max_disappeared: 最大连续丢失帧数
- min_distance: 最小匹配距离阈值
- track_smoothing: 是否平滑轨迹
- """
- # 跟踪参数
- self.tracks = {}
- self.next_id = 1
- self.max_disappeared = max_disappeared
- self.disappeared = defaultdict(int)
- self.min_distance = min_distance
-
- # 存储轨迹历史
- self.track_history = defaultdict(list)
-
- # 轨迹平滑
- self.track_smoothing = track_smoothing
- self.smooth_window = 5 # 平滑窗口大小,必须是奇数
- self.smooth_poly_order = 2 # 平滑多项式阶数
-
- # 初始化彩色轨迹
- self.colors = {}
-
- # 初始化supervision轨迹工具
- self.bounding_box_annotator = sv.BoxAnnotator(
- thickness=2,
- text_thickness=2,
- text_scale=0.5
- )
-
- # 注释掉TraceAnnotator,当前版本的supervision库不支持
- # self.trace_annotator = sv.TraceAnnotator(
- # thickness=2,
- # trace_length=20
- # )
-
- def update(self, detections, frame):
- """
- 更新跟踪状态
-
- Args:
- detections: 当前帧检测到的球员 [x1, y1, x2, y2, conf]
- frame: 当前视频帧
-
- Returns:
- 当前帧的跟踪状态 {id: (x, y, w, h)}
- """
- # 如果没有检测到任何球员
- if len(detections) == 0:
- # 增加所有跟踪记录的消失计数
- for track_id in list(self.tracks.keys()):
- self.disappeared[track_id] += 1
-
- # 如果连续消失帧数超过阈值,删除该跟踪记录
- if self.disappeared[track_id] > self.max_disappeared:
- del self.tracks[track_id]
- del self.disappeared[track_id]
-
- return self.tracks
-
- # 如果还没有任何跟踪记录,为所有检测结果创建新的跟踪ID
- if len(self.tracks) == 0:
- for i, det in enumerate(detections):
- x1, y1, x2, y2, _ = det
- center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
- width, height = x2 - x1, y2 - y1
-
- self.tracks[self.next_id] = (center_x, center_y, width, height)
- self.track_history[self.next_id].append((center_x, center_y))
-
- # 为新轨迹指定颜色
- self.colors[self.next_id] = self._generate_color(self.next_id)
-
- self.next_id += 1
- else:
- # 计算当前检测结果与现有跟踪记录之间的距离
- track_centers = np.array([[x, y] for x, y, _, _ in self.tracks.values()])
-
- # 计算检测结果的中心点
- detection_centers = []
- for det in detections:
- x1, y1, x2, y2, _ = det
- center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
- detection_centers.append([center_x, center_y])
-
- detection_centers = np.array(detection_centers)
-
- # 计算距离矩阵
- distances = scipy.spatial.distance.cdist(track_centers, detection_centers)
-
- # 使用匈牙利算法进行跟踪-检测匹配
- from scipy.optimize import linear_sum_assignment
- track_indices, detection_indices = linear_sum_assignment(distances)
-
- # 初始化已使用的跟踪ID和检测结果
- used_tracks = set()
- used_detections = set()
-
- # 遍历所有匹配结果
- track_ids = list(self.tracks.keys())
- for track_idx, det_idx in zip(track_indices, detection_indices):
- # 如果匹配距离太大,则忽略
- if distances[track_idx, det_idx] > self.min_distance:
- continue
-
- track_id = track_ids[track_idx]
- used_tracks.add(track_id)
- used_detections.add(det_idx)
-
- # 获取当前跟踪对象的位置和尺寸
- curr_x, curr_y, curr_w, curr_h = self.tracks[track_id]
-
- # 更新跟踪记录
- x1, y1, x2, y2, _ = detections[det_idx]
- new_x, new_y = (x1 + x2) / 2, (y1 + y2) / 2
- new_w, new_h = x2 - x1, y2 - y1
-
- # 使用动态加权平均进行平滑
- alpha = 0.8 # 当前检测的权重
- smooth_x = alpha * new_x + (1 - alpha) * curr_x
- smooth_y = alpha * new_y + (1 - alpha) * curr_y
- smooth_w = alpha * new_w + (1 - alpha) * curr_w
- smooth_h = alpha * new_h + (1 - alpha) * curr_h
-
- # 更新跟踪对象
- self.tracks[track_id] = (smooth_x, smooth_y, smooth_w, smooth_h)
- self.disappeared[track_id] = 0
-
- # 更新轨迹历史
- self.track_history[track_id].append((smooth_x, smooth_y))
-
- # 处理未匹配的跟踪记录
- for track_id in list(self.tracks.keys()):
- if track_id not in used_tracks:
- self.disappeared[track_id] += 1
-
- # 如果连续消失帧数超过阈值,删除该跟踪记录
- if self.disappeared[track_id] > self.max_disappeared:
- del self.tracks[track_id]
- del self.disappeared[track_id]
- if track_id in self.colors:
- del self.colors[track_id]
-
- # 处理未匹配的检测结果
- for det_idx, det in enumerate(detections):
- if det_idx not in used_detections:
- x1, y1, x2, y2, _ = det
- center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
- width, height = x2 - x1, y2 - y1
-
- # 创建新的跟踪对象
- self.tracks[self.next_id] = (center_x, center_y, width, height)
- self.disappeared[self.next_id] = 0
-
- # 为新轨迹指定颜色
- self.colors[self.next_id] = self._generate_color(self.next_id)
-
- # 更新轨迹历史
- self.track_history[self.next_id].append((center_x, center_y))
- self.next_id += 1
-
- # 对轨迹进行平滑处理
- if self.track_smoothing:
- self._smooth_tracks()
-
- return self.tracks
-
- def _smooth_tracks(self):
- """使用Savitzky-Golay滤波器平滑轨迹"""
- for track_id in self.tracks.keys():
- # 如果轨迹长度足够,进行平滑
- history = self.track_history[track_id]
- if len(history) >= self.smooth_window:
- # 分离x和y坐标
- xs = np.array([p[0] for p in history])
- ys = np.array([p[1] for p in history])
-
- try:
- # 应用Savitzky-Golay滤波器
- smooth_xs = savgol_filter(xs, self.smooth_window, self.smooth_poly_order)
- smooth_ys = savgol_filter(ys, self.smooth_window, self.smooth_poly_order)
-
- # 更新最新位置为平滑后的位置
- x, y, w, h = self.tracks[track_id]
- self.tracks[track_id] = (smooth_xs[-1], smooth_ys[-1], w, h)
-
- # 更新历史轨迹
- self.track_history[track_id][-1] = (smooth_xs[-1], smooth_ys[-1])
- except:
- # 如果平滑失败,保持原样
- pass
-
- def _generate_color(self, track_id):
- """为轨迹生成一个唯一的颜色"""
- # 使用HSV颜色空间生成均匀分布的颜色
- h = (track_id * 0.1) % 1.0
- s = 0.8
- v = 0.8
-
- # 转换为RGB
- import colorsys
- r, g, b = colorsys.hsv_to_rgb(h, s, v)
- return (int(r * 255), int(g * 255), int(b * 255))
-
- def visualize(self, frame, draw_history=True, history_len=30):
- """
- 可视化跟踪结果
-
- Args:
- frame: 原始视频帧
- draw_history: 是否绘制历史轨迹
- history_len: 历史轨迹长度限制
-
- Returns:
- 标注后的视频帧
- """
- # 复制原始帧以避免修改
- annotated_frame = frame.copy()
-
- # 绘制当前跟踪的球员
- for track_id, (x, y, w, h) in self.tracks.items():
- x1, y1 = int(x - w/2), int(y - h/2)
- x2, y2 = int(x + w/2), int(y + h/2)
-
- # 获取该轨迹的颜色
- color = self.colors.get(track_id, (0, 255, 0))
-
- # 绘制边界框
- cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2)
-
- # 添加ID文本
- cv2.putText(annotated_frame, str(track_id), (x1, y1 - 5),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
-
- # 绘制历史轨迹
- if draw_history and track_id in self.track_history:
- history = self.track_history[track_id]
-
- # 限制历史轨迹长度
- if history_len > 0:
- history = history[-history_len:]
-
- for i in range(1, len(history)):
- # 轨迹线的透明度随着时间减弱
- alpha = min(1.0, i / len(history))
- c = (int(color[0] * alpha), int(color[1] * alpha), int(color[2] * alpha))
-
- pt1 = (int(history[i-1][0]), int(history[i-1][1]))
- pt2 = (int(history[i][0]), int(history[i][1]))
-
- # 确保点在图像范围内
- if 0 <= pt1[0] < frame.shape[1] and 0 <= pt1[1] < frame.shape[0] and \
- 0 <= pt2[0] < frame.shape[1] and 0 <= pt2[1] < frame.shape[0]:
- cv2.line(annotated_frame, pt1, pt2, c, 2)
-
- return annotated_frame
-
- def visualize_with_supervision(self, frame, detections_sv=None):
- """
- 使用supervision库可视化跟踪结果
-
- Args:
- frame: 原始视频帧
- detections_sv: 可选的supervision检测结果
-
- Returns:
- 标注后的视频帧
- """
- # 如果没有提供detections_sv,从当前跟踪状态创建
- if detections_sv is None:
- # 创建空的检测结果
- xyxy = np.zeros((len(self.tracks), 4))
- confidence = np.ones(len(self.tracks))
- class_id = np.zeros(len(self.tracks), dtype=int)
- tracker_id = np.array(list(self.tracks.keys()))
-
- # 填充边界框信息
- for i, (track_id, (x, y, w, h)) in enumerate(self.tracks.items()):
- x1, y1 = x - w/2, y - h/2
- x2, y2 = x + w/2, y + h/2
- xyxy[i] = [x1, y1, x2, y2]
-
- # 创建supervision检测结果
- detections_sv = sv.Detections(
- xyxy=xyxy,
- confidence=confidence,
- class_id=class_id,
- tracker_id=tracker_id
- )
-
- # 创建标签
- labels = [f"ID:{tracker_id}" for tracker_id in detections_sv.tracker_id]
-
- # 使用supervision的工具进行可视化
- annotated_frame = frame.copy()
-
- # 绘制边界框和ID
- annotated_frame = self.bounding_box_annotator.annotate(
- scene=annotated_frame,
- detections=detections_sv,
- labels=labels
- )
-
- # 绘制轨迹
- for track_id in self.tracks.keys():
- history = self.track_history[track_id]
-
- # 至少需要两个点才能绘制线条
- if len(history) < 2:
- continue
-
- # 获取该轨迹的颜色
- color = self.colors.get(track_id, (0, 255, 0))
-
- # 绘制完整轨迹
- points = np.array(history, dtype=np.int32)
- for i in range(1, len(points)):
- pt1 = tuple(points[i-1])
- pt2 = tuple(points[i])
- cv2.line(annotated_frame, pt1, pt2, color, 2)
-
- return annotated_frame
-
- def get_tracks_for_visualization(self):
- """
- 获取适合可视化的轨迹数据
-
- Returns:
- 轨迹数据,格式为 {id: [(x1, y1), (x2, y2), ...]}
- """
- visualization_tracks = {}
-
- for track_id, history in self.track_history.items():
- # 只返回活跃的轨迹
- if track_id in self.tracks:
- visualization_tracks[track_id] = history.copy()
-
- return visualization_tracks, self.colors
|