field_transformer.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import cv2
  2. import numpy as np
  3. class FieldTransformer:
  4. """用于将视频中的像素坐标转换为场地坐标的工具类"""
  5. def __init__(self):
  6. # 标准橄榄球场尺寸(单位:米)
  7. self.field_length = 100 # 大约100码(不包括端区)
  8. self.field_width = 50 # 大约53.3码
  9. # 默认变换矩阵(假设俯视图)
  10. self.perspective_matrix = None
  11. # 是否已自动检测场地
  12. self.field_detected = False
  13. # 场地边界点(图像坐标系)
  14. self.field_corners = None
  15. def auto_detect_field(self, frame):
  16. """尝试自动检测球场区域"""
  17. try:
  18. # 转换为灰度图
  19. gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  20. # 应用高斯模糊
  21. blur = cv2.GaussianBlur(gray, (5, 5), 0)
  22. # 边缘检测
  23. edges = cv2.Canny(blur, 50, 150)
  24. # 膨胀边缘以连接间隙
  25. dilated = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=2)
  26. # 查找轮廓
  27. contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  28. # 按面积排序
  29. contours = sorted(contours, key=cv2.contourArea, reverse=True)
  30. # 尝试找到矩形区域(假设场地为矩形)
  31. for contour in contours[:5]: # 只检查最大的5个轮廓
  32. # 近似轮廓
  33. epsilon = 0.02 * cv2.arcLength(contour, True)
  34. approx = cv2.approxPolyDP(contour, epsilon, True)
  35. # 如果近似为四边形,可能是场地
  36. if len(approx) == 4:
  37. # 进一步处理,确认是否为场地
  38. field_area = cv2.contourArea(approx)
  39. frame_area = frame.shape[0] * frame.shape[1]
  40. # 如果面积占总面积的比例合适,则认为找到了场地
  41. if 0.2 < field_area / frame_area < 0.95:
  42. # 存储场地角点
  43. self.field_corners = approx.reshape(4, 2)
  44. # 根据场地角点计算透视变换矩阵
  45. self._calculate_perspective_matrix()
  46. self.field_detected = True
  47. return True
  48. # 未找到合适的场地轮廓,使用默认变换
  49. self._setup_default_transform(frame)
  50. return False
  51. except Exception as e:
  52. print(f"场地自动检测失败: {str(e)}")
  53. self._setup_default_transform(frame)
  54. return False
  55. def _setup_default_transform(self, frame):
  56. """设置默认变换矩阵(假设俯视图)"""
  57. height, width = frame.shape[:2]
  58. # 默认场地框(图像四角)
  59. self.field_corners = np.array([
  60. [0.1 * width, 0.1 * height], # 左上
  61. [0.9 * width, 0.1 * height], # 右上
  62. [0.9 * width, 0.9 * height], # 右下
  63. [0.1 * width, 0.9 * height] # 左下
  64. ], dtype=np.float32)
  65. self._calculate_perspective_matrix()
  66. def _calculate_perspective_matrix(self):
  67. """计算透视变换矩阵"""
  68. # 标准化角点顺序:左上,右上,右下,左下
  69. rect = self._order_points(self.field_corners)
  70. # 目标场地坐标(标准场地,单位:米)
  71. dst = np.array([
  72. [0, 0], # 左上
  73. [self.field_length, 0], # 右上
  74. [self.field_length, self.field_width], # 右下
  75. [0, self.field_width] # 左下
  76. ], dtype=np.float32)
  77. # 计算透视变换矩阵
  78. self.perspective_matrix = cv2.getPerspectiveTransform(rect, dst)
  79. def _order_points(self, pts):
  80. """对四边形角点进行排序: 左上,右上,右下,左下"""
  81. # 初始化结果数组
  82. rect = np.zeros((4, 2), dtype=np.float32)
  83. # 计算点的和,左上角的和最小,右下角的和最大
  84. s = pts.sum(axis=1)
  85. rect[0] = pts[np.argmin(s)] # 左上
  86. rect[2] = pts[np.argmax(s)] # 右下
  87. # 计算差,右上角的差最小,左下角的差最大
  88. diff = np.diff(pts, axis=1)
  89. rect[1] = pts[np.argmin(diff)] # 右上
  90. rect[3] = pts[np.argmax(diff)] # 左下
  91. return rect
  92. def transform_to_field(self, tracks, frame=None):
  93. """将球员位置从像素坐标转换到场地坐标
  94. Args:
  95. tracks: {id: (x, y, w, h)} 形式的字典,x,y是中心点
  96. frame: 可选的原始视频帧,用于自动检测场地
  97. Returns:
  98. {id: (field_x, field_y)} 形式的字典
  99. """
  100. # 如果没有初始化透视矩阵且提供了frame
  101. if self.perspective_matrix is None and frame is not None:
  102. self.auto_detect_field(frame)
  103. # 如果仍未初始化,使用默认值
  104. if self.perspective_matrix is None:
  105. # 创建一个假想的frame
  106. mock_frame = np.zeros((720, 1280, 3), dtype=np.uint8)
  107. self._setup_default_transform(mock_frame)
  108. # 转换每个轨迹到场地坐标
  109. field_positions = {}
  110. for track_id, (x, y, w, h) in tracks.items():
  111. # 转换中心点
  112. pt = np.array([[x, y]], dtype=np.float32)
  113. pt = cv2.perspectiveTransform(pt.reshape(-1, 1, 2), self.perspective_matrix)
  114. # 存储场地坐标
  115. field_positions[track_id] = (float(pt[0][0][0]), float(pt[0][0][1]))
  116. return field_positions
  117. def convert_field_to_image(self, field_positions, frame_shape):
  118. """将场地坐标转换回图像坐标
  119. Args:
  120. field_positions: {id: (field_x, field_y)} 形式的字典
  121. frame_shape: 原始帧的形状(高度, 宽度)
  122. Returns:
  123. {id: (image_x, image_y)} 形式的字典
  124. """
  125. if self.perspective_matrix is None:
  126. # 创建一个假想的frame
  127. mock_frame = np.zeros((frame_shape[0], frame_shape[1], 3), dtype=np.uint8)
  128. self._setup_default_transform(mock_frame)
  129. # 计算逆变换矩阵
  130. inv_matrix = np.linalg.inv(self.perspective_matrix)
  131. # 转换每个场地坐标到图像坐标
  132. image_positions = {}
  133. for track_id, (field_x, field_y) in field_positions.items():
  134. # 转换点
  135. pt = np.array([[[field_x, field_y]]], dtype=np.float32)
  136. pt = cv2.perspectiveTransform(pt, inv_matrix)
  137. # 存储图像坐标
  138. image_positions[track_id] = (float(pt[0][0][0]), float(pt[0][0][1]))
  139. return image_positions
  140. def visualize_field_transform(self, frame):
  141. """在帧上可视化场地变换"""
  142. if self.field_corners is None:
  143. self.auto_detect_field(frame)
  144. vis_frame = frame.copy()
  145. # 绘制检测到的场地边界
  146. if self.field_corners is not None:
  147. # 将角点转换为整数坐标
  148. corners = self.field_corners.astype(np.int32)
  149. # 绘制场地边界
  150. cv2.polylines(vis_frame, [corners], True, (0, 255, 0), 2)
  151. # 标记角点
  152. for i, (x, y) in enumerate(corners):
  153. cv2.circle(vis_frame, (x, y), 5, (0, 0, 255), -1)
  154. cv2.putText(vis_frame, str(i), (x-10, y-10),
  155. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
  156. # 添加说明文本
  157. cv2.putText(vis_frame, "检测到场地", (10, 30),
  158. cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
  159. else:
  160. cv2.putText(vis_frame, "未检测到场地,使用默认变换", (10, 30),
  161. cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  162. return vis_frame