refactor: worker mode — server offloads all LLM/exec to worker
- Split into `tori server` / `tori worker` subcommands (clap derive) - Extract lib.rs for shared crate (agent, llm, exec, state, etc.) - Introduce AgentUpdate channel to decouple agent loop from DB/broadcast - New sink.rs: AgentUpdate enum + ServiceManager + handle_agent_updates - New worker_runner.rs: connects to server WS, runs full agent loop - Expand worker protocol: ServerToWorker (workflow_assign, comment) and WorkerToServer (register, result, update) - Remove LLM from title generation (heuristic) and template selection (must be explicit) - Remove KB tools (kb_search, kb_read) and remote worker tools (list_workers, execute_on_worker) from agent loop - run_agent_loop/run_step_loop now take mpsc::Sender<AgentUpdate> instead of direct DB pool + broadcast sender
This commit is contained in:
parent
28a00dd2f3
commit
e4ba385112
186
Cargo.lock
generated
186
Cargo.lock
generated
@ -26,6 +26,56 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstream"
|
||||||
|
version = "0.6.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a"
|
||||||
|
dependencies = [
|
||||||
|
"anstyle",
|
||||||
|
"anstyle-parse",
|
||||||
|
"anstyle-query",
|
||||||
|
"anstyle-wincon",
|
||||||
|
"colorchoice",
|
||||||
|
"is_terminal_polyfill",
|
||||||
|
"utf8parse",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstyle"
|
||||||
|
version = "1.0.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstyle-parse"
|
||||||
|
version = "0.2.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
|
||||||
|
dependencies = [
|
||||||
|
"utf8parse",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstyle-query"
|
||||||
|
version = "1.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
|
||||||
|
dependencies = [
|
||||||
|
"windows-sys 0.61.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstyle-wincon"
|
||||||
|
version = "3.0.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
|
||||||
|
dependencies = [
|
||||||
|
"anstyle",
|
||||||
|
"once_cell_polyfill",
|
||||||
|
"windows-sys 0.61.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyhow"
|
name = "anyhow"
|
||||||
version = "1.0.102"
|
version = "1.0.102"
|
||||||
@ -83,7 +133,7 @@ dependencies = [
|
|||||||
"sha1",
|
"sha1",
|
||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-tungstenite",
|
"tokio-tungstenite 0.28.0",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
@ -216,6 +266,52 @@ dependencies = [
|
|||||||
"windows-link",
|
"windows-link",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap"
|
||||||
|
version = "4.5.60"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a"
|
||||||
|
dependencies = [
|
||||||
|
"clap_builder",
|
||||||
|
"clap_derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap_builder"
|
||||||
|
version = "4.5.60"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876"
|
||||||
|
dependencies = [
|
||||||
|
"anstream",
|
||||||
|
"anstyle",
|
||||||
|
"clap_lex",
|
||||||
|
"strsim",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap_derive"
|
||||||
|
version = "4.5.55"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5"
|
||||||
|
dependencies = [
|
||||||
|
"heck",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap_lex"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "colorchoice"
|
||||||
|
version = "1.0.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "concurrent-queue"
|
name = "concurrent-queue"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
@ -663,6 +759,17 @@ dependencies = [
|
|||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hostname"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "617aaa3557aef3810a6369d0a99fac8a080891b68bd9f9812a1eeda0c0730cbd"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"windows-link",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http"
|
name = "http"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
@ -750,7 +857,7 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"webpki-roots",
|
"webpki-roots 1.0.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -936,6 +1043,12 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "is_terminal_polyfill"
|
||||||
|
version = "1.70.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itoa"
|
name = "itoa"
|
||||||
version = "1.0.17"
|
version = "1.0.17"
|
||||||
@ -1207,6 +1320,12 @@ version = "1.21.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "once_cell_polyfill"
|
||||||
|
version = "1.70.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking"
|
name = "parking"
|
||||||
version = "2.2.1"
|
version = "2.2.1"
|
||||||
@ -1564,7 +1683,7 @@ dependencies = [
|
|||||||
"wasm-bindgen-futures",
|
"wasm-bindgen-futures",
|
||||||
"wasm-streams",
|
"wasm-streams",
|
||||||
"web-sys",
|
"web-sys",
|
||||||
"webpki-roots",
|
"webpki-roots 1.0.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2063,6 +2182,12 @@ dependencies = [
|
|||||||
"unicode-properties",
|
"unicode-properties",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "strsim"
|
||||||
|
version = "0.11.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "subtle"
|
name = "subtle"
|
||||||
version = "2.6.1"
|
version = "2.6.1"
|
||||||
@ -2234,6 +2359,22 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-tungstenite"
|
||||||
|
version = "0.26.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"rustls",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"tokio",
|
||||||
|
"tokio-rustls",
|
||||||
|
"tungstenite 0.26.2",
|
||||||
|
"webpki-roots 0.26.11",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-tungstenite"
|
name = "tokio-tungstenite"
|
||||||
version = "0.28.0"
|
version = "0.28.0"
|
||||||
@ -2243,7 +2384,7 @@ dependencies = [
|
|||||||
"futures-util",
|
"futures-util",
|
||||||
"log",
|
"log",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tungstenite",
|
"tungstenite 0.28.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2268,7 +2409,9 @@ dependencies = [
|
|||||||
"axum-extra",
|
"axum-extra",
|
||||||
"base64",
|
"base64",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
"clap",
|
||||||
"futures",
|
"futures",
|
||||||
|
"hostname",
|
||||||
"jsonwebtoken",
|
"jsonwebtoken",
|
||||||
"mime_guess",
|
"mime_guess",
|
||||||
"nix",
|
"nix",
|
||||||
@ -2280,6 +2423,7 @@ dependencies = [
|
|||||||
"sqlx",
|
"sqlx",
|
||||||
"time",
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-tungstenite 0.26.2",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
@ -2411,6 +2555,25 @@ version = "0.2.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tungstenite"
|
||||||
|
version = "0.26.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"data-encoding",
|
||||||
|
"http",
|
||||||
|
"httparse",
|
||||||
|
"log",
|
||||||
|
"rand 0.9.2",
|
||||||
|
"rustls",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"sha1",
|
||||||
|
"thiserror",
|
||||||
|
"utf-8",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tungstenite"
|
name = "tungstenite"
|
||||||
version = "0.28.0"
|
version = "0.28.0"
|
||||||
@ -2515,6 +2678,12 @@ version = "1.0.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utf8parse"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "uuid"
|
name = "uuid"
|
||||||
version = "1.21.0"
|
version = "1.21.0"
|
||||||
@ -2709,6 +2878,15 @@ dependencies = [
|
|||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "webpki-roots"
|
||||||
|
version = "0.26.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
|
||||||
|
dependencies = [
|
||||||
|
"webpki-roots 1.0.4",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "webpki-roots"
|
name = "webpki-roots"
|
||||||
version = "1.0.4"
|
version = "1.0.4"
|
||||||
|
|||||||
@ -19,11 +19,14 @@ sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
|
|||||||
tower-http = { version = "0.6", features = ["cors", "fs"] }
|
tower-http = { version = "0.6", features = ["cors", "fs"] }
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
chrono = { version = "0.4", features = ["serde"] }
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
uuid = { version = "1", features = ["v4"] }
|
uuid = { version = "1", features = ["v4"] }
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
|
clap = { version = "4", features = ["derive", "env"] }
|
||||||
|
hostname = "0.4"
|
||||||
mime_guess = "2"
|
mime_guess = "2"
|
||||||
tokio-util = { version = "0.7", features = ["io"] }
|
tokio-util = { version = "0.7", features = ["io"] }
|
||||||
nix = { version = "0.29", features = ["signal"] }
|
nix = { version = "0.29", features = ["signal"] }
|
||||||
|
|||||||
704
src/agent.rs
704
src/agent.rs
File diff suppressed because it is too large
Load Diff
74
src/lib.rs
Normal file
74
src/lib.rs
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
pub mod api;
|
||||||
|
pub mod agent;
|
||||||
|
pub mod db;
|
||||||
|
pub mod kb;
|
||||||
|
pub mod llm;
|
||||||
|
pub mod exec;
|
||||||
|
pub mod state;
|
||||||
|
pub mod template;
|
||||||
|
pub mod timer;
|
||||||
|
pub mod tools;
|
||||||
|
pub mod worker;
|
||||||
|
pub mod sink;
|
||||||
|
pub mod worker_runner;
|
||||||
|
pub mod ws;
|
||||||
|
pub mod ws_worker;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
pub struct AppState {
|
||||||
|
pub db: db::Database,
|
||||||
|
pub config: Config,
|
||||||
|
pub agent_mgr: Arc<agent::AgentManager>,
|
||||||
|
pub kb: Option<Arc<kb::KbManager>>,
|
||||||
|
pub obj_root: String,
|
||||||
|
pub auth: Option<api::auth::AuthConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub llm: LlmConfig,
|
||||||
|
pub server: ServerConfig,
|
||||||
|
pub database: DatabaseConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
pub template_repo: Option<TemplateRepoConfig>,
|
||||||
|
/// Path to EC private key PEM file for JWT signing
|
||||||
|
#[serde(default)]
|
||||||
|
pub jwt_private_key: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct TemplateRepoConfig {
|
||||||
|
pub gitea_url: String,
|
||||||
|
pub owner: String,
|
||||||
|
pub repo: String,
|
||||||
|
#[serde(default = "default_repo_path")]
|
||||||
|
pub local_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_repo_path() -> String {
|
||||||
|
if std::path::Path::new("/app/oseng-templates").is_dir() {
|
||||||
|
"/app/oseng-templates".to_string()
|
||||||
|
} else {
|
||||||
|
"oseng-templates".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, Deserialize)]
|
||||||
|
pub struct LlmConfig {
|
||||||
|
pub base_url: String,
|
||||||
|
pub api_key: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ServerConfig {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct DatabaseConfig {
|
||||||
|
pub path: String,
|
||||||
|
}
|
||||||
114
src/main.rs
114
src/main.rs
@ -1,77 +1,33 @@
|
|||||||
mod api;
|
|
||||||
mod agent;
|
|
||||||
mod db;
|
|
||||||
mod kb;
|
|
||||||
mod llm;
|
|
||||||
mod exec;
|
|
||||||
pub mod state;
|
|
||||||
mod template;
|
|
||||||
mod timer;
|
|
||||||
mod tools;
|
|
||||||
mod worker;
|
|
||||||
mod ws;
|
|
||||||
mod ws_worker;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
use sqlx::sqlite::SqlitePool;
|
use sqlx::sqlite::SqlitePool;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
use tower_http::services::{ServeDir, ServeFile};
|
use tower_http::services::{ServeDir, ServeFile};
|
||||||
|
|
||||||
pub struct AppState {
|
use tori::{agent, api, db, kb, template, timer, worker, worker_runner, ws, ws_worker};
|
||||||
pub db: db::Database,
|
use tori::{AppState, Config};
|
||||||
pub config: Config,
|
|
||||||
pub agent_mgr: Arc<agent::AgentManager>,
|
#[derive(Parser)]
|
||||||
pub kb: Option<Arc<kb::KbManager>>,
|
#[command(name = "tori", about = "Tori AI agent orchestration")]
|
||||||
pub obj_root: String,
|
struct Cli {
|
||||||
pub auth: Option<api::auth::AuthConfig>,
|
#[command(subcommand)]
|
||||||
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
#[derive(Subcommand)]
|
||||||
pub struct Config {
|
enum Command {
|
||||||
pub llm: LlmConfig,
|
/// Start the API server
|
||||||
pub server: ServerConfig,
|
Server,
|
||||||
pub database: DatabaseConfig,
|
/// Start a worker that connects to the server
|
||||||
#[serde(default)]
|
Worker {
|
||||||
pub template_repo: Option<TemplateRepoConfig>,
|
/// Server WebSocket URL
|
||||||
/// Path to EC private key PEM file for JWT signing
|
#[arg(long, env = "TORI_SERVER", default_value = "ws://127.0.0.1:3000/ws/tori/workers")]
|
||||||
#[serde(default)]
|
server: String,
|
||||||
pub jwt_private_key: Option<String>,
|
/// Worker name
|
||||||
}
|
#[arg(long, env = "TORI_WORKER_NAME")]
|
||||||
|
name: Option<String>,
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
},
|
||||||
pub struct TemplateRepoConfig {
|
|
||||||
pub gitea_url: String,
|
|
||||||
pub owner: String,
|
|
||||||
pub repo: String,
|
|
||||||
#[serde(default = "default_repo_path")]
|
|
||||||
pub local_path: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_repo_path() -> String {
|
|
||||||
if std::path::Path::new("/app/oseng-templates").is_dir() {
|
|
||||||
"/app/oseng-templates".to_string()
|
|
||||||
} else {
|
|
||||||
"oseng-templates".to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
pub struct LlmConfig {
|
|
||||||
pub base_url: String,
|
|
||||||
pub api_key: String,
|
|
||||||
pub model: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
pub struct ServerConfig {
|
|
||||||
pub host: String,
|
|
||||||
pub port: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
pub struct DatabaseConfig {
|
|
||||||
pub path: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -80,6 +36,22 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
.with_env_filter("tori=debug,tower_http=debug")
|
.with_env_filter("tori=debug,tower_http=debug")
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
match cli.command {
|
||||||
|
Command::Server => run_server().await,
|
||||||
|
Command::Worker { server, name } => {
|
||||||
|
let name = name.unwrap_or_else(|| {
|
||||||
|
hostname::get()
|
||||||
|
.map(|h| h.to_string_lossy().to_string())
|
||||||
|
.unwrap_or_else(|_| "worker-1".to_string())
|
||||||
|
});
|
||||||
|
worker_runner::run(&server, &name).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_server() -> anyhow::Result<()> {
|
||||||
let config_str = std::fs::read_to_string("config.yaml")
|
let config_str = std::fs::read_to_string("config.yaml")
|
||||||
.expect("Failed to read config.yaml");
|
.expect("Failed to read config.yaml");
|
||||||
let config: Config = serde_yaml::from_str(&config_str)
|
let config: Config = serde_yaml::from_str(&config_str)
|
||||||
@ -88,7 +60,6 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let database = db::Database::new(&config.database.path).await?;
|
let database = db::Database::new(&config.database.path).await?;
|
||||||
database.migrate().await?;
|
database.migrate().await?;
|
||||||
|
|
||||||
// Initialize KB manager
|
|
||||||
let kb_arc = match kb::KbManager::new(database.pool.clone()) {
|
let kb_arc = match kb::KbManager::new(database.pool.clone()) {
|
||||||
Ok(kb) => {
|
Ok(kb) => {
|
||||||
tracing::info!("KB manager initialized");
|
tracing::info!("KB manager initialized");
|
||||||
@ -100,7 +71,6 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Ensure template repo is cloned before serving
|
|
||||||
if let Some(ref repo_cfg) = config.template_repo {
|
if let Some(ref repo_cfg) = config.template_repo {
|
||||||
template::ensure_repo_ready(repo_cfg).await;
|
template::ensure_repo_ready(repo_cfg).await;
|
||||||
}
|
}
|
||||||
@ -117,8 +87,6 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
timer::start_timer_runner(database.pool.clone(), agent_mgr.clone());
|
timer::start_timer_runner(database.pool.clone(), agent_mgr.clone());
|
||||||
|
|
||||||
// Resume incomplete workflows after restart
|
|
||||||
resume_workflows(database.pool.clone(), agent_mgr.clone()).await;
|
resume_workflows(database.pool.clone(), agent_mgr.clone()).await;
|
||||||
|
|
||||||
let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string());
|
let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string());
|
||||||
@ -129,7 +97,6 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let public_url = std::env::var("PUBLIC_URL")
|
let public_url = std::env::var("PUBLIC_URL")
|
||||||
.unwrap_or_else(|_| "https://tori.euphon.cloud".to_string());
|
.unwrap_or_else(|_| "https://tori.euphon.cloud".to_string());
|
||||||
|
|
||||||
// Try TikTok SSO first, then Google OAuth
|
|
||||||
if let (Ok(id), Ok(secret)) = (
|
if let (Ok(id), Ok(secret)) = (
|
||||||
std::env::var("SSO_CLIENT_ID"),
|
std::env::var("SSO_CLIENT_ID"),
|
||||||
std::env::var("SSO_CLIENT_SECRET"),
|
std::env::var("SSO_CLIENT_SECRET"),
|
||||||
@ -157,7 +124,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
public_url,
|
public_url,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
tracing::warn!("No OAuth configured (set SSO_CLIENT_ID/SSO_CLIENT_SECRET or GOOGLE_CLIENT_ID/GOOGLE_CLIENT_SECRET)");
|
tracing::warn!("No OAuth configured");
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -172,13 +139,10 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
// Health check (public, for k8s probes)
|
|
||||||
.route("/tori/api/health", axum::routing::get(|| async {
|
.route("/tori/api/health", axum::routing::get(|| async {
|
||||||
axum::Json(serde_json::json!({"status": "ok"}))
|
axum::Json(serde_json::json!({"status": "ok"}))
|
||||||
}))
|
}))
|
||||||
// Auth routes are public
|
|
||||||
.nest("/tori/api/auth", api::auth::router(state.clone()))
|
.nest("/tori/api/auth", api::auth::router(state.clone()))
|
||||||
// Protected API routes
|
|
||||||
.nest("/tori/api", api::router(state.clone())
|
.nest("/tori/api", api::router(state.clone())
|
||||||
.layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth))
|
.layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth))
|
||||||
)
|
)
|
||||||
|
|||||||
239
src/sink.rs
Normal file
239
src/sink.rs
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicU16, Ordering};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sqlx::sqlite::SqlitePool;
|
||||||
|
use tokio::sync::{RwLock, broadcast, mpsc};
|
||||||
|
|
||||||
|
use crate::agent::{PlanStepInfo, WsMessage, ServiceInfo};
|
||||||
|
use crate::state::{AgentState, Artifact};
|
||||||
|
|
||||||
|
/// All updates produced by the agent loop. This is the single output interface
|
||||||
|
/// that decouples the agent logic from DB persistence and WebSocket broadcasting.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "kind")]
|
||||||
|
pub enum AgentUpdate {
|
||||||
|
PlanUpdate {
|
||||||
|
workflow_id: String,
|
||||||
|
steps: Vec<PlanStepInfo>,
|
||||||
|
},
|
||||||
|
WorkflowStatus {
|
||||||
|
workflow_id: String,
|
||||||
|
status: String,
|
||||||
|
},
|
||||||
|
Activity {
|
||||||
|
workflow_id: String,
|
||||||
|
activity: String,
|
||||||
|
},
|
||||||
|
ExecutionLog {
|
||||||
|
workflow_id: String,
|
||||||
|
step_order: i32,
|
||||||
|
tool_name: String,
|
||||||
|
tool_input: String,
|
||||||
|
output: String,
|
||||||
|
status: String,
|
||||||
|
},
|
||||||
|
LlmCallLog {
|
||||||
|
workflow_id: String,
|
||||||
|
step_order: i32,
|
||||||
|
phase: String,
|
||||||
|
messages_count: i32,
|
||||||
|
tools_count: i32,
|
||||||
|
tool_calls: String,
|
||||||
|
text_response: String,
|
||||||
|
prompt_tokens: Option<u32>,
|
||||||
|
completion_tokens: Option<u32>,
|
||||||
|
latency_ms: i64,
|
||||||
|
},
|
||||||
|
StateSnapshot {
|
||||||
|
workflow_id: String,
|
||||||
|
step_order: i32,
|
||||||
|
state: AgentState,
|
||||||
|
},
|
||||||
|
WorkflowComplete {
|
||||||
|
workflow_id: String,
|
||||||
|
status: String,
|
||||||
|
report: Option<String>,
|
||||||
|
},
|
||||||
|
ArtifactSave {
|
||||||
|
workflow_id: String,
|
||||||
|
step_order: i32,
|
||||||
|
artifact: Artifact,
|
||||||
|
},
|
||||||
|
RequirementUpdate {
|
||||||
|
workflow_id: String,
|
||||||
|
requirement: String,
|
||||||
|
},
|
||||||
|
Error {
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Manages local services (start_service / stop_service tools).
|
||||||
|
/// Created per-worker or per-agent-loop.
|
||||||
|
pub struct ServiceManager {
|
||||||
|
pub services: RwLock<HashMap<String, ServiceInfo>>,
|
||||||
|
next_port: AtomicU16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServiceManager {
|
||||||
|
pub fn new(start_port: u16) -> Arc<Self> {
|
||||||
|
Arc::new(Self {
|
||||||
|
services: RwLock::new(HashMap::new()),
|
||||||
|
next_port: AtomicU16::new(start_port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allocate_port(&self) -> u16 {
|
||||||
|
self.next_port.fetch_add(1, Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Server-side handler: consumes AgentUpdate from channel, persists to DB and broadcasts to frontend.
|
||||||
|
pub async fn handle_agent_updates(
|
||||||
|
mut rx: mpsc::Receiver<AgentUpdate>,
|
||||||
|
pool: SqlitePool,
|
||||||
|
broadcast_tx: broadcast::Sender<WsMessage>,
|
||||||
|
) {
|
||||||
|
while let Some(update) = rx.recv().await {
|
||||||
|
match update {
|
||||||
|
AgentUpdate::PlanUpdate { workflow_id, steps } => {
|
||||||
|
let _ = broadcast_tx.send(WsMessage::PlanUpdate { workflow_id, steps });
|
||||||
|
}
|
||||||
|
AgentUpdate::WorkflowStatus { ref workflow_id, ref status } => {
|
||||||
|
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
|
||||||
|
.bind(status)
|
||||||
|
.bind(workflow_id)
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate {
|
||||||
|
workflow_id: workflow_id.clone(),
|
||||||
|
status: status.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
AgentUpdate::Activity { workflow_id, activity } => {
|
||||||
|
let _ = broadcast_tx.send(WsMessage::ActivityUpdate { workflow_id, activity });
|
||||||
|
}
|
||||||
|
AgentUpdate::ExecutionLog { ref workflow_id, step_order, ref tool_name, ref tool_input, ref output, ref status } => {
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let _ = sqlx::query(
|
||||||
|
"INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.bind(workflow_id)
|
||||||
|
.bind(step_order)
|
||||||
|
.bind(tool_name)
|
||||||
|
.bind(tool_input)
|
||||||
|
.bind(output)
|
||||||
|
.bind(status)
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
let _ = broadcast_tx.send(WsMessage::StepStatusUpdate {
|
||||||
|
step_id: id,
|
||||||
|
status: status.clone(),
|
||||||
|
output: output.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
AgentUpdate::LlmCallLog { ref workflow_id, step_order, ref phase, messages_count, tools_count, ref tool_calls, ref text_response, prompt_tokens, completion_tokens, latency_ms } => {
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let _ = sqlx::query(
|
||||||
|
"INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.bind(workflow_id)
|
||||||
|
.bind(step_order)
|
||||||
|
.bind(phase)
|
||||||
|
.bind(messages_count)
|
||||||
|
.bind(tools_count)
|
||||||
|
.bind(tool_calls)
|
||||||
|
.bind(text_response)
|
||||||
|
.bind(prompt_tokens.map(|v| v as i32))
|
||||||
|
.bind(completion_tokens.map(|v| v as i32))
|
||||||
|
.bind(latency_ms as i32)
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
let entry = crate::db::LlmCallLogEntry {
|
||||||
|
id,
|
||||||
|
workflow_id: workflow_id.clone(),
|
||||||
|
step_order,
|
||||||
|
phase: phase.clone(),
|
||||||
|
messages_count,
|
||||||
|
tools_count,
|
||||||
|
tool_calls: tool_calls.clone(),
|
||||||
|
text_response: text_response.clone(),
|
||||||
|
prompt_tokens: prompt_tokens.map(|v| v as i32),
|
||||||
|
completion_tokens: completion_tokens.map(|v| v as i32),
|
||||||
|
latency_ms: latency_ms as i32,
|
||||||
|
created_at: String::new(),
|
||||||
|
};
|
||||||
|
let _ = broadcast_tx.send(WsMessage::LlmCallLog {
|
||||||
|
workflow_id: workflow_id.clone(),
|
||||||
|
entry,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
AgentUpdate::StateSnapshot { ref workflow_id, step_order, ref state } => {
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let json = serde_json::to_string(state).unwrap_or_default();
|
||||||
|
let _ = sqlx::query(
|
||||||
|
"INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.bind(workflow_id)
|
||||||
|
.bind(step_order)
|
||||||
|
.bind(&json)
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
AgentUpdate::WorkflowComplete { ref workflow_id, ref status, ref report } => {
|
||||||
|
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
|
||||||
|
.bind(status)
|
||||||
|
.bind(workflow_id)
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
if let Some(ref r) = report {
|
||||||
|
let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?")
|
||||||
|
.bind(r)
|
||||||
|
.bind(workflow_id)
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
let _ = broadcast_tx.send(WsMessage::ReportReady {
|
||||||
|
workflow_id: workflow_id.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate {
|
||||||
|
workflow_id: workflow_id.clone(),
|
||||||
|
status: status.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
AgentUpdate::ArtifactSave { ref workflow_id, step_order, ref artifact } => {
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let _ = sqlx::query(
|
||||||
|
"INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.bind(workflow_id)
|
||||||
|
.bind(step_order)
|
||||||
|
.bind(&artifact.name)
|
||||||
|
.bind(&artifact.path)
|
||||||
|
.bind(&artifact.artifact_type)
|
||||||
|
.bind(&artifact.description)
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
AgentUpdate::RequirementUpdate { ref workflow_id, ref requirement } => {
|
||||||
|
let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?")
|
||||||
|
.bind(requirement)
|
||||||
|
.bind(workflow_id)
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
let _ = broadcast_tx.send(WsMessage::RequirementUpdate {
|
||||||
|
workflow_id: workflow_id.clone(),
|
||||||
|
requirement: requirement.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
AgentUpdate::Error { message } => {
|
||||||
|
let _ = broadcast_tx.send(WsMessage::Error { message });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -3,7 +3,6 @@ use std::path::{Path, PathBuf};
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::TemplateRepoConfig;
|
use crate::TemplateRepoConfig;
|
||||||
use crate::llm::{ChatMessage, LlmClient};
|
|
||||||
use crate::tools::ExternalToolManager;
|
use crate::tools::ExternalToolManager;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@ -463,42 +462,6 @@ pub fn is_repo_template(template_id: &str) -> bool {
|
|||||||
template_id.contains('/')
|
template_id.contains('/')
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- LLM template selection ---
|
|
||||||
|
|
||||||
pub async fn select_template(llm: &LlmClient, requirement: &str, repo_cfg: Option<&TemplateRepoConfig>) -> Option<String> {
|
|
||||||
let all = list_all_templates(repo_cfg).await;
|
|
||||||
if all.is_empty() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let listing: String = all
|
|
||||||
.iter()
|
|
||||||
.map(|t| format!("- id: {}\n 名称: {}\n 描述: {}", t.id, t.name, t.description))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
let prompt = format!(
|
|
||||||
"以下是可用的项目模板:\n{}\n\n用户需求:{}\n\n选择最匹配的模板 ID,如果都不合适则回复 none。只回复模板 ID 或 none,不要其他内容。",
|
|
||||||
listing, requirement
|
|
||||||
);
|
|
||||||
|
|
||||||
let response = llm
|
|
||||||
.chat(vec![
|
|
||||||
ChatMessage::system("你是一个模板选择助手。根据用户需求选择最合适的项目模板。只回复模板 ID 或 none。"),
|
|
||||||
ChatMessage::user(&prompt),
|
|
||||||
])
|
|
||||||
.await
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
let answer = response.trim().to_lowercase();
|
|
||||||
tracing::info!("Template selection LLM response: '{}' (available: {:?})",
|
|
||||||
answer, all.iter().map(|t| t.id.as_str()).collect::<Vec<_>>());
|
|
||||||
if answer == "none" {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
all.iter().find(|t| t.id == answer).map(|t| t.id.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Template loading ---
|
// --- Template loading ---
|
||||||
|
|
||||||
|
|||||||
@ -131,3 +131,56 @@ impl WorkerManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Extended protocol for workflow execution ---
|
||||||
|
|
||||||
|
use crate::LlmConfig;
|
||||||
|
use crate::state::AgentState;
|
||||||
|
use crate::llm::Tool;
|
||||||
|
|
||||||
|
/// Messages sent from server to worker.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ServerToWorker {
|
||||||
|
/// Assign a full workflow for execution.
|
||||||
|
#[serde(rename = "workflow_assign")]
|
||||||
|
WorkflowAssign {
|
||||||
|
workflow_id: String,
|
||||||
|
project_id: String,
|
||||||
|
requirement: String,
|
||||||
|
workdir: String,
|
||||||
|
instructions: String,
|
||||||
|
llm_config: LlmConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
initial_state: Option<AgentState>,
|
||||||
|
#[serde(default)]
|
||||||
|
require_plan_approval: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
external_tools: Vec<Tool>,
|
||||||
|
},
|
||||||
|
/// Forward a user comment to the worker executing this workflow.
|
||||||
|
#[serde(rename = "comment")]
|
||||||
|
Comment {
|
||||||
|
workflow_id: String,
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Messages sent from worker to server.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum WorkerToServer {
|
||||||
|
/// Worker registration.
|
||||||
|
#[serde(rename = "register")]
|
||||||
|
Register { info: WorkerInfo },
|
||||||
|
/// Script execution result (legacy).
|
||||||
|
#[serde(rename = "result")]
|
||||||
|
Result(WorkerResult),
|
||||||
|
/// Agent update from workflow execution.
|
||||||
|
#[serde(rename = "update")]
|
||||||
|
Update {
|
||||||
|
workflow_id: String,
|
||||||
|
#[serde(flatten)]
|
||||||
|
update: crate::sink::AgentUpdate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|||||||
203
src/worker_runner.rs
Normal file
203
src/worker_runner.rs
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||||
|
|
||||||
|
use crate::agent::{self, AgentEvent};
|
||||||
|
use crate::exec::LocalExecutor;
|
||||||
|
use crate::llm::LlmClient;
|
||||||
|
use crate::sink::{AgentUpdate, ServiceManager};
|
||||||
|
use crate::worker::{ServerToWorker, WorkerInfo, WorkerToServer};
|
||||||
|
|
||||||
|
fn collect_worker_info(name: &str) -> WorkerInfo {
|
||||||
|
let cpu = std::fs::read_to_string("/proc/cpuinfo")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| {
|
||||||
|
s.lines()
|
||||||
|
.find(|l| l.starts_with("model name"))
|
||||||
|
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| "unknown".into());
|
||||||
|
|
||||||
|
let memory = std::fs::read_to_string("/proc/meminfo")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| {
|
||||||
|
s.lines()
|
||||||
|
.find(|l| l.starts_with("MemTotal"))
|
||||||
|
.and_then(|l| l.split_whitespace().nth(1))
|
||||||
|
.and_then(|kb| kb.parse::<u64>().ok())
|
||||||
|
.map(|kb| format!("{:.1} GB", kb as f64 / 1_048_576.0))
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| "unknown".into());
|
||||||
|
|
||||||
|
let gpu = std::process::Command::new("nvidia-smi")
|
||||||
|
.arg("--query-gpu=name")
|
||||||
|
.arg("--format=csv,noheader")
|
||||||
|
.output()
|
||||||
|
.ok()
|
||||||
|
.and_then(|o| {
|
||||||
|
if o.status.success() {
|
||||||
|
Some(String::from_utf8_lossy(&o.stdout).trim().to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| "none".into());
|
||||||
|
|
||||||
|
WorkerInfo {
|
||||||
|
name: name.to_string(),
|
||||||
|
cpu,
|
||||||
|
memory,
|
||||||
|
gpu,
|
||||||
|
os: std::env::consts::OS.to_string(),
|
||||||
|
kernel: std::process::Command::new("uname")
|
||||||
|
.arg("-r")
|
||||||
|
.output()
|
||||||
|
.ok()
|
||||||
|
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
|
||||||
|
.unwrap_or_else(|| "unknown".into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
|
||||||
|
tracing::info!("Tori worker '{}' connecting to {}", worker_name, server_url);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match connect_and_run(server_url, worker_name).await {
|
||||||
|
Ok(()) => {
|
||||||
|
tracing::info!("Worker connection closed, reconnecting in 5s...");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Worker error: {}, reconnecting in 5s...", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_and_run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
|
||||||
|
let (ws_stream, _) = connect_async(server_url).await?;
|
||||||
|
let (mut ws_tx, mut ws_rx) = ws_stream.split();
|
||||||
|
|
||||||
|
// Register
|
||||||
|
let info = collect_worker_info(worker_name);
|
||||||
|
let register_msg = serde_json::to_string(&WorkerToServer::Register { info })?;
|
||||||
|
ws_tx.send(Message::Text(register_msg.into())).await?;
|
||||||
|
|
||||||
|
// Wait for registration ack
|
||||||
|
while let Some(msg) = ws_rx.next().await {
|
||||||
|
match msg? {
|
||||||
|
Message::Text(text) => {
|
||||||
|
let v: serde_json::Value = serde_json::from_str(&text)?;
|
||||||
|
if v["type"] == "registered" {
|
||||||
|
tracing::info!("Registered as '{}'", v["name"]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Close(_) => anyhow::bail!("Connection closed during registration"),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let svc_mgr = ServiceManager::new(9100);
|
||||||
|
let ws_tx = Arc::new(tokio::sync::Mutex::new(ws_tx));
|
||||||
|
|
||||||
|
// Channel for forwarding comments to the running workflow
|
||||||
|
let comment_tx: Arc<tokio::sync::Mutex<Option<mpsc::Sender<AgentEvent>>>> =
|
||||||
|
Arc::new(tokio::sync::Mutex::new(None));
|
||||||
|
|
||||||
|
// Main message loop
|
||||||
|
while let Some(msg) = ws_rx.next().await {
|
||||||
|
let text = match msg? {
|
||||||
|
Message::Text(t) => t,
|
||||||
|
Message::Close(_) => break,
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let server_msg: ServerToWorker = match serde_json::from_str(&text) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Failed to parse server message: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match server_msg {
|
||||||
|
ServerToWorker::WorkflowAssign {
|
||||||
|
workflow_id,
|
||||||
|
project_id,
|
||||||
|
requirement,
|
||||||
|
workdir,
|
||||||
|
instructions,
|
||||||
|
llm_config,
|
||||||
|
initial_state,
|
||||||
|
require_plan_approval,
|
||||||
|
external_tools: _,
|
||||||
|
} => {
|
||||||
|
tracing::info!("Received workflow: {} (project {})", workflow_id, project_id);
|
||||||
|
|
||||||
|
let llm = LlmClient::new(&llm_config);
|
||||||
|
let exec = LocalExecutor::new(None);
|
||||||
|
|
||||||
|
// update channel → serialize → WebSocket
|
||||||
|
let (update_tx, mut update_rx) = mpsc::channel::<AgentUpdate>(64);
|
||||||
|
let ws_tx_clone = ws_tx.clone();
|
||||||
|
let wf_id_clone = workflow_id.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(update) = update_rx.recv().await {
|
||||||
|
let msg = WorkerToServer::Update {
|
||||||
|
workflow_id: wf_id_clone.clone(),
|
||||||
|
update,
|
||||||
|
};
|
||||||
|
if let Ok(json) = serde_json::to_string(&msg) {
|
||||||
|
let mut tx = ws_tx_clone.lock().await;
|
||||||
|
if tx.send(Message::Text(json.into())).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// event channel for comments
|
||||||
|
let (evt_tx, mut evt_rx) = mpsc::channel::<AgentEvent>(32);
|
||||||
|
*comment_tx.lock().await = Some(evt_tx);
|
||||||
|
|
||||||
|
let _ = tokio::fs::create_dir_all(&workdir).await;
|
||||||
|
|
||||||
|
let result = agent::run_agent_loop(
|
||||||
|
&llm, &exec, &update_tx, &mut evt_rx,
|
||||||
|
&project_id, &workflow_id, &requirement, &workdir, &svc_mgr,
|
||||||
|
&instructions, initial_state, None, require_plan_approval,
|
||||||
|
).await;
|
||||||
|
|
||||||
|
let final_status = if result.is_ok() { "done" } else { "failed" };
|
||||||
|
if let Err(e) = &result {
|
||||||
|
tracing::error!("Workflow {} failed: {}", workflow_id, e);
|
||||||
|
let _ = update_tx.send(AgentUpdate::Error {
|
||||||
|
message: format!("Agent error: {}", e),
|
||||||
|
}).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = update_tx.send(AgentUpdate::WorkflowComplete {
|
||||||
|
workflow_id: workflow_id.clone(),
|
||||||
|
status: final_status.into(),
|
||||||
|
report: None,
|
||||||
|
}).await;
|
||||||
|
|
||||||
|
*comment_tx.lock().await = None;
|
||||||
|
tracing::info!("Workflow {} completed: {}", workflow_id, final_status);
|
||||||
|
}
|
||||||
|
|
||||||
|
ServerToWorker::Comment { workflow_id, content } => {
|
||||||
|
if let Some(ref tx) = *comment_tx.lock().await {
|
||||||
|
let _ = tx.send(AgentEvent::Comment {
|
||||||
|
workflow_id,
|
||||||
|
content,
|
||||||
|
}).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user