refactor: worker mode — server offloads all LLM/exec to worker

- Split into `tori server` / `tori worker` subcommands (clap derive)
- Extract lib.rs for shared crate (agent, llm, exec, state, etc.)
- Introduce AgentUpdate channel to decouple agent loop from DB/broadcast
- New sink.rs: AgentUpdate enum + ServiceManager + handle_agent_updates
- New worker_runner.rs: connects to server WS, runs full agent loop
- Expand worker protocol: ServerToWorker (workflow_assign, comment)
  and WorkerToServer (register, result, update)
- Remove LLM from title generation (heuristic) and template selection
  (must be explicit)
- Remove KB tools (kb_search, kb_read) and remote worker tools
  (list_workers, execute_on_worker) from agent loop
- run_agent_loop/run_step_loop now take mpsc::Sender<AgentUpdate>
  instead of direct DB pool + broadcast sender
This commit is contained in:
Fam Zheng 2026-04-06 12:54:57 +01:00
parent 28a00dd2f3
commit e4ba385112
9 changed files with 1003 additions and 610 deletions

186
Cargo.lock generated
View File

@ -26,6 +26,56 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "anstream"
version = "0.6.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000"
[[package]]
name = "anstyle-parse"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.102" version = "1.0.102"
@ -83,7 +133,7 @@ dependencies = [
"sha1", "sha1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite", "tokio-tungstenite 0.28.0",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@ -216,6 +266,52 @@ dependencies = [
"windows-link", "windows-link",
] ]
[[package]]
name = "clap"
version = "4.5.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a"
dependencies = [
"clap_builder",
"clap_derive",
]
[[package]]
name = "clap_builder"
version = "4.5.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.5.55"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "clap_lex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
[[package]]
name = "colorchoice"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
[[package]] [[package]]
name = "concurrent-queue" name = "concurrent-queue"
version = "2.5.0" version = "2.5.0"
@ -663,6 +759,17 @@ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]]
name = "hostname"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "617aaa3557aef3810a6369d0a99fac8a080891b68bd9f9812a1eeda0c0730cbd"
dependencies = [
"cfg-if",
"libc",
"windows-link",
]
[[package]] [[package]]
name = "http" name = "http"
version = "1.4.0" version = "1.4.0"
@ -750,7 +857,7 @@ dependencies = [
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tower-service", "tower-service",
"webpki-roots", "webpki-roots 1.0.4",
] ]
[[package]] [[package]]
@ -936,6 +1043,12 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
[[package]] [[package]]
name = "itoa" name = "itoa"
version = "1.0.17" version = "1.0.17"
@ -1207,6 +1320,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "once_cell_polyfill"
version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]] [[package]]
name = "parking" name = "parking"
version = "2.2.1" version = "2.2.1"
@ -1564,7 +1683,7 @@ dependencies = [
"wasm-bindgen-futures", "wasm-bindgen-futures",
"wasm-streams", "wasm-streams",
"web-sys", "web-sys",
"webpki-roots", "webpki-roots 1.0.4",
] ]
[[package]] [[package]]
@ -2063,6 +2182,12 @@ dependencies = [
"unicode-properties", "unicode-properties",
] ]
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
@ -2234,6 +2359,22 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tungstenite 0.26.2",
"webpki-roots 0.26.11",
]
[[package]] [[package]]
name = "tokio-tungstenite" name = "tokio-tungstenite"
version = "0.28.0" version = "0.28.0"
@ -2243,7 +2384,7 @@ dependencies = [
"futures-util", "futures-util",
"log", "log",
"tokio", "tokio",
"tungstenite", "tungstenite 0.28.0",
] ]
[[package]] [[package]]
@ -2268,7 +2409,9 @@ dependencies = [
"axum-extra", "axum-extra",
"base64", "base64",
"chrono", "chrono",
"clap",
"futures", "futures",
"hostname",
"jsonwebtoken", "jsonwebtoken",
"mime_guess", "mime_guess",
"nix", "nix",
@ -2280,6 +2423,7 @@ dependencies = [
"sqlx", "sqlx",
"time", "time",
"tokio", "tokio",
"tokio-tungstenite 0.26.2",
"tokio-util", "tokio-util",
"tower-http", "tower-http",
"tracing", "tracing",
@ -2411,6 +2555,25 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand 0.9.2",
"rustls",
"rustls-pki-types",
"sha1",
"thiserror",
"utf-8",
]
[[package]] [[package]]
name = "tungstenite" name = "tungstenite"
version = "0.28.0" version = "0.28.0"
@ -2515,6 +2678,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.21.0" version = "1.21.0"
@ -2709,6 +2878,15 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "webpki-roots"
version = "0.26.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
dependencies = [
"webpki-roots 1.0.4",
]
[[package]] [[package]]
name = "webpki-roots" name = "webpki-roots"
version = "1.0.4" version = "1.0.4"

View File

@ -19,11 +19,14 @@ sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
tower-http = { version = "0.6", features = ["cors", "fs"] } tower-http = { version = "0.6", features = ["cors", "fs"] }
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
futures = "0.3" futures = "0.3"
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1", features = ["v4"] } uuid = { version = "1", features = ["v4"] }
anyhow = "1" anyhow = "1"
clap = { version = "4", features = ["derive", "env"] }
hostname = "0.4"
mime_guess = "2" mime_guess = "2"
tokio-util = { version = "0.7", features = ["io"] } tokio-util = { version = "0.7", features = ["io"] }
nix = { version = "0.29", features = ["signal"] } nix = { version = "0.29", features = ["signal"] }

File diff suppressed because it is too large Load Diff

74
src/lib.rs Normal file
View File

@ -0,0 +1,74 @@
pub mod api;
pub mod agent;
pub mod db;
pub mod kb;
pub mod llm;
pub mod exec;
pub mod state;
pub mod template;
pub mod timer;
pub mod tools;
pub mod worker;
pub mod sink;
pub mod worker_runner;
pub mod ws;
pub mod ws_worker;
use std::sync::Arc;
use serde::Deserialize;
pub struct AppState {
pub db: db::Database,
pub config: Config,
pub agent_mgr: Arc<agent::AgentManager>,
pub kb: Option<Arc<kb::KbManager>>,
pub obj_root: String,
pub auth: Option<api::auth::AuthConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub llm: LlmConfig,
pub server: ServerConfig,
pub database: DatabaseConfig,
#[serde(default)]
pub template_repo: Option<TemplateRepoConfig>,
/// Path to EC private key PEM file for JWT signing
#[serde(default)]
pub jwt_private_key: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TemplateRepoConfig {
pub gitea_url: String,
pub owner: String,
pub repo: String,
#[serde(default = "default_repo_path")]
pub local_path: String,
}
fn default_repo_path() -> String {
if std::path::Path::new("/app/oseng-templates").is_dir() {
"/app/oseng-templates".to_string()
} else {
"oseng-templates".to_string()
}
}
#[derive(Debug, Clone, serde::Serialize, Deserialize)]
pub struct LlmConfig {
pub base_url: String,
pub api_key: String,
pub model: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DatabaseConfig {
pub path: String,
}

View File

@ -1,77 +1,33 @@
mod api;
mod agent;
mod db;
mod kb;
mod llm;
mod exec;
pub mod state;
mod template;
mod timer;
mod tools;
mod worker;
mod ws;
mod ws_worker;
use std::sync::Arc; use std::sync::Arc;
use axum::Router; use axum::Router;
use clap::{Parser, Subcommand};
use sqlx::sqlite::SqlitePool; use sqlx::sqlite::SqlitePool;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
use tower_http::services::{ServeDir, ServeFile}; use tower_http::services::{ServeDir, ServeFile};
pub struct AppState { use tori::{agent, api, db, kb, template, timer, worker, worker_runner, ws, ws_worker};
pub db: db::Database, use tori::{AppState, Config};
pub config: Config,
pub agent_mgr: Arc<agent::AgentManager>, #[derive(Parser)]
pub kb: Option<Arc<kb::KbManager>>, #[command(name = "tori", about = "Tori AI agent orchestration")]
pub obj_root: String, struct Cli {
pub auth: Option<api::auth::AuthConfig>, #[command(subcommand)]
command: Command,
} }
#[derive(Debug, Clone, serde::Deserialize)] #[derive(Subcommand)]
pub struct Config { enum Command {
pub llm: LlmConfig, /// Start the API server
pub server: ServerConfig, Server,
pub database: DatabaseConfig, /// Start a worker that connects to the server
#[serde(default)] Worker {
pub template_repo: Option<TemplateRepoConfig>, /// Server WebSocket URL
/// Path to EC private key PEM file for JWT signing #[arg(long, env = "TORI_SERVER", default_value = "ws://127.0.0.1:3000/ws/tori/workers")]
#[serde(default)] server: String,
pub jwt_private_key: Option<String>, /// Worker name
} #[arg(long, env = "TORI_WORKER_NAME")]
name: Option<String>,
#[derive(Debug, Clone, serde::Deserialize)] },
pub struct TemplateRepoConfig {
pub gitea_url: String,
pub owner: String,
pub repo: String,
#[serde(default = "default_repo_path")]
pub local_path: String,
}
fn default_repo_path() -> String {
if std::path::Path::new("/app/oseng-templates").is_dir() {
"/app/oseng-templates".to_string()
} else {
"oseng-templates".to_string()
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlmConfig {
pub base_url: String,
pub api_key: String,
pub model: String,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct DatabaseConfig {
pub path: String,
} }
#[tokio::main] #[tokio::main]
@ -80,6 +36,22 @@ async fn main() -> anyhow::Result<()> {
.with_env_filter("tori=debug,tower_http=debug") .with_env_filter("tori=debug,tower_http=debug")
.init(); .init();
let cli = Cli::parse();
match cli.command {
Command::Server => run_server().await,
Command::Worker { server, name } => {
let name = name.unwrap_or_else(|| {
hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "worker-1".to_string())
});
worker_runner::run(&server, &name).await
}
}
}
async fn run_server() -> anyhow::Result<()> {
let config_str = std::fs::read_to_string("config.yaml") let config_str = std::fs::read_to_string("config.yaml")
.expect("Failed to read config.yaml"); .expect("Failed to read config.yaml");
let config: Config = serde_yaml::from_str(&config_str) let config: Config = serde_yaml::from_str(&config_str)
@ -88,7 +60,6 @@ async fn main() -> anyhow::Result<()> {
let database = db::Database::new(&config.database.path).await?; let database = db::Database::new(&config.database.path).await?;
database.migrate().await?; database.migrate().await?;
// Initialize KB manager
let kb_arc = match kb::KbManager::new(database.pool.clone()) { let kb_arc = match kb::KbManager::new(database.pool.clone()) {
Ok(kb) => { Ok(kb) => {
tracing::info!("KB manager initialized"); tracing::info!("KB manager initialized");
@ -100,7 +71,6 @@ async fn main() -> anyhow::Result<()> {
} }
}; };
// Ensure template repo is cloned before serving
if let Some(ref repo_cfg) = config.template_repo { if let Some(ref repo_cfg) = config.template_repo {
template::ensure_repo_ready(repo_cfg).await; template::ensure_repo_ready(repo_cfg).await;
} }
@ -117,8 +87,6 @@ async fn main() -> anyhow::Result<()> {
); );
timer::start_timer_runner(database.pool.clone(), agent_mgr.clone()); timer::start_timer_runner(database.pool.clone(), agent_mgr.clone());
// Resume incomplete workflows after restart
resume_workflows(database.pool.clone(), agent_mgr.clone()).await; resume_workflows(database.pool.clone(), agent_mgr.clone()).await;
let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string()); let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string());
@ -129,7 +97,6 @@ async fn main() -> anyhow::Result<()> {
let public_url = std::env::var("PUBLIC_URL") let public_url = std::env::var("PUBLIC_URL")
.unwrap_or_else(|_| "https://tori.euphon.cloud".to_string()); .unwrap_or_else(|_| "https://tori.euphon.cloud".to_string());
// Try TikTok SSO first, then Google OAuth
if let (Ok(id), Ok(secret)) = ( if let (Ok(id), Ok(secret)) = (
std::env::var("SSO_CLIENT_ID"), std::env::var("SSO_CLIENT_ID"),
std::env::var("SSO_CLIENT_SECRET"), std::env::var("SSO_CLIENT_SECRET"),
@ -157,7 +124,7 @@ async fn main() -> anyhow::Result<()> {
public_url, public_url,
}) })
} else { } else {
tracing::warn!("No OAuth configured (set SSO_CLIENT_ID/SSO_CLIENT_SECRET or GOOGLE_CLIENT_ID/GOOGLE_CLIENT_SECRET)"); tracing::warn!("No OAuth configured");
None None
} }
}; };
@ -172,13 +139,10 @@ async fn main() -> anyhow::Result<()> {
}); });
let app = Router::new() let app = Router::new()
// Health check (public, for k8s probes)
.route("/tori/api/health", axum::routing::get(|| async { .route("/tori/api/health", axum::routing::get(|| async {
axum::Json(serde_json::json!({"status": "ok"})) axum::Json(serde_json::json!({"status": "ok"}))
})) }))
// Auth routes are public
.nest("/tori/api/auth", api::auth::router(state.clone())) .nest("/tori/api/auth", api::auth::router(state.clone()))
// Protected API routes
.nest("/tori/api", api::router(state.clone()) .nest("/tori/api", api::router(state.clone())
.layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth)) .layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth))
) )

