181 lines
5.8 KiB
Rust
181 lines
5.8 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::state::AgentState;
|
|
|
|
/// Information reported by a worker on registration.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct WorkerInfo {
|
|
pub name: String,
|
|
pub cpu: String,
|
|
pub memory: String,
|
|
pub gpu: String,
|
|
pub os: String,
|
|
pub kernel: String,
|
|
}
|
|
|
|
/// A registered worker.
|
|
struct Worker {
|
|
pub info: WorkerInfo,
|
|
pub tx: tokio::sync::mpsc::Sender<ServerToWorker>,
|
|
}
|
|
|
|
/// Legacy script execution result.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct WorkerResult {
|
|
pub job_id: String,
|
|
pub exit_code: i32,
|
|
pub stdout: String,
|
|
pub stderr: String,
|
|
}
|
|
|
|
/// Messages sent from server to worker.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "type")]
|
|
pub enum ServerToWorker {
|
|
/// Assign a full workflow for execution.
|
|
#[serde(rename = "workflow_assign")]
|
|
WorkflowAssign {
|
|
workflow_id: String,
|
|
project_id: String,
|
|
requirement: String,
|
|
#[serde(default)]
|
|
template_id: Option<String>,
|
|
#[serde(default)]
|
|
initial_state: Option<AgentState>,
|
|
#[serde(default)]
|
|
require_plan_approval: bool,
|
|
},
|
|
/// Forward a user comment to the worker executing this workflow.
|
|
#[serde(rename = "comment")]
|
|
Comment {
|
|
workflow_id: String,
|
|
content: String,
|
|
},
|
|
}
|
|
|
|
/// Messages sent from worker to server.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "type")]
|
|
pub enum WorkerToServer {
|
|
/// Worker registration.
|
|
#[serde(rename = "register")]
|
|
Register { info: WorkerInfo },
|
|
/// Agent update from workflow execution.
|
|
#[serde(rename = "update")]
|
|
Update {
|
|
workflow_id: String,
|
|
update: crate::sink::AgentUpdate,
|
|
},
|
|
}
|
|
|
|
/// Manages all connected workers and workflow assignments.
|
|
pub struct WorkerManager {
|
|
workers: RwLock<HashMap<String, Worker>>,
|
|
/// workflow_id → worker_name
|
|
assignments: RwLock<HashMap<String, String>>,
|
|
}
|
|
|
|
impl WorkerManager {
|
|
pub fn new() -> Arc<Self> {
|
|
Arc::new(Self {
|
|
workers: RwLock::new(HashMap::new()),
|
|
assignments: RwLock::new(HashMap::new()),
|
|
})
|
|
}
|
|
|
|
/// Register a new worker. Returns a receiver for messages.
|
|
pub async fn register(
|
|
&self,
|
|
name: String,
|
|
info: WorkerInfo,
|
|
) -> tokio::sync::mpsc::Receiver<ServerToWorker> {
|
|
let (tx, rx) = tokio::sync::mpsc::channel(16);
|
|
tracing::info!("Worker registered: {} (cpu={}, mem={}, gpu={}, os={}, kernel={})",
|
|
name, info.cpu, info.memory, info.gpu, info.os, info.kernel);
|
|
self.workers.write().await.insert(name, Worker { info, tx });
|
|
rx
|
|
}
|
|
|
|
/// Remove a worker and clean up its assignments.
|
|
pub async fn unregister(&self, name: &str) {
|
|
self.workers.write().await.remove(name);
|
|
// Remove all workflow assignments for this worker
|
|
let mut assignments = self.assignments.write().await;
|
|
assignments.retain(|_, worker| worker != name);
|
|
tracing::info!("Worker unregistered: {}", name);
|
|
}
|
|
|
|
/// List all connected workers.
|
|
pub async fn list(&self) -> Vec<(String, WorkerInfo)> {
|
|
self.workers
|
|
.read()
|
|
.await
|
|
.iter()
|
|
.map(|(name, w)| (name.clone(), w.info.clone()))
|
|
.collect()
|
|
}
|
|
|
|
/// Assign a workflow to a worker. If `preferred` is specified, use that worker;
|
|
/// otherwise pick the first available.
|
|
pub async fn assign_workflow(&self, assign: ServerToWorker, preferred: Option<&str>) -> Result<String, String> {
|
|
let workflow_id = match &assign {
|
|
ServerToWorker::WorkflowAssign { workflow_id, .. } => workflow_id.clone(),
|
|
_ => return Err("Not a workflow assignment".into()),
|
|
};
|
|
|
|
let workers = self.workers.read().await;
|
|
let (name, worker) = if let Some(pref) = preferred {
|
|
workers.get_key_value(pref)
|
|
.ok_or_else(|| format!("Worker '{}' not available", pref))?
|
|
} else {
|
|
workers.iter().next()
|
|
.ok_or_else(|| "No workers available".to_string())?
|
|
};
|
|
|
|
worker.tx.send(assign).await.map_err(|_| {
|
|
format!("Worker '{}' disconnected", name)
|
|
})?;
|
|
|
|
let worker_name = name.clone();
|
|
drop(workers);
|
|
|
|
self.assignments.write().await.insert(workflow_id, worker_name.clone());
|
|
Ok(worker_name)
|
|
}
|
|
|
|
/// Forward a comment to the worker handling a workflow.
|
|
pub async fn forward_comment(&self, workflow_id: &str, content: &str) -> Result<(), String> {
|
|
let assignments = self.assignments.read().await;
|
|
let worker_name = assignments.get(workflow_id)
|
|
.ok_or_else(|| format!("No worker assigned for workflow {}", workflow_id))?
|
|
.clone();
|
|
drop(assignments);
|
|
|
|
let workers = self.workers.read().await;
|
|
let worker = workers.get(&worker_name)
|
|
.ok_or_else(|| format!("Worker '{}' not found", worker_name))?;
|
|
|
|
worker.tx.send(ServerToWorker::Comment {
|
|
workflow_id: workflow_id.to_string(),
|
|
content: content.to_string(),
|
|
}).await.map_err(|_| format!("Worker '{}' disconnected", worker_name))
|
|
}
|
|
|
|
/// Remove a workflow assignment (when workflow completes).
|
|
pub async fn complete_workflow(&self, workflow_id: &str) {
|
|
self.assignments.write().await.remove(workflow_id);
|
|
}
|
|
|
|
/// List all workflows assigned to a worker.
|
|
pub async fn assignments_for_worker(&self, worker_name: &str) -> Vec<String> {
|
|
self.assignments.read().await
|
|
.iter()
|
|
.filter(|(_, w)| w.as_str() == worker_name)
|
|
.map(|(wf_id, _)| wf_id.clone())
|
|
.collect()
|
|
}
|
|
}
|