tori/src/worker.rs

181 lines
5.8 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use crate::state::AgentState;
/// Information reported by a worker on registration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerInfo {
pub name: String,
pub cpu: String,
pub memory: String,
pub gpu: String,
pub os: String,
pub kernel: String,
}
/// A registered worker.
struct Worker {
pub info: WorkerInfo,
pub tx: tokio::sync::mpsc::Sender<ServerToWorker>,
}
/// Legacy script execution result.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerResult {
pub job_id: String,
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
}
/// Messages sent from server to worker.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ServerToWorker {
/// Assign a full workflow for execution.
#[serde(rename = "workflow_assign")]
WorkflowAssign {
workflow_id: String,
project_id: String,
requirement: String,
#[serde(default)]
template_id: Option<String>,
#[serde(default)]
initial_state: Option<AgentState>,
#[serde(default)]
require_plan_approval: bool,
},
/// Forward a user comment to the worker executing this workflow.
#[serde(rename = "comment")]
Comment {
workflow_id: String,
content: String,
},
}
/// Messages sent from worker to server.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WorkerToServer {
/// Worker registration.
#[serde(rename = "register")]
Register { info: WorkerInfo },
/// Agent update from workflow execution.
#[serde(rename = "update")]
Update {
workflow_id: String,
update: crate::sink::AgentUpdate,
},
}
/// Manages all connected workers and workflow assignments.
pub struct WorkerManager {
workers: RwLock<HashMap<String, Worker>>,
/// workflow_id → worker_name
assignments: RwLock<HashMap<String, String>>,
}
impl WorkerManager {
pub fn new() -> Arc<Self> {
Arc::new(Self {
workers: RwLock::new(HashMap::new()),
assignments: RwLock::new(HashMap::new()),
})
}
/// Register a new worker. Returns a receiver for messages.
pub async fn register(
&self,
name: String,
info: WorkerInfo,
) -> tokio::sync::mpsc::Receiver<ServerToWorker> {
let (tx, rx) = tokio::sync::mpsc::channel(16);
tracing::info!("Worker registered: {} (cpu={}, mem={}, gpu={}, os={}, kernel={})",
name, info.cpu, info.memory, info.gpu, info.os, info.kernel);
self.workers.write().await.insert(name, Worker { info, tx });
rx
}
/// Remove a worker and clean up its assignments.
pub async fn unregister(&self, name: &str) {
self.workers.write().await.remove(name);
// Remove all workflow assignments for this worker
let mut assignments = self.assignments.write().await;
assignments.retain(|_, worker| worker != name);
tracing::info!("Worker unregistered: {}", name);
}
/// List all connected workers.
pub async fn list(&self) -> Vec<(String, WorkerInfo)> {
self.workers
.read()
.await
.iter()
.map(|(name, w)| (name.clone(), w.info.clone()))
.collect()
}
/// Assign a workflow to a worker. If `preferred` is specified, use that worker;
/// otherwise pick the first available.
pub async fn assign_workflow(&self, assign: ServerToWorker, preferred: Option<&str>) -> Result<String, String> {
let workflow_id = match &assign {
ServerToWorker::WorkflowAssign { workflow_id, .. } => workflow_id.clone(),
_ => return Err("Not a workflow assignment".into()),
};
let workers = self.workers.read().await;
let (name, worker) = if let Some(pref) = preferred {
workers.get_key_value(pref)
.ok_or_else(|| format!("Worker '{}' not available", pref))?
} else {
workers.iter().next()
.ok_or_else(|| "No workers available".to_string())?
};
worker.tx.send(assign).await.map_err(|_| {
format!("Worker '{}' disconnected", name)
})?;
let worker_name = name.clone();
drop(workers);
self.assignments.write().await.insert(workflow_id, worker_name.clone());
Ok(worker_name)
}
/// Forward a comment to the worker handling a workflow.
pub async fn forward_comment(&self, workflow_id: &str, content: &str) -> Result<(), String> {
let assignments = self.assignments.read().await;
let worker_name = assignments.get(workflow_id)
.ok_or_else(|| format!("No worker assigned for workflow {}", workflow_id))?
.clone();
drop(assignments);
let workers = self.workers.read().await;
let worker = workers.get(&worker_name)
.ok_or_else(|| format!("Worker '{}' not found", worker_name))?;
worker.tx.send(ServerToWorker::Comment {
workflow_id: workflow_id.to_string(),
content: content.to_string(),
}).await.map_err(|_| format!("Worker '{}' disconnected", worker_name))
}
/// Remove a workflow assignment (when workflow completes).
pub async fn complete_workflow(&self, workflow_id: &str) {
self.assignments.write().await.remove(workflow_id);
}
/// List all workflows assigned to a worker.
pub async fn assignments_for_worker(&self, worker_name: &str) -> Vec<String> {
self.assignments.read().await
.iter()
.filter(|(_, w)| w.as_str() == worker_name)
.map(|(wf_id, _)| wf_id.clone())
.collect()
}
}