239
src/sink.rs Normal file
View File

@ -0,0 +1,239 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU16, Ordering};
use serde::{Deserialize, Serialize};
use sqlx::sqlite::SqlitePool;
use tokio::sync::{RwLock, broadcast, mpsc};
use crate::agent::{PlanStepInfo, WsMessage, ServiceInfo};
use crate::state::{AgentState, Artifact};
/// All updates produced by the agent loop. This is the single output interface
/// that decouples the agent logic from DB persistence and WebSocket broadcasting.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum AgentUpdate {
PlanUpdate {
workflow_id: String,
steps: Vec<PlanStepInfo>,
},
WorkflowStatus {
workflow_id: String,
status: String,
},
Activity {
workflow_id: String,
activity: String,
},
ExecutionLog {
workflow_id: String,
step_order: i32,
tool_name: String,
tool_input: String,
output: String,
status: String,
},
LlmCallLog {
workflow_id: String,
step_order: i32,
phase: String,
messages_count: i32,
tools_count: i32,
tool_calls: String,
text_response: String,
prompt_tokens: Option<u32>,
completion_tokens: Option<u32>,
latency_ms: i64,
},
StateSnapshot {
workflow_id: String,
step_order: i32,
state: AgentState,
},
WorkflowComplete {
workflow_id: String,
status: String,
report: Option<String>,
},
ArtifactSave {
workflow_id: String,
step_order: i32,
artifact: Artifact,
},
RequirementUpdate {
workflow_id: String,
requirement: String,
},
Error {
message: String,
},
}
/// Manages local services (start_service / stop_service tools).
/// Created per-worker or per-agent-loop.
pub struct ServiceManager {
pub services: RwLock<HashMap<String, ServiceInfo>>,
next_port: AtomicU16,
}
impl ServiceManager {
pub fn new(start_port: u16) -> Arc<Self> {
Arc::new(Self {
services: RwLock::new(HashMap::new()),
next_port: AtomicU16::new(start_port),
})
}
pub fn allocate_port(&self) -> u16 {
self.next_port.fetch_add(1, Ordering::Relaxed)
}
}
/// Server-side handler: consumes AgentUpdate from channel, persists to DB and broadcasts to frontend.
pub async fn handle_agent_updates(
mut rx: mpsc::Receiver<AgentUpdate>,
pool: SqlitePool,
broadcast_tx: broadcast::Sender<WsMessage>,
) {
while let Some(update) = rx.recv().await {
match update {
AgentUpdate::PlanUpdate { workflow_id, steps } => {
let _ = broadcast_tx.send(WsMessage::PlanUpdate { workflow_id, steps });
}
AgentUpdate::WorkflowStatus { ref workflow_id, ref status } => {
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
.bind(status)
.bind(workflow_id)
.execute(&pool)
.await;
let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate {
workflow_id: workflow_id.clone(),
status: status.clone(),
});
}
AgentUpdate::Activity { workflow_id, activity } => {
let _ = broadcast_tx.send(WsMessage::ActivityUpdate { workflow_id, activity });
}
AgentUpdate::ExecutionLog { ref workflow_id, step_order, ref tool_name, ref tool_input, ref output, ref status } => {
let id = uuid::Uuid::new_v4().to_string();
let _ = sqlx::query(
"INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
)
.bind(&id)
.bind(workflow_id)
.bind(step_order)
.bind(tool_name)
.bind(tool_input)
.bind(output)
.bind(status)
.execute(&pool)
.await;
let _ = broadcast_tx.send(WsMessage::StepStatusUpdate {
step_id: id,
status: status.clone(),
output: output.clone(),
});
}
AgentUpdate::LlmCallLog { ref workflow_id, step_order, ref phase, messages_count, tools_count, ref tool_calls, ref text_response, prompt_tokens, completion_tokens, latency_ms } => {
let id = uuid::Uuid::new_v4().to_string();
let _ = sqlx::query(
"INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))"
)
.bind(&id)
.bind(workflow_id)
.bind(step_order)
.bind(phase)
.bind(messages_count)
.bind(tools_count)
.bind(tool_calls)
.bind(text_response)
.bind(prompt_tokens.map(|v| v as i32))
.bind(completion_tokens.map(|v| v as i32))
.bind(latency_ms as i32)
.execute(&pool)
.await;
let entry = crate::db::LlmCallLogEntry {
id,
workflow_id: workflow_id.clone(),
step_order,
phase: phase.clone(),
messages_count,
tools_count,
tool_calls: tool_calls.clone(),
text_response: text_response.clone(),
prompt_tokens: prompt_tokens.map(|v| v as i32),
completion_tokens: completion_tokens.map(|v| v as i32),
latency_ms: latency_ms as i32,
created_at: String::new(),
};
let _ = broadcast_tx.send(WsMessage::LlmCallLog {
workflow_id: workflow_id.clone(),
entry,
});
}
AgentUpdate::StateSnapshot { ref workflow_id, step_order, ref state } => {
let id = uuid::Uuid::new_v4().to_string();
let json = serde_json::to_string(state).unwrap_or_default();
let _ = sqlx::query(
"INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))"
)
.bind(&id)
.bind(workflow_id)
.bind(step_order)
.bind(&json)
.execute(&pool)
.await;
}
AgentUpdate::WorkflowComplete { ref workflow_id, ref status, ref report } => {
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
.bind(status)
.bind(workflow_id)
.execute(&pool)
.await;
if let Some(ref r) = report {
let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?")
.bind(r)
.bind(workflow_id)
.execute(&pool)
.await;
let _ = broadcast_tx.send(WsMessage::ReportReady {
workflow_id: workflow_id.clone(),
});
}
let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate {
workflow_id: workflow_id.clone(),
status: status.clone(),
});
}
AgentUpdate::ArtifactSave { ref workflow_id, step_order, ref artifact } => {
let id = uuid::Uuid::new_v4().to_string();
let _ = sqlx::query(
"INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)"
)
.bind(&id)
.bind(workflow_id)
.bind(step_order)
.bind(&artifact.name)
.bind(&artifact.path)
.bind(&artifact.artifact_type)
.bind(&artifact.description)
.execute(&pool)
.await;
}
AgentUpdate::RequirementUpdate { ref workflow_id, ref requirement } => {
let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?")
.bind(requirement)
.bind(workflow_id)
.execute(&pool)
.await;
let _ = broadcast_tx.send(WsMessage::RequirementUpdate {
workflow_id: workflow_id.clone(),
requirement: requirement.clone(),
});
}
AgentUpdate::Error { message } => {
let _ = broadcast_tx.send(WsMessage::Error { message });
}
}
}
}

