tori/src/ws_worker.rs
Fam Zheng c56bfd9377 feat: status_reason field for workflows + proper failure logging
- Add status_reason column to workflows table (migration)
- AgentUpdate::WorkflowStatus and WorkflowComplete carry reason
- Dispatch failure logs to execution_log with reason
- Worker disconnect marks orphaned workflows as failed with reason
- All status transitions now have traceable cause
2026-04-06 20:33:41 +01:00

166 lines
6.3 KiB
Rust

use std::sync::Arc;
use axum::{
extract::{State, WebSocketUpgrade, ws::{Message, WebSocket}},
response::Response,
routing::get,
Router,
};
use futures::{SinkExt, StreamExt};
use sqlx::sqlite::SqlitePool;
use tokio::sync::broadcast;
use crate::agent::WsMessage;
use crate::worker::{WorkerInfo, WorkerManager, WorkerToServer};
pub struct WsWorkerState {
pub mgr: Arc<WorkerManager>,
pub pool: SqlitePool,
pub broadcast_fn: Arc<dyn Fn(&str) -> broadcast::Sender<WsMessage> + Send + Sync>,
}
pub fn router(mgr: Arc<WorkerManager>, pool: SqlitePool, broadcast_fn: Arc<dyn Fn(&str) -> broadcast::Sender<WsMessage> + Send + Sync>) -> Router {
let state = Arc::new(WsWorkerState { mgr, pool, broadcast_fn });
Router::new()
.route("/", get(ws_handler))
.with_state(state)
}
async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<WsWorkerState>>,
) -> Response {
ws.on_upgrade(move |socket| handle_worker_socket(socket, state))
}
async fn handle_worker_socket(socket: WebSocket, state: Arc<WsWorkerState>) {
let (mut sender, mut receiver) = socket.split();
// First message must be registration
let (name, mut msg_rx) = loop {
match receiver.next().await {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<serde_json::Value>(&text) {
Ok(v) if v["type"] == "register" => {
let info: WorkerInfo = match serde_json::from_value(v["info"].clone()) {
Ok(i) => i,
Err(_) => {
let _ = sender.send(Message::Text(
r#"{"type":"error","message":"Invalid worker info"}"#.into(),
)).await;
return;
}
};
let name = info.name.clone();
let rx = state.mgr.register(name.clone(), info).await;
let ack = serde_json::json!({ "type": "registered", "name": &name });
let _ = sender.send(Message::Text(ack.to_string().into())).await;
break (name, rx);
}
_ => {
let _ = sender.send(Message::Text(
r#"{"type":"error","message":"First message must be register"}"#.into(),
)).await;
return;
}
}
}
Some(Ok(Message::Close(_))) | None => return,
_ => continue,
}
};
let name_clone = name.clone();
let mgr_for_cleanup = state.mgr.clone();
// Task: send ServerToWorker messages from msg_rx to the WebSocket
let send_task = tokio::spawn(async move {
while let Some(msg) = msg_rx.recv().await {
if let Ok(json) = serde_json::to_string(&msg) {
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
});
// Task: receive WorkerToServer messages from the WebSocket → process
let state_clone = state.clone();
let recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(text) => {
match serde_json::from_str::<WorkerToServer>(&text) {
Ok(worker_msg) => {
handle_worker_message(&state_clone, worker_msg).await;
}
Err(e) => {
let preview: String = text.chars().take(200).collect();
tracing::warn!("Failed to parse worker message: {} — raw: {}", e, preview);
}
}
}
Message::Close(_) => break,
_ => {}
}
}
});
tokio::select! {
_ = send_task => {},
_ = recv_task => {},
}
// Log reason for any orphaned workflows before cleanup
let orphan_workflows: Vec<String> = {
let assignments = mgr_for_cleanup.assignments_for_worker(&name_clone).await;
assignments
};
if !orphan_workflows.is_empty() {
let reason = format!("Worker '{}' 断开连接", name_clone);
for wf_id in &orphan_workflows {
let _ = sqlx::query("UPDATE workflows SET status = 'failed', status_reason = ? WHERE id = ? AND status IN ('executing', 'planning')")
.bind(&reason).bind(wf_id).execute(&state.pool).await;
let log_id = uuid::Uuid::new_v4().to_string();
let _ = sqlx::query(
"INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, 0, 'system', 'worker_disconnect', ?, 'failed', datetime('now'))"
).bind(&log_id).bind(wf_id).bind(&reason).execute(&state.pool).await;
tracing::warn!("Workflow {} orphaned: {}", wf_id, reason);
}
}
mgr_for_cleanup.unregister(&name_clone).await;
}
async fn handle_worker_message(state: &WsWorkerState, msg: WorkerToServer) {
match msg {
WorkerToServer::Register { .. } => {
// Already handled during initial handshake
}
WorkerToServer::Update { workflow_id, update } => {
// Get project_id for broadcasting (look up from DB)
let project_id: Option<String> = sqlx::query_scalar(
"SELECT project_id FROM workflows WHERE id = ?"
)
.bind(&workflow_id)
.fetch_optional(&state.pool)
.await
.ok()
.flatten();
let broadcast_tx = if let Some(ref pid) = project_id {
Some((state.broadcast_fn)(pid))
} else {
None
};
// Check if this is a workflow completion
if let crate::sink::AgentUpdate::WorkflowComplete { ref workflow_id, .. } = update {
state.mgr.complete_workflow(workflow_id).await;
}
// Process the update: write to DB + broadcast
crate::sink::handle_single_update(&update, &state.pool, broadcast_tx.as_ref()).await;
}
}
}