217 lines
7.7 KiB
Rust
217 lines
7.7 KiB
Rust
use std::sync::Arc;
|
|
use futures::{SinkExt, StreamExt};
|
|
use tokio::sync::mpsc;
|
|
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
|
|
|
use crate::agent::{self, AgentEvent};
|
|
use crate::exec::LocalExecutor;
|
|
use crate::llm::LlmClient;
|
|
use crate::sink::{AgentUpdate, ServiceManager};
|
|
use crate::worker::{ServerToWorker, WorkerInfo, WorkerToServer};
|
|
|
|
fn collect_worker_info(name: &str) -> WorkerInfo {
|
|
let cpu = std::fs::read_to_string("/proc/cpuinfo")
|
|
.ok()
|
|
.and_then(|s| {
|
|
s.lines()
|
|
.find(|l| l.starts_with("model name"))
|
|
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
|
})
|
|
.unwrap_or_else(|| "unknown".into());
|
|
|
|
let memory = std::fs::read_to_string("/proc/meminfo")
|
|
.ok()
|
|
.and_then(|s| {
|
|
s.lines()
|
|
.find(|l| l.starts_with("MemTotal"))
|
|
.and_then(|l| l.split_whitespace().nth(1))
|
|
.and_then(|kb| kb.parse::<u64>().ok())
|
|
.map(|kb| format!("{:.1} GB", kb as f64 / 1_048_576.0))
|
|
})
|
|
.unwrap_or_else(|| "unknown".into());
|
|
|
|
let gpu = std::process::Command::new("nvidia-smi")
|
|
.arg("--query-gpu=name")
|
|
.arg("--format=csv,noheader")
|
|
.output()
|
|
.ok()
|
|
.and_then(|o| {
|
|
if o.status.success() {
|
|
Some(String::from_utf8_lossy(&o.stdout).trim().to_string())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.unwrap_or_else(|| "none".into());
|
|
|
|
WorkerInfo {
|
|
name: name.to_string(),
|
|
cpu,
|
|
memory,
|
|
gpu,
|
|
os: std::env::consts::OS.to_string(),
|
|
kernel: std::process::Command::new("uname")
|
|
.arg("-r")
|
|
.output()
|
|
.ok()
|
|
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
|
|
.unwrap_or_else(|| "unknown".into()),
|
|
}
|
|
}
|
|
|
|
pub async fn run(server_url: &str, worker_name: &str, llm_config: &crate::LlmConfig) -> anyhow::Result<()> {
|
|
tracing::info!("Tori worker '{}' connecting to {} (model={})", worker_name, server_url, llm_config.model);
|
|
|
|
loop {
|
|
match connect_and_run(server_url, worker_name, llm_config).await {
|
|
Ok(()) => {
|
|
tracing::info!("Worker connection closed, reconnecting in 5s...");
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Worker error: {}, reconnecting in 5s...", e);
|
|
}
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
|
}
|
|
}
|
|
|
|
async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate::LlmConfig) -> anyhow::Result<()> {
|
|
let (ws_stream, _) = connect_async(server_url).await?;
|
|
let (mut ws_tx, mut ws_rx) = ws_stream.split();
|
|
|
|
// Register
|
|
let info = collect_worker_info(worker_name);
|
|
let register_msg = serde_json::to_string(&WorkerToServer::Register { info })?;
|
|
ws_tx.send(Message::Text(register_msg.into())).await?;
|
|
|
|
// Wait for registration ack
|
|
while let Some(msg) = ws_rx.next().await {
|
|
match msg? {
|
|
Message::Text(text) => {
|
|
let v: serde_json::Value = serde_json::from_str(&text)?;
|
|
if v["type"] == "registered" {
|
|
tracing::info!("Registered as '{}'", v["name"]);
|
|
break;
|
|
}
|
|
}
|
|
Message::Close(_) => anyhow::bail!("Connection closed during registration"),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let svc_mgr = ServiceManager::new(9100);
|
|
let ws_tx = Arc::new(tokio::sync::Mutex::new(ws_tx));
|
|
|
|
// Ping task to keep connection alive
|
|
let ping_tx = ws_tx.clone();
|
|
tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
|
loop {
|
|
interval.tick().await;
|
|
let mut tx = ping_tx.lock().await;
|
|
if tx.send(Message::Ping(vec![].into())).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
// Channel for forwarding comments to the running workflow
|
|
let comment_tx: Arc<tokio::sync::Mutex<Option<mpsc::Sender<AgentEvent>>>> =
|
|
Arc::new(tokio::sync::Mutex::new(None));
|
|
|
|
// Main message loop
|
|
while let Some(msg) = ws_rx.next().await {
|
|
let text = match msg? {
|
|
Message::Text(t) => t,
|
|
Message::Close(_) => break,
|
|
Message::Pong(_) => continue,
|
|
_ => continue,
|
|
};
|
|
|
|
let server_msg: ServerToWorker = match serde_json::from_str(&text) {
|
|
Ok(m) => m,
|
|
Err(e) => {
|
|
tracing::warn!("Failed to parse server message: {}", e);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
match server_msg {
|
|
ServerToWorker::WorkflowAssign {
|
|
workflow_id,
|
|
project_id,
|
|
requirement,
|
|
template_id: _,
|
|
initial_state,
|
|
require_plan_approval,
|
|
} => {
|
|
tracing::info!("Received workflow: {} (project {})", workflow_id, project_id);
|
|
|
|
let llm = LlmClient::new(llm_config);
|
|
let exec = LocalExecutor::new(None);
|
|
let workdir = format!("workspaces/{}", project_id);
|
|
let instructions = String::new(); // TODO: load from template
|
|
|
|
// update channel → serialize → WebSocket
|
|
let (update_tx, mut update_rx) = mpsc::channel::<AgentUpdate>(64);
|
|
let ws_tx_clone = ws_tx.clone();
|
|
let wf_id_clone = workflow_id.clone();
|
|
tokio::spawn(async move {
|
|
while let Some(update) = update_rx.recv().await {
|
|
let msg = WorkerToServer::Update {
|
|
workflow_id: wf_id_clone.clone(),
|
|
update,
|
|
};
|
|
if let Ok(json) = serde_json::to_string(&msg) {
|
|
let mut tx = ws_tx_clone.lock().await;
|
|
if tx.send(Message::Text(json.into())).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// event channel for comments
|
|
let (evt_tx, mut evt_rx) = mpsc::channel::<AgentEvent>(32);
|
|
*comment_tx.lock().await = Some(evt_tx);
|
|
|
|
let _ = tokio::fs::create_dir_all(&workdir).await;
|
|
|
|
let result = agent::run_agent_loop(
|
|
&llm, &exec, &update_tx, &mut evt_rx,
|
|
&project_id, &workflow_id, &requirement, &workdir, &svc_mgr,
|
|
&instructions, initial_state, None, require_plan_approval,
|
|
).await;
|
|
|
|
let final_status = if result.is_ok() { "done" } else { "failed" };
|
|
if let Err(e) = &result {
|
|
tracing::error!("Workflow {} failed: {}", workflow_id, e);
|
|
let _ = update_tx.send(AgentUpdate::Error {
|
|
message: format!("Agent error: {}", e),
|
|
}).await;
|
|
}
|
|
|
|
let _ = update_tx.send(AgentUpdate::WorkflowComplete {
|
|
workflow_id: workflow_id.clone(),
|
|
status: final_status.into(),
|
|
report: None,
|
|
}).await;
|
|
|
|
*comment_tx.lock().await = None;
|
|
tracing::info!("Workflow {} completed: {}", workflow_id, final_status);
|
|
}
|
|
|
|
ServerToWorker::Comment { workflow_id, content } => {
|
|
if let Some(ref tx) = *comment_tx.lock().await {
|
|
let _ = tx.send(AgentEvent::Comment {
|
|
workflow_id,
|
|
content,
|
|
}).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|