View File

@ -3,7 +3,6 @@ use std::path::{Path, PathBuf};
use serde::Deserialize; use serde::Deserialize;
use crate::TemplateRepoConfig; use crate::TemplateRepoConfig;
use crate::llm::{ChatMessage, LlmClient};
use crate::tools::ExternalToolManager; use crate::tools::ExternalToolManager;
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -463,42 +462,6 @@ pub fn is_repo_template(template_id: &str) -> bool {
template_id.contains('/') template_id.contains('/')
} }
// --- LLM template selection ---
pub async fn select_template(llm: &LlmClient, requirement: &str, repo_cfg: Option<&TemplateRepoConfig>) -> Option<String> {
let all = list_all_templates(repo_cfg).await;
if all.is_empty() {
return None;
}
let listing: String = all
.iter()
.map(|t| format!("- id: {}\n 名称: {}\n 描述: {}", t.id, t.name, t.description))
.collect::<Vec<_>>()
.join("\n");
let prompt = format!(
"以下是可用的项目模板:\n{}\n\n用户需求:{}\n\n选择最匹配的模板 ID如果都不合适则回复 none。只回复模板 ID 或 none不要其他内容。",
listing, requirement
);
let response = llm
.chat(vec![
ChatMessage::system("你是一个模板选择助手。根据用户需求选择最合适的项目模板。只回复模板 ID 或 none。"),
ChatMessage::user(&prompt),
])
.await
.ok()?;
let answer = response.trim().to_lowercase();
tracing::info!("Template selection LLM response: '{}' (available: {:?})",
answer, all.iter().map(|t| t.id.as_str()).collect::<Vec<_>>());
if answer == "none" {
return None;
}
all.iter().find(|t| t.id == answer).map(|t| t.id.clone())
}
// --- Template loading --- // --- Template loading ---

