app.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import os
  2. import uuid
  3. import time
  4. from flask import Flask, request, jsonify, send_from_directory, Response
  5. from flask_cors import CORS
  6. import cv2
  7. import numpy as np
  8. import json
  9. import tempfile
  10. from werkzeug.utils import secure_filename
  11. import threading
  12. from models.player_detector import PlayerDetector
  13. from models.player_tracker import PlayerTracker
  14. from utils.field_transformer import FieldTransformer
  15. app = Flask(__name__, static_folder='uploads')
  16. CORS(app)
  17. # 配置上传文件夹
  18. UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
  19. os.makedirs(UPLOAD_FOLDER, exist_ok=True)
  20. os.makedirs(os.path.join(UPLOAD_FOLDER, 'videos'), exist_ok=True)
  21. os.makedirs(os.path.join(UPLOAD_FOLDER, 'results'), exist_ok=True)
  22. os.makedirs(os.path.join(UPLOAD_FOLDER, 'frames'), exist_ok=True) # 存储视频帧
  23. os.makedirs(os.path.join(UPLOAD_FOLDER, 'visualizations'), exist_ok=True) # 存储可视化结果
  24. # 初始化模型
  25. player_detector = PlayerDetector(confidence=0.6) # 提高置信度阈值
  26. player_tracker = PlayerTracker()
  27. field_transformer = FieldTransformer()
  28. # 存储处理任务状态
  29. processing_tasks = {}
  30. @app.route('/api/upload', methods=['POST'])
  31. def upload_file():
  32. if 'video' not in request.files:
  33. return jsonify({'error': '没有上传文件'}), 400
  34. file = request.files['video']
  35. if file.filename == '':
  36. return jsonify({'error': '没有选择文件'}), 400
  37. # 生成唯一ID并保存视频
  38. video_id = str(uuid.uuid4())
  39. filename = secure_filename(file.filename)
  40. video_path = os.path.join(UPLOAD_FOLDER, 'videos', f"{video_id}_{filename}")
  41. file.save(video_path)
  42. # 创建处理任务
  43. processing_tasks[video_id] = {
  44. 'status': 'processing',
  45. 'progress': 0,
  46. 'video_path': video_path,
  47. 'filename': filename
  48. }
  49. # 异步处理视频分析
  50. threading.Thread(target=process_video_async, args=(video_path, video_id)).start()
  51. return jsonify({
  52. 'success': True,
  53. 'videoId': video_id,
  54. 'message': '视频上传成功,正在处理中',
  55. 'status': 'processing'
  56. })
  57. def process_video_async(video_path, video_id):
  58. """异步处理视频"""
  59. try:
  60. result = process_video(video_path, video_id)
  61. processing_tasks[video_id]['status'] = 'completed'
  62. processing_tasks[video_id]['result'] = result
  63. except Exception as e:
  64. processing_tasks[video_id]['status'] = 'error'
  65. processing_tasks[video_id]['error'] = str(e)
  66. print(f"处理视频时出错: {str(e)}")
  67. def process_video(video_path, video_id):
  68. """处理视频并提取球员轨迹"""
  69. # 读取视频
  70. cap = cv2.VideoCapture(video_path)
  71. # 获取视频基本信息
  72. fps = cap.get(cv2.CAP_PROP_FPS)
  73. total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  74. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  75. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  76. # 创建帧文件夹
  77. frames_dir = os.path.join(UPLOAD_FOLDER, 'frames', video_id)
  78. os.makedirs(frames_dir, exist_ok=True)
  79. # 创建可视化文件夹
  80. vis_dir = os.path.join(UPLOAD_FOLDER, 'visualizations', video_id)
  81. os.makedirs(vis_dir, exist_ok=True)
  82. # 存储轨迹数据
  83. all_tracks = []
  84. frame_count = 0
  85. # 第一帧用于检测场地(可选)
  86. ret, first_frame = cap.read()
  87. if ret:
  88. # 尝试自动检测场地
  89. field_detected = field_transformer.auto_detect_field(first_frame)
  90. # 如果自动检测失败,使用默认值
  91. if not field_detected:
  92. print("场地自动检测失败,使用默认变换")
  93. # 重置视频读取
  94. cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
  95. while cap.isOpened():
  96. ret, frame = cap.read()
  97. if not ret:
  98. break
  99. # 更新进度
  100. if video_id in processing_tasks:
  101. processing_tasks[video_id]['progress'] = min(95, int((frame_count / total_frames) * 100))
  102. # 保存帧
  103. frame_path = os.path.join(frames_dir, f"frame_{frame_count:06d}.jpg")
  104. cv2.imwrite(frame_path, frame)
  105. # 检测球员
  106. detections = player_detector.detect(frame)
  107. # 跟踪球员
  108. tracks = player_tracker.update(detections, frame)
  109. # 变换坐标到标准足球场
  110. field_positions = field_transformer.transform_to_field(tracks, frame)
  111. # 可视化跟踪结果
  112. vis_frame = player_tracker.visualize(frame, draw_history=True)
  113. vis_path = os.path.join(vis_dir, f"vis_{frame_count:06d}.jpg")
  114. cv2.imwrite(vis_path, vis_frame)
  115. # 保存当前帧的轨迹数据
  116. frame_data = {
  117. 'frame': frame_count,
  118. 'tracks': []
  119. }
  120. for track_id, position in field_positions.items():
  121. # 从原始跟踪数据获取球员位置和尺寸
  122. if track_id in tracks:
  123. x, y, w, h = tracks[track_id]
  124. x1, y1 = int(x - w/2), int(y - h/2)
  125. x2, y2 = int(x + w/2), int(y + h/2)
  126. frame_data['tracks'].append({
  127. 'id': int(track_id),
  128. 'x': float(position[0]),
  129. 'y': float(position[1]),
  130. 'z': 0.0, # 高度默认为0
  131. 'image_pos': {
  132. 'x1': x1,
  133. 'y1': y1,
  134. 'x2': x2,
  135. 'y2': y2
  136. }
  137. })
  138. all_tracks.append(frame_data)
  139. frame_count += 1
  140. cap.release()
  141. # 保存结果到JSON文件
  142. result_path = os.path.join(UPLOAD_FOLDER, 'results', f"{video_id}_tracks.json")
  143. with open(result_path, 'w') as f:
  144. json.dump(all_tracks, f)
  145. # 保存视频信息
  146. video_info = {
  147. 'fps': fps,
  148. 'total_frames': total_frames,
  149. 'width': width,
  150. 'height': height,
  151. 'filename': os.path.basename(video_path)
  152. }
  153. video_info_path = os.path.join(UPLOAD_FOLDER, 'results', f"{video_id}_info.json")
  154. with open(video_info_path, 'w') as f:
  155. json.dump(video_info, f)
  156. return {
  157. 'trackFile': f"/api/results/{video_id}_tracks.json",
  158. 'infoFile': f"/api/results/{video_id}_info.json",
  159. 'totalFrames': frame_count,
  160. 'framesDir': f"/api/frames/{video_id}",
  161. 'visualizationsDir': f"/api/visualizations/{video_id}",
  162. 'fps': fps
  163. }
  164. @app.route('/api/status/<video_id>', methods=['GET'])
  165. def get_status(video_id):
  166. """获取视频处理状态"""
  167. if video_id not in processing_tasks:
  168. return jsonify({'error': '找不到处理任务'}), 404
  169. task = processing_tasks[video_id]
  170. response = {
  171. 'status': task['status'],
  172. 'progress': task['progress']
  173. }
  174. if task['status'] == 'completed' and 'result' in task:
  175. response['result'] = task['result']
  176. elif task['status'] == 'error' and 'error' in task:
  177. response['error'] = task['error']
  178. return jsonify(response)
  179. @app.route('/api/results/<filename>', methods=['GET'])
  180. def get_result(filename):
  181. return send_from_directory(os.path.join(UPLOAD_FOLDER, 'results'), filename)
  182. @app.route('/api/frames/<video_id>/<frame_filename>', methods=['GET'])
  183. def get_frame(video_id, frame_filename):
  184. return send_from_directory(os.path.join(UPLOAD_FOLDER, 'frames', video_id), frame_filename)
  185. @app.route('/api/visualizations/<video_id>/<vis_filename>', methods=['GET'])
  186. def get_visualization(video_id, vis_filename):
  187. return send_from_directory(os.path.join(UPLOAD_FOLDER, 'visualizations', video_id), vis_filename)
  188. @app.route('/api/video/<video_id>', methods=['GET'])
  189. def get_video(video_id):
  190. """获取原始视频"""
  191. # 查找对应的视频文件
  192. for filename in os.listdir(os.path.join(UPLOAD_FOLDER, 'videos')):
  193. if filename.startswith(video_id):
  194. return send_from_directory(os.path.join(UPLOAD_FOLDER, 'videos'), filename)
  195. return jsonify({'error': '找不到视频文件'}), 404
  196. @app.route('/api/generate_visualization_video/<video_id>', methods=['GET'])
  197. def generate_visualization_video(video_id):
  198. """生成包含轨迹可视化的视频"""
  199. if video_id not in processing_tasks or processing_tasks[video_id]['status'] != 'completed':
  200. return jsonify({'error': '视频尚未处理完成'}), 400
  201. vis_dir = os.path.join(UPLOAD_FOLDER, 'visualizations', video_id)
  202. if not os.path.exists(vis_dir) or len(os.listdir(vis_dir)) == 0:
  203. return jsonify({'error': '找不到可视化结果'}), 404
  204. try:
  205. # 获取视频信息
  206. with open(os.path.join(UPLOAD_FOLDER, 'results', f"{video_id}_info.json"), 'r') as f:
  207. video_info = json.load(f)
  208. fps = video_info.get('fps', 30)
  209. width = video_info.get('width', 1280)
  210. height = video_info.get('height', 720)
  211. # 输出视频路径
  212. output_path = os.path.join(UPLOAD_FOLDER, 'videos', f"{video_id}_visualization.mp4")
  213. # 获取所有可视化帧文件
  214. frame_files = sorted([f for f in os.listdir(vis_dir) if f.startswith('vis_')])
  215. if not frame_files:
  216. return jsonify({'error': '没有找到可视化帧'}), 404
  217. # 读取第一帧获取尺寸
  218. first_frame = cv2.imread(os.path.join(vis_dir, frame_files[0]))
  219. h, w, _ = first_frame.shape
  220. # 创建视频写入器
  221. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  222. out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
  223. # 写入每一帧
  224. for frame_file in frame_files:
  225. frame = cv2.imread(os.path.join(vis_dir, frame_file))
  226. out.write(frame)
  227. out.release()
  228. return jsonify({
  229. 'success': True,
  230. 'video_url': f"/api/videos/{video_id}_visualization.mp4"
  231. })
  232. except Exception as e:
  233. return jsonify({'error': f'生成可视化视频失败: {str(e)}'}), 500
  234. @app.route('/api/videos/<filename>', methods=['GET'])
  235. def get_visualization_video(filename):
  236. return send_from_directory(os.path.join(UPLOAD_FOLDER, 'videos'), filename)
  237. if __name__ == '__main__':
  238. app.run(debug=True, host='0.0.0.0', port=5000)