player_tracker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. import cv2
  2. import numpy as np
  3. from collections import defaultdict
  4. import scipy.spatial
  5. import supervision as sv
  6. from scipy.signal import savgol_filter
  7. class PlayerTracker:
  8. def __init__(self, max_disappeared=30, min_distance=50, track_smoothing=True):
  9. """
  10. 初始化球员跟踪器
  11. Args:
  12. max_disappeared: 最大连续丢失帧数
  13. min_distance: 最小匹配距离阈值
  14. track_smoothing: 是否平滑轨迹
  15. """
  16. # 跟踪参数
  17. self.tracks = {}
  18. self.next_id = 1
  19. self.max_disappeared = max_disappeared
  20. self.disappeared = defaultdict(int)
  21. self.min_distance = min_distance
  22. # 存储轨迹历史
  23. self.track_history = defaultdict(list)
  24. # 轨迹平滑
  25. self.track_smoothing = track_smoothing
  26. self.smooth_window = 5 # 平滑窗口大小,必须是奇数
  27. self.smooth_poly_order = 2 # 平滑多项式阶数
  28. # 初始化彩色轨迹
  29. self.colors = {}
  30. # 初始化supervision轨迹工具
  31. self.bounding_box_annotator = sv.BoxAnnotator(
  32. thickness=2,
  33. text_thickness=2,
  34. text_scale=0.5
  35. )
  36. # 注释掉TraceAnnotator,当前版本的supervision库不支持
  37. # self.trace_annotator = sv.TraceAnnotator(
  38. # thickness=2,
  39. # trace_length=20
  40. # )
  41. def update(self, detections, frame):
  42. """
  43. 更新跟踪状态
  44. Args:
  45. detections: 当前帧检测到的球员 [x1, y1, x2, y2, conf]
  46. frame: 当前视频帧
  47. Returns:
  48. 当前帧的跟踪状态 {id: (x, y, w, h)}
  49. """
  50. # 如果没有检测到任何球员
  51. if len(detections) == 0:
  52. # 增加所有跟踪记录的消失计数
  53. for track_id in list(self.tracks.keys()):
  54. self.disappeared[track_id] += 1
  55. # 如果连续消失帧数超过阈值,删除该跟踪记录
  56. if self.disappeared[track_id] > self.max_disappeared:
  57. del self.tracks[track_id]
  58. del self.disappeared[track_id]
  59. return self.tracks
  60. # 如果还没有任何跟踪记录,为所有检测结果创建新的跟踪ID
  61. if len(self.tracks) == 0:
  62. for i, det in enumerate(detections):
  63. x1, y1, x2, y2, _ = det
  64. center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
  65. width, height = x2 - x1, y2 - y1
  66. self.tracks[self.next_id] = (center_x, center_y, width, height)
  67. self.track_history[self.next_id].append((center_x, center_y))
  68. # 为新轨迹指定颜色
  69. self.colors[self.next_id] = self._generate_color(self.next_id)
  70. self.next_id += 1
  71. else:
  72. # 计算当前检测结果与现有跟踪记录之间的距离
  73. track_centers = np.array([[x, y] for x, y, _, _ in self.tracks.values()])
  74. # 计算检测结果的中心点
  75. detection_centers = []
  76. for det in detections:
  77. x1, y1, x2, y2, _ = det
  78. center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
  79. detection_centers.append([center_x, center_y])
  80. detection_centers = np.array(detection_centers)
  81. # 计算距离矩阵
  82. distances = scipy.spatial.distance.cdist(track_centers, detection_centers)
  83. # 使用匈牙利算法进行跟踪-检测匹配
  84. from scipy.optimize import linear_sum_assignment
  85. track_indices, detection_indices = linear_sum_assignment(distances)
  86. # 初始化已使用的跟踪ID和检测结果
  87. used_tracks = set()
  88. used_detections = set()
  89. # 遍历所有匹配结果
  90. track_ids = list(self.tracks.keys())
  91. for track_idx, det_idx in zip(track_indices, detection_indices):
  92. # 如果匹配距离太大,则忽略
  93. if distances[track_idx, det_idx] > self.min_distance:
  94. continue
  95. track_id = track_ids[track_idx]
  96. used_tracks.add(track_id)
  97. used_detections.add(det_idx)
  98. # 获取当前跟踪对象的位置和尺寸
  99. curr_x, curr_y, curr_w, curr_h = self.tracks[track_id]
  100. # 更新跟踪记录
  101. x1, y1, x2, y2, _ = detections[det_idx]
  102. new_x, new_y = (x1 + x2) / 2, (y1 + y2) / 2
  103. new_w, new_h = x2 - x1, y2 - y1
  104. # 使用动态加权平均进行平滑
  105. alpha = 0.8 # 当前检测的权重
  106. smooth_x = alpha * new_x + (1 - alpha) * curr_x
  107. smooth_y = alpha * new_y + (1 - alpha) * curr_y
  108. smooth_w = alpha * new_w + (1 - alpha) * curr_w
  109. smooth_h = alpha * new_h + (1 - alpha) * curr_h
  110. # 更新跟踪对象
  111. self.tracks[track_id] = (smooth_x, smooth_y, smooth_w, smooth_h)
  112. self.disappeared[track_id] = 0
  113. # 更新轨迹历史
  114. self.track_history[track_id].append((smooth_x, smooth_y))
  115. # 处理未匹配的跟踪记录
  116. for track_id in list(self.tracks.keys()):
  117. if track_id not in used_tracks:
  118. self.disappeared[track_id] += 1
  119. # 如果连续消失帧数超过阈值,删除该跟踪记录
  120. if self.disappeared[track_id] > self.max_disappeared:
  121. del self.tracks[track_id]
  122. del self.disappeared[track_id]
  123. if track_id in self.colors:
  124. del self.colors[track_id]
  125. # 处理未匹配的检测结果
  126. for det_idx, det in enumerate(detections):
  127. if det_idx not in used_detections:
  128. x1, y1, x2, y2, _ = det
  129. center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
  130. width, height = x2 - x1, y2 - y1
  131. # 创建新的跟踪对象
  132. self.tracks[self.next_id] = (center_x, center_y, width, height)
  133. self.disappeared[self.next_id] = 0
  134. # 为新轨迹指定颜色
  135. self.colors[self.next_id] = self._generate_color(self.next_id)
  136. # 更新轨迹历史
  137. self.track_history[self.next_id].append((center_x, center_y))
  138. self.next_id += 1
  139. # 对轨迹进行平滑处理
  140. if self.track_smoothing:
  141. self._smooth_tracks()
  142. return self.tracks
  143. def _smooth_tracks(self):
  144. """使用Savitzky-Golay滤波器平滑轨迹"""
  145. for track_id in self.tracks.keys():
  146. # 如果轨迹长度足够,进行平滑
  147. history = self.track_history[track_id]
  148. if len(history) >= self.smooth_window:
  149. # 分离x和y坐标
  150. xs = np.array([p[0] for p in history])
  151. ys = np.array([p[1] for p in history])
  152. try:
  153. # 应用Savitzky-Golay滤波器
  154. smooth_xs = savgol_filter(xs, self.smooth_window, self.smooth_poly_order)
  155. smooth_ys = savgol_filter(ys, self.smooth_window, self.smooth_poly_order)
  156. # 更新最新位置为平滑后的位置
  157. x, y, w, h = self.tracks[track_id]
  158. self.tracks[track_id] = (smooth_xs[-1], smooth_ys[-1], w, h)
  159. # 更新历史轨迹
  160. self.track_history[track_id][-1] = (smooth_xs[-1], smooth_ys[-1])
  161. except:
  162. # 如果平滑失败,保持原样
  163. pass
  164. def _generate_color(self, track_id):
  165. """为轨迹生成一个唯一的颜色"""
  166. # 使用HSV颜色空间生成均匀分布的颜色
  167. h = (track_id * 0.1) % 1.0
  168. s = 0.8
  169. v = 0.8
  170. # 转换为RGB
  171. import colorsys
  172. r, g, b = colorsys.hsv_to_rgb(h, s, v)
  173. return (int(r * 255), int(g * 255), int(b * 255))
  174. def visualize(self, frame, draw_history=True, history_len=30):
  175. """
  176. 可视化跟踪结果
  177. Args:
  178. frame: 原始视频帧
  179. draw_history: 是否绘制历史轨迹
  180. history_len: 历史轨迹长度限制
  181. Returns:
  182. 标注后的视频帧
  183. """
  184. # 复制原始帧以避免修改
  185. annotated_frame = frame.copy()
  186. # 绘制当前跟踪的球员
  187. for track_id, (x, y, w, h) in self.tracks.items():
  188. x1, y1 = int(x - w/2), int(y - h/2)
  189. x2, y2 = int(x + w/2), int(y + h/2)
  190. # 获取该轨迹的颜色
  191. color = self.colors.get(track_id, (0, 255, 0))
  192. # 绘制边界框
  193. cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2)
  194. # 添加ID文本
  195. cv2.putText(annotated_frame, str(track_id), (x1, y1 - 5),
  196. cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
  197. # 绘制历史轨迹
  198. if draw_history and track_id in self.track_history:
  199. history = self.track_history[track_id]
  200. # 限制历史轨迹长度
  201. if history_len > 0:
  202. history = history[-history_len:]
  203. for i in range(1, len(history)):
  204. # 轨迹线的透明度随着时间减弱
  205. alpha = min(1.0, i / len(history))
  206. c = (int(color[0] * alpha), int(color[1] * alpha), int(color[2] * alpha))
  207. pt1 = (int(history[i-1][0]), int(history[i-1][1]))
  208. pt2 = (int(history[i][0]), int(history[i][1]))
  209. # 确保点在图像范围内
  210. if 0 <= pt1[0] < frame.shape[1] and 0 <= pt1[1] < frame.shape[0] and \
  211. 0 <= pt2[0] < frame.shape[1] and 0 <= pt2[1] < frame.shape[0]:
  212. cv2.line(annotated_frame, pt1, pt2, c, 2)
  213. return annotated_frame
  214. def visualize_with_supervision(self, frame, detections_sv=None):
  215. """
  216. 使用supervision库可视化跟踪结果
  217. Args:
  218. frame: 原始视频帧
  219. detections_sv: 可选的supervision检测结果
  220. Returns:
  221. 标注后的视频帧
  222. """
  223. # 如果没有提供detections_sv,从当前跟踪状态创建
  224. if detections_sv is None:
  225. # 创建空的检测结果
  226. xyxy = np.zeros((len(self.tracks), 4))
  227. confidence = np.ones(len(self.tracks))
  228. class_id = np.zeros(len(self.tracks), dtype=int)
  229. tracker_id = np.array(list(self.tracks.keys()))
  230. # 填充边界框信息
  231. for i, (track_id, (x, y, w, h)) in enumerate(self.tracks.items()):
  232. x1, y1 = x - w/2, y - h/2
  233. x2, y2 = x + w/2, y + h/2
  234. xyxy[i] = [x1, y1, x2, y2]
  235. # 创建supervision检测结果
  236. detections_sv = sv.Detections(
  237. xyxy=xyxy,
  238. confidence=confidence,
  239. class_id=class_id,
  240. tracker_id=tracker_id
  241. )
  242. # 创建标签
  243. labels = [f"ID:{tracker_id}" for tracker_id in detections_sv.tracker_id]
  244. # 使用supervision的工具进行可视化
  245. annotated_frame = frame.copy()
  246. # 绘制边界框和ID
  247. annotated_frame = self.bounding_box_annotator.annotate(
  248. scene=annotated_frame,
  249. detections=detections_sv,
  250. labels=labels
  251. )
  252. # 绘制轨迹
  253. for track_id in self.tracks.keys():
  254. history = self.track_history[track_id]
  255. # 至少需要两个点才能绘制线条
  256. if len(history) < 2:
  257. continue
  258. # 获取该轨迹的颜色
  259. color = self.colors.get(track_id, (0, 255, 0))
  260. # 绘制完整轨迹
  261. points = np.array(history, dtype=np.int32)
  262. for i in range(1, len(points)):
  263. pt1 = tuple(points[i-1])
  264. pt2 = tuple(points[i])
  265. cv2.line(annotated_frame, pt1, pt2, color, 2)
  266. return annotated_frame
  267. def get_tracks_for_visualization(self):
  268. """
  269. 获取适合可视化的轨迹数据
  270. Returns:
  271. 轨迹数据,格式为 {id: [(x1, y1), (x2, y2), ...]}
  272. """
  273. visualization_tracks = {}
  274. for track_id, history in self.track_history.items():
  275. # 只返回活跃的轨迹
  276. if track_id in self.tracks:
  277. visualization_tracks[track_id] = history.copy()
  278. return visualization_tracks, self.colors