View File

@ -131,3 +131,56 @@ impl WorkerManager {
} }
} }
} }
// --- Extended protocol for workflow execution ---
use crate::LlmConfig;
use crate::state::AgentState;
use crate::llm::Tool;
/// Messages sent from server to worker.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ServerToWorker {
/// Assign a full workflow for execution.
#[serde(rename = "workflow_assign")]
WorkflowAssign {
workflow_id: String,
project_id: String,
requirement: String,
workdir: String,
instructions: String,
llm_config: LlmConfig,
#[serde(default)]
initial_state: Option<AgentState>,
#[serde(default)]
require_plan_approval: bool,
#[serde(default)]
external_tools: Vec<Tool>,
},
/// Forward a user comment to the worker executing this workflow.
#[serde(rename = "comment")]
Comment {
workflow_id: String,
content: String,
},
}
/// Messages sent from worker to server.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WorkerToServer {
/// Worker registration.
#[serde(rename = "register")]
Register { info: WorkerInfo },
/// Script execution result (legacy).
#[serde(rename = "result")]
Result(WorkerResult),
/// Agent update from workflow execution.
#[serde(rename = "update")]
Update {
workflow_id: String,
#[serde(flatten)]
update: crate::sink::AgentUpdate,
},
}

203
src/worker_runner.rs Normal file
View File

