tori/src/main.rs
Fam Zheng d4d9edeb78 feat: startup git clone for template repo + pass config through
- ensure_repo_ready() at startup: clone if missing, fetch if exists
- TemplateRepoConfig gains local_path field
- list_all_templates/select_template/extract_repo_template accept repo config
- Remove hardcoded repo_dir(), use config.local_path
2026-03-09 08:42:23 +00:00

179 lines
4.9 KiB
Rust

mod api;
mod agent;
mod db;
mod kb;
mod llm;
mod exec;
pub mod state;
mod template;
mod timer;
mod tools;
mod ws;
use std::sync::Arc;
use axum::Router;
use sqlx::sqlite::SqlitePool;
use tower_http::cors::CorsLayer;
use tower_http::services::{ServeDir, ServeFile};
pub struct AppState {
pub db: db::Database,
pub config: Config,
pub agent_mgr: Arc<agent::AgentManager>,
pub kb: Option<Arc<kb::KbManager>>,
pub obj_root: String,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub llm: LlmConfig,
pub server: ServerConfig,
pub database: DatabaseConfig,
#[serde(default)]
pub template_repo: Option<TemplateRepoConfig>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct TemplateRepoConfig {
pub gitea_url: String,
pub owner: String,
pub repo: String,
#[serde(default = "default_repo_path")]
pub local_path: String,
}
fn default_repo_path() -> String {
if std::path::Path::new("/app/oseng-templates").is_dir() {
"/app/oseng-templates".to_string()
} else {
"oseng-templates".to_string()
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlmConfig {
pub base_url: String,
pub api_key: String,
pub model: String,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct DatabaseConfig {
pub path: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter("tori=debug,tower_http=debug")
.init();
let config_str = std::fs::read_to_string("config.yaml")
.expect("Failed to read config.yaml");
let config: Config = serde_yaml::from_str(&config_str)
.expect("Failed to parse config.yaml");
let database = db::Database::new(&config.database.path).await?;
database.migrate().await?;
// Initialize KB manager
let kb_arc = match kb::KbManager::new(database.pool.clone()) {
Ok(kb) => {
tracing::info!("KB manager initialized");
Some(Arc::new(kb))
}
Err(e) => {
tracing::warn!("KB manager init failed (will retry on use): {}", e);
None
}
};
// Ensure template repo is cloned before serving
if let Some(ref repo_cfg) = config.template_repo {
template::ensure_repo_ready(repo_cfg).await;
}
let agent_mgr = agent::AgentManager::new(
database.pool.clone(),
config.llm.clone(),
config.template_repo.clone(),
kb_arc.clone(),
);
timer::start_timer_runner(database.pool.clone(), agent_mgr.clone());
// Resume incomplete workflows after restart
resume_workflows(database.pool.clone(), agent_mgr.clone()).await;
let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string());
let state = Arc::new(AppState {
db: database,
config: config.clone(),
agent_mgr: agent_mgr.clone(),
kb: kb_arc,
obj_root: obj_root.clone(),
});
let app = Router::new()
.nest("/tori/api", api::router(state))
.nest("/api/obj", api::obj::router(obj_root.clone()))
.route("/api/obj/", axum::routing::get({
let r = obj_root;
move || api::obj::root_listing(r)
}))
.nest("/ws/tori", ws::router(agent_mgr))
.nest_service("/tori", ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html")))
.route("/", axum::routing::get(|| async {
axum::response::Redirect::permanent("/tori/")
}))
.layer(CorsLayer::permissive());
let addr = format!("{}:{}", &config.server.host, config.server.port);
tracing::info!("Tori server listening on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn resume_workflows(pool: SqlitePool, agent_mgr: Arc<agent::AgentManager>) {
let rows: Vec<(String, String, String)> = match sqlx::query_as(
"SELECT w.id, w.project_id, w.requirement FROM workflows w \
JOIN projects p ON w.project_id = p.id \
WHERE w.status IN ('pending', 'planning', 'executing') \
AND p.deleted = 0 \
ORDER BY w.created_at ASC"
)
.fetch_all(&pool)
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!("Failed to query incomplete workflows: {}", e);
return;
}
};
if rows.is_empty() {
tracing::info!("No incomplete workflows to resume");
return;
}
tracing::info!("Resuming {} incomplete workflow(s)", rows.len());
for (workflow_id, project_id, requirement) in rows {
tracing::info!("Resuming workflow {} (project {})", workflow_id, project_id);
agent_mgr.send_event(&project_id, agent::AgentEvent::NewRequirement {
workflow_id,
requirement,
template_id: None,
}).await;
}
}