use std::sync::Arc; use axum::{ extract::{State, WebSocketUpgrade, ws::{Message, WebSocket}}, response::Response, routing::get, Router, }; use futures::{SinkExt, StreamExt}; use sqlx::sqlite::SqlitePool; use tokio::sync::broadcast; use crate::agent::WsMessage; use crate::worker::{WorkerInfo, WorkerManager, WorkerToServer}; pub struct WsWorkerState { pub mgr: Arc, pub pool: SqlitePool, pub broadcast_fn: Arc broadcast::Sender + Send + Sync>, } pub fn router(mgr: Arc, pool: SqlitePool, broadcast_fn: Arc broadcast::Sender + Send + Sync>) -> Router { let state = Arc::new(WsWorkerState { mgr, pool, broadcast_fn }); Router::new() .route("/", get(ws_handler)) .with_state(state) } async fn ws_handler( ws: WebSocketUpgrade, State(state): State>, ) -> Response { ws.on_upgrade(move |socket| handle_worker_socket(socket, state)) } async fn handle_worker_socket(socket: WebSocket, state: Arc) { let (mut sender, mut receiver) = socket.split(); // First message must be registration let (name, mut msg_rx) = loop { match receiver.next().await { Some(Ok(Message::Text(text))) => { match serde_json::from_str::(&text) { Ok(v) if v["type"] == "register" => { let info: WorkerInfo = match serde_json::from_value(v["info"].clone()) { Ok(i) => i, Err(_) => { let _ = sender.send(Message::Text( r#"{"type":"error","message":"Invalid worker info"}"#.into(), )).await; return; } }; let name = info.name.clone(); let rx = state.mgr.register(name.clone(), info).await; let ack = serde_json::json!({ "type": "registered", "name": &name }); let _ = sender.send(Message::Text(ack.to_string().into())).await; break (name, rx); } _ => { let _ = sender.send(Message::Text( r#"{"type":"error","message":"First message must be register"}"#.into(), )).await; return; } } } Some(Ok(Message::Close(_))) | None => return, _ => continue, } }; let name_clone = name.clone(); let mgr_for_cleanup = state.mgr.clone(); // Task: send ServerToWorker messages from msg_rx to the WebSocket let send_task = tokio::spawn(async move { while let Some(msg) = msg_rx.recv().await { if let Ok(json) = serde_json::to_string(&msg) { if sender.send(Message::Text(json.into())).await.is_err() { break; } } } }); // Task: receive WorkerToServer messages from the WebSocket → process let state_clone = state.clone(); let recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { match msg { Message::Text(text) => { match serde_json::from_str::(&text) { Ok(worker_msg) => { handle_worker_message(&state_clone, worker_msg).await; } Err(e) => { let preview: String = text.chars().take(200).collect(); tracing::warn!("Failed to parse worker message: {} — raw: {}", e, preview); } } } Message::Close(_) => break, _ => {} } } }); tokio::select! { _ = send_task => {}, _ = recv_task => {}, } // Log reason for any orphaned workflows before cleanup let orphan_workflows: Vec = { let assignments = mgr_for_cleanup.assignments_for_worker(&name_clone).await; assignments }; if !orphan_workflows.is_empty() { let reason = format!("Worker '{}' 断开连接", name_clone); for wf_id in &orphan_workflows { let _ = sqlx::query("UPDATE workflows SET status = 'failed', status_reason = ? WHERE id = ? AND status IN ('executing', 'planning')") .bind(&reason).bind(wf_id).execute(&state.pool).await; let log_id = uuid::Uuid::new_v4().to_string(); let _ = sqlx::query( "INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, 0, 'system', 'worker_disconnect', ?, 'failed', datetime('now'))" ).bind(&log_id).bind(wf_id).bind(&reason).execute(&state.pool).await; tracing::warn!("Workflow {} orphaned: {}", wf_id, reason); } } mgr_for_cleanup.unregister(&name_clone).await; } async fn handle_worker_message(state: &WsWorkerState, msg: WorkerToServer) { match msg { WorkerToServer::Register { .. } => { // Already handled during initial handshake } WorkerToServer::Update { workflow_id, update } => { // Get project_id for broadcasting (look up from DB) let project_id: Option = sqlx::query_scalar( "SELECT project_id FROM workflows WHERE id = ?" ) .bind(&workflow_id) .fetch_optional(&state.pool) .await .ok() .flatten(); let broadcast_tx = if let Some(ref pid) = project_id { Some((state.broadcast_fn)(pid)) } else { None }; // Check if this is a workflow completion if let crate::sink::AgentUpdate::WorkflowComplete { ref workflow_id, .. } = update { state.mgr.complete_workflow(workflow_id).await; } // Process the update: write to DB + broadcast crate::sink::handle_single_update(&update, &state.pool, broadcast_tx.as_ref()).await; } } }