@ -0,0 +1,203 @@
use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::agent::{self, AgentEvent};
use crate::exec::LocalExecutor;
use crate::llm::LlmClient;
use crate::sink::{AgentUpdate, ServiceManager};
use crate::worker::{ServerToWorker, WorkerInfo, WorkerToServer};
fn collect_worker_info(name: &str) -> WorkerInfo {
let cpu = std::fs::read_to_string("/proc/cpuinfo")
.ok()
.and_then(|s| {
s.lines()
.find(|l| l.starts_with("model name"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
})
.unwrap_or_else(|| "unknown".into());
let memory = std::fs::read_to_string("/proc/meminfo")
.ok()
.and_then(|s| {
s.lines()
.find(|l| l.starts_with("MemTotal"))
.and_then(|l| l.split_whitespace().nth(1))
.and_then(|kb| kb.parse::<u64>().ok())
.map(|kb| format!("{:.1} GB", kb as f64 / 1_048_576.0))
})
.unwrap_or_else(|| "unknown".into());
let gpu = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=name")
.arg("--format=csv,noheader")
.output()
.ok()
.and_then(|o| {
if o.status.success() {
Some(String::from_utf8_lossy(&o.stdout).trim().to_string())
} else {
None
}
})
.unwrap_or_else(|| "none".into());
WorkerInfo {
name: name.to_string(),
cpu,
memory,
gpu,
os: std::env::consts::OS.to_string(),
kernel: std::process::Command::new("uname")
.arg("-r")
.output()
.ok()
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
.unwrap_or_else(|| "unknown".into()),
}
}
pub async fn run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
tracing::info!("Tori worker '{}' connecting to {}", worker_name, server_url);
loop {
match connect_and_run(server_url, worker_name).await {
Ok(()) => {
tracing::info!("Worker connection closed, reconnecting in 5s...");
}
Err(e) => {
tracing::error!("Worker error: {}, reconnecting in 5s...", e);
}
}
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
async fn connect_and_run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
let (ws_stream, _) = connect_async(server_url).await?;
let (mut ws_tx, mut ws_rx) = ws_stream.split();
// Register
let info = collect_worker_info(worker_name);
let register_msg = serde_json::to_string(&WorkerToServer::Register { info })?;
ws_tx.send(Message::Text(register_msg.into())).await?;
// Wait for registration ack
while let Some(msg) = ws_rx.next().await {
match msg? {
Message::Text(text) => {
let v: serde_json::Value = serde_json::from_str(&text)?;
if v["type"] == "registered" {
tracing::info!("Registered as '{}'", v["name"]);
break;
}
}
Message::Close(_) => anyhow::bail!("Connection closed during registration"),
_ => {}
}
}
let svc_mgr = ServiceManager::new(9100);
let ws_tx = Arc::new(tokio::sync::Mutex::new(ws_tx));
// Channel for forwarding comments to the running workflow
let comment_tx: Arc<tokio::sync::Mutex<Option<mpsc::Sender<AgentEvent>>>> =
Arc::new(tokio::sync::Mutex::new(None));
// Main message loop
while let Some(msg) = ws_rx.next().await {
let text = match msg? {
Message::Text(t) => t,
Message::Close(_) => break,
_ => continue,
};
let server_msg: ServerToWorker = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
tracing::warn!("Failed to parse server message: {}", e);
continue;
}
};
match server_msg {
ServerToWorker::WorkflowAssign {
workflow_id,
project_id,
requirement,
workdir,
instructions,
llm_config,
initial_state,
require_plan_approval,
external_tools: _,
} => {
tracing::info!("Received workflow: {} (project {})", workflow_id, project_id);
let llm = LlmClient::new(&llm_config);
let exec = LocalExecutor::new(None);
// update channel → serialize → WebSocket
let (update_tx, mut update_rx) = mpsc::channel::<AgentUpdate>(64);
let ws_tx_clone = ws_tx.clone();
let wf_id_clone = workflow_id.clone();
tokio::spawn(async move {
while let Some(update) = update_rx.recv().await {
let msg = WorkerToServer::Update {
workflow_id: wf_id_clone.clone(),
update,
};
if let Ok(json) = serde_json::to_string(&msg) {
let mut tx = ws_tx_clone.lock().await;
if tx.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
});
// event channel for comments
let (evt_tx, mut evt_rx) = mpsc::channel::<AgentEvent>(32);
*comment_tx.lock().await = Some(evt_tx);
let _ = tokio::fs::create_dir_all(&workdir).await;
let result = agent::run_agent_loop(
&llm, &exec, &update_tx, &mut evt_rx,
&project_id, &workflow_id, &requirement, &workdir, &svc_mgr,
&instructions, initial_state, None, require_plan_approval,
).await;
let final_status = if result.is_ok() { "done" } else { "failed" };
if let Err(e) = &result {
tracing::error!("Workflow {} failed: {}", workflow_id, e);
let _ = update_tx.send(AgentUpdate::Error {
message: format!("Agent error: {}", e),
}).await;
}
let _ = update_tx.send(AgentUpdate::WorkflowComplete {
workflow_id: workflow_id.clone(),
status: final_status.into(),
report: None,
}).await;
*comment_tx.lock().await = None;
tracing::info!("Workflow {} completed: {}", workflow_id, final_status);
}
ServerToWorker::Comment { workflow_id, content } => {
if let Some(ref tx) = *comment_tx.lock().await {
let _ = tx.send(AgentEvent::Comment {
workflow_id,
content,
}).await;
}
}
}
}
Ok(())
}