From c3eb13dad381fb70d7d9ca5d169aafb8b80a6b0d Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Thu, 9 Apr 2026 20:28:54 +0100 Subject: [PATCH] refactor: split main.rs into 7 modules, add life loop with timer system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Structure: main.rs (534) — entry, handler, prompt building config.rs (52) — config structs state.rs (358) — AppState, SQLite, persistence tools.rs (665) — tool definitions, execution, subagent management stream.rs (776) — OpenAI/Claude streaming, system prompt display.rs (220)— markdown rendering, message formatting life.rs (87) — life loop heartbeat, timer firing New features: - Life Loop: background tokio task, 30s heartbeat, scans timers table - Timer tools: set_timer (relative/absolute/cron), list_timers, cancel_timer - inner_state table for life loop's own context - cron crate for recurring schedule parsing Zero logic changes in the refactor — pure structural split. --- Cargo.lock | 85 +++ Cargo.toml | 1 + src/config.rs | 52 ++ src/display.rs | 220 ++++++ src/life.rs | 87 +++ src/main.rs | 1818 +----------------------------------------------- src/state.rs | 358 ++++++++++ src/stream.rs | 776 +++++++++++++++++++++ src/tools.rs | 665 ++++++++++++++++++ 9 files changed, 2266 insertions(+), 1796 deletions(-) create mode 100644 src/config.rs create mode 100644 src/display.rs create mode 100644 src/life.rs create mode 100644 src/state.rs create mode 100644 src/stream.rs create mode 100644 src/tools.rs diff --git a/Cargo.lock b/Cargo.lock index 4b5a191..d8e1c6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -161,6 +161,18 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cron" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "089df96cf6a25253b4b6b6744d86f91150a3d4df546f31a95def47976b8cba97" +dependencies = [ + "chrono", + "once_cell", + "phf", + "winnow", +] + [[package]] name = "darling" version = "0.13.4" @@ -1038,6 +1050,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "chrono", + "cron", "dptree", "libc", "pulldown-cmark", @@ -1150,6 +1163,48 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_macros", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_macros" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.11" @@ -1268,6 +1323,21 @@ version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + [[package]] name = "rc-box" version = "1.3.0" @@ -1647,6 +1717,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" + [[package]] name = "slab" version = "0.4.12" @@ -2552,6 +2628,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.50.0" diff --git a/Cargo.toml b/Cargo.toml index 9f95dfd..04de24a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" anyhow = "1" base64 = "0.22" chrono = { version = "0.4", features = ["serde"] } +cron = "0.16" dptree = "0.3" libc = "0.2" serde = { version = "1", features = ["derive"] } diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..5853462 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,52 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct Config { + #[serde(default = "default_name")] + pub name: String, + pub tg: TgConfig, + pub auth: AuthConfig, + pub session: SessionConfig, + #[serde(default)] + pub backend: BackendConfig, + #[serde(default)] + pub whisper_url: Option, +} + +fn default_name() -> String { + "noc".to_string() +} + +#[derive(Deserialize, Clone, Default)] +#[serde(tag = "type")] +pub enum BackendConfig { + #[serde(rename = "claude")] + #[default] + Claude, + #[serde(rename = "openai")] + OpenAI { + endpoint: String, + model: String, + #[serde(default = "default_api_key")] + api_key: String, + }, +} + +fn default_api_key() -> String { + "unused".to_string() +} + +#[derive(Deserialize)] +pub struct TgConfig { + pub key: String, +} + +#[derive(Deserialize)] +pub struct AuthConfig { + pub passphrase: String, +} + +#[derive(Deserialize)] +pub struct SessionConfig { + pub refresh_hour: u32, +} diff --git a/src/display.rs b/src/display.rs new file mode 100644 index 0000000..00ef56d --- /dev/null +++ b/src/display.rs @@ -0,0 +1,220 @@ +use std::path::PathBuf; + +use base64::Engine; +use teloxide::prelude::*; +use teloxide::types::ParseMode; + +use crate::stream::{CURSOR, TG_MSG_LIMIT}; + +pub fn truncate_for_display(s: &str) -> String { + let budget = TG_MSG_LIMIT - CURSOR.len() - 1; + if s.len() <= budget { + format!("{s}{CURSOR}") + } else { + let truncated = truncate_at_char_boundary(s, budget - 2); + format!("{truncated}\n…{CURSOR}") + } +} + +pub fn truncate_at_char_boundary(s: &str, max: usize) -> &str { + if s.len() <= max { + return s; + } + let mut end = max; + while !s.is_char_boundary(end) { + end -= 1; + } + &s[..end] +} + +pub fn escape_html(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) +} + +pub fn markdown_to_telegram_html(md: &str) -> String { + use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd}; + + let mut opts = Options::empty(); + opts.insert(Options::ENABLE_STRIKETHROUGH); + + let parser = Parser::new_ext(md, opts); + let mut html = String::new(); + + for event in parser { + match event { + Event::Start(tag) => match tag { + Tag::Paragraph => {} + Tag::Heading { .. } => html.push_str(""), + Tag::BlockQuote(_) => html.push_str("
"), + Tag::CodeBlock(kind) => match kind { + CodeBlockKind::Fenced(ref lang) if !lang.is_empty() => { + html.push_str(&format!( + "
",
+                            escape_html(lang.as_ref())
+                        ));
+                    }
+                    _ => html.push_str("
"),
+                },
+                Tag::Item => html.push_str("• "),
+                Tag::Emphasis => html.push_str(""),
+                Tag::Strong => html.push_str(""),
+                Tag::Strikethrough => html.push_str(""),
+                Tag::Link { dest_url, .. } => {
+                    html.push_str(&format!(
+                        "",
+                        escape_html(dest_url.as_ref())
+                    ));
+                }
+                _ => {}
+            },
+            Event::End(tag) => match tag {
+                TagEnd::Paragraph => html.push_str("\n\n"),
+                TagEnd::Heading(_) => html.push_str("\n\n"),
+                TagEnd::BlockQuote(_) => html.push_str("
"), + TagEnd::CodeBlock => html.push_str("\n\n"), + TagEnd::List(_) => html.push('\n'), + TagEnd::Item => html.push('\n'), + TagEnd::Emphasis => html.push_str(""), + TagEnd::Strong => html.push_str("
"), + TagEnd::Strikethrough => html.push_str(""), + TagEnd::Link => html.push_str(""), + _ => {} + }, + Event::Text(text) => html.push_str(&escape_html(text.as_ref())), + Event::Code(text) => { + html.push_str(""); + html.push_str(&escape_html(text.as_ref())); + html.push_str(""); + } + Event::SoftBreak | Event::HardBreak => html.push('\n'), + Event::Rule => html.push_str("\n---\n\n"), + _ => {} + } + } + + html.trim_end().to_string() +} + +/// Send final result with HTML formatting, fallback to plain text on failure. +pub async fn send_final_result( + bot: &Bot, + chat_id: ChatId, + msg_id: Option, + use_draft: bool, + result: &str, +) { + let html = markdown_to_telegram_html(result); + + // try HTML as single message + let html_ok = if let (false, Some(id)) = (use_draft, msg_id) { + bot.edit_message_text(chat_id, id, &html) + .parse_mode(ParseMode::Html) + .await + .is_ok() + } else { + bot.send_message(chat_id, &html) + .parse_mode(ParseMode::Html) + .await + .is_ok() + }; + + if html_ok { + return; + } + + // fallback: plain text with chunking + let chunks = split_msg(result, TG_MSG_LIMIT); + if let (false, Some(id)) = (use_draft, msg_id) { + let _ = bot.edit_message_text(chat_id, id, chunks[0]).await; + for chunk in &chunks[1..] { + let _ = bot.send_message(chat_id, *chunk).await; + } + } else { + for chunk in &chunks { + let _ = bot.send_message(chat_id, *chunk).await; + } + } +} + +pub fn split_msg(s: &str, max: usize) -> Vec<&str> { + if s.len() <= max { + return vec![s]; + } + let mut parts = Vec::new(); + let mut rest = s; + while !rest.is_empty() { + if rest.len() <= max { + parts.push(rest); + break; + } + let mut end = max; + while !rest.is_char_boundary(end) { + end -= 1; + } + let (chunk, tail) = rest.split_at(end); + parts.push(chunk); + rest = tail; + } + parts +} + +/// Build user message content, with optional images/videos as multimodal input. +pub fn build_user_content( + text: &str, + scratch: &str, + media: &[PathBuf], +) -> serde_json::Value { + let full_text = if scratch.is_empty() { + text.to_string() + } else { + format!("{text}\n\n[scratch]\n{scratch}") + }; + + // collect media data (images + videos) + let mut media_parts: Vec = Vec::new(); + for path in media { + let (mime, is_video) = match path + .extension() + .and_then(|e| e.to_str()) + .map(|e| e.to_lowercase()) + .as_deref() + { + Some("jpg" | "jpeg") => ("image/jpeg", false), + Some("png") => ("image/png", false), + Some("gif") => ("image/gif", false), + Some("webp") => ("image/webp", false), + Some("mp4") => ("video/mp4", true), + Some("webm") => ("video/webm", true), + Some("mov") => ("video/quicktime", true), + _ => continue, + }; + if let Ok(data) = std::fs::read(path) { + let b64 = base64::engine::general_purpose::STANDARD.encode(&data); + let data_url = format!("data:{mime};base64,{b64}"); + if is_video { + media_parts.push(serde_json::json!({ + "type": "video_url", + "video_url": {"url": data_url} + })); + } else { + media_parts.push(serde_json::json!({ + "type": "image_url", + "image_url": {"url": data_url} + })); + } + } + } + + if media_parts.is_empty() { + // plain text — more compatible + serde_json::Value::String(full_text) + } else { + // multimodal array + let mut content = vec![serde_json::json!({"type": "text", "text": full_text})]; + content.extend(media_parts); + serde_json::Value::Array(content) + } +} diff --git a/src/life.rs b/src/life.rs new file mode 100644 index 0000000..7bca23b --- /dev/null +++ b/src/life.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use teloxide::prelude::*; +use tracing::{error, info}; + +use crate::config::{BackendConfig, Config}; +use crate::state::AppState; +use crate::stream::run_openai_streaming; +use crate::tools::compute_next_cron_fire; + +pub async fn life_loop(bot: Bot, state: Arc, config: Arc) { + info!("life loop started"); + let mut interval = tokio::time::interval(std::time::Duration::from_secs(30)); + + loop { + interval.tick().await; + + let due = state.due_timers().await; + if due.is_empty() { + continue; + } + + for (timer_id, chat_id_raw, label, schedule) in &due { + let chat_id = ChatId(*chat_id_raw); + info!(timer_id, %label, "timer fired"); + + // build life loop context + let persona = state.get_config("persona").await.unwrap_or_default(); + let inner = state.get_inner_state().await; + let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string(); + + let mut system_text = if persona.is_empty() { + "你叫小乖,是Fam的AI伙伴。".to_string() + } else { + persona.clone() + }; + system_text.push_str(&format!( + "\n\n[当前时间] {now}\n\n[你的内心状态]\n{}", + if inner.is_empty() { "(空)" } else { &inner } + )); + system_text.push_str( + "\n\n你可以使用工具来完成任务。你可以选择发消息给用户,也可以选择什么都不做(直接回复空文本)。\ + 可以用 update_inner_state 更新你的内心状态。\ + 输出格式:纯文本或基础Markdown,不要LaTeX或特殊Unicode。", + ); + + let messages = vec![ + serde_json::json!({"role": "system", "content": system_text}), + serde_json::json!({"role": "user", "content": format!("[timer] {label}")}), + ]; + + // call LLM (no tools for now — keep life loop simple) + if let BackendConfig::OpenAI { + ref endpoint, + ref model, + ref api_key, + } = config.backend + { + match run_openai_streaming(endpoint, model, api_key, &messages, &bot, chat_id) + .await + { + Ok(response) => { + if !response.is_empty() { + info!(timer_id, "life loop sent response ({} chars)", response.len()); + } + } + Err(e) => { + error!(timer_id, "life loop LLM error: {e:#}"); + } + } + } + + // reschedule or delete + if schedule.starts_with("cron:") { + if let Some(next) = compute_next_cron_fire(schedule) { + state.update_timer_next_fire(*timer_id, &next).await; + info!(timer_id, next = %next, "cron rescheduled"); + } else { + state.cancel_timer(*timer_id).await; + } + } else { + // one-shot: delete after firing + state.cancel_timer(*timer_id).await; + } + } + } +} diff --git a/src/main.rs b/src/main.rs index fc93173..c6365d6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,492 +1,31 @@ -use std::collections::{HashMap, HashSet}; +mod config; +mod state; +mod tools; +mod stream; +mod display; +mod life; + +use std::collections::HashSet; use std::path::{Path, PathBuf}; -use std::process::Stdio; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use anyhow::Result; use chrono::{Local, NaiveDate, NaiveTime}; -use serde::{Deserialize, Serialize}; use teloxide::dispatching::UpdateFilterExt; use teloxide::net::Download; use teloxide::prelude::*; use teloxide::types::InputFile; -use tokio::io::AsyncBufReadExt; -use tokio::process::Command; -use tokio::sync::RwLock; -use tokio::time::Instant; use tracing::{error, info, warn}; -use base64::Engine; use uuid::Uuid; -// ── config ────────────────────────────────────────────────────────── - -#[derive(Deserialize)] -struct Config { - #[serde(default = "default_name")] - name: String, - tg: TgConfig, - auth: AuthConfig, - session: SessionConfig, - #[serde(default)] - backend: BackendConfig, - #[serde(default)] - whisper_url: Option, -} - -fn default_name() -> String { - "noc".to_string() -} - -#[derive(Deserialize, Clone, Default)] -#[serde(tag = "type")] -enum BackendConfig { - #[serde(rename = "claude")] - #[default] - Claude, - #[serde(rename = "openai")] - OpenAI { - endpoint: String, - model: String, - #[serde(default = "default_api_key")] - api_key: String, - }, -} - -fn default_api_key() -> String { - "unused".to_string() -} - - -#[derive(Deserialize)] -struct TgConfig { - key: String, -} - -#[derive(Deserialize)] -struct AuthConfig { - passphrase: String, -} - -#[derive(Deserialize)] -struct SessionConfig { - refresh_hour: u32, -} - -// ── persistent state ──────────────────────────────────────────────── - -#[derive(Serialize, Deserialize, Default)] -struct Persistent { - authed: HashMap, - known_sessions: HashSet, -} - -#[derive(Serialize, Deserialize, Clone, Default)] -struct ConversationState { - summary: String, - messages: Vec, - total_messages: usize, -} - -const MAX_WINDOW: usize = 100; -const SLIDE_SIZE: usize = 50; - -// ── subagent & tool call ─────────────────────────────────────────── - -struct SubAgent { - task: String, - output: Arc>, - completed: Arc, - exit_code: Arc>>, - pid: Option, -} - -struct ToolCall { - id: String, - name: String, - arguments: String, -} - -fn tools_dir() -> PathBuf { - // tools/ relative to the config file location - let config_path = std::env::var("NOC_CONFIG").unwrap_or_else(|_| "config.yaml".into()); - let config_dir = Path::new(&config_path) - .parent() - .unwrap_or(Path::new(".")); - config_dir.join("tools") -} - -/// Scan tools/ directory for scripts with --schema, merge with built-in tools. -/// Called on every API request so new/updated scripts take effect immediately. -fn discover_tools() -> serde_json::Value { - let mut tools = vec![ - serde_json::json!({ - "type": "function", - "function": { - "name": "spawn_agent", - "description": "启动一个 Claude Code 子代理异步执行复杂任务。子代理可使用 shell、浏览器和搜索引擎,适合网页搜索、资料查找、技术调研、代码任务等。完成后会收到通知。", - "parameters": { - "type": "object", - "properties": { - "id": {"type": "string", "description": "简短唯一标识符(如 'research'、'fix-bug')"}, - "task": {"type": "string", "description": "给子代理的详细任务描述"} - }, - "required": ["id", "task"] - } - } - }), - serde_json::json!({ - "type": "function", - "function": { - "name": "agent_status", - "description": "查看正在运行或已完成的子代理的状态和输出", - "parameters": { - "type": "object", - "properties": { - "id": {"type": "string", "description": "子代理标识符"} - }, - "required": ["id"] - } - } - }), - serde_json::json!({ - "type": "function", - "function": { - "name": "kill_agent", - "description": "终止一个正在运行的子代理", - "parameters": { - "type": "object", - "properties": { - "id": {"type": "string", "description": "子代理标识符"} - }, - "required": ["id"] - } - } - }), - serde_json::json!({ - "type": "function", - "function": { - "name": "send_file", - "description": "通过 Telegram 向用户发送服务器上的文件,文件必须存在于服务器文件系统中。", - "parameters": { - "type": "object", - "properties": { - "path": {"type": "string", "description": "服务器上文件的绝对路径"}, - "caption": {"type": "string", "description": "可选的文件说明/描述"} - }, - "required": ["path"] - } - } - }), - serde_json::json!({ - "type": "function", - "function": { - "name": "update_scratch", - "description": "更新你的草稿区(工作笔记、状态、提醒)。草稿区内容会附加到每条用户消息中,确保你始终可见。用于跨轮次跟踪上下文。", - "parameters": { - "type": "object", - "properties": { - "content": {"type": "string", "description": "完整的草稿区内容(替换之前的内容)"} - }, - "required": ["content"] - } - } - }), - serde_json::json!({ - "type": "function", - "function": { - "name": "update_memory", - "description": "写入持久记忆槽。共 100 个槽位(0-99),跨会话保留。记忆槽内容会注入到每次对话的 system prompt 中。用于存储关键事实、用户偏好或重要上下文。内容设为空字符串可清除槽位。", - "parameters": { - "type": "object", - "properties": { - "slot_nr": {"type": "integer", "description": "槽位编号(0-99)"}, - "content": {"type": "string", "description": "要存储的内容(最多200字符),空字符串表示清除该槽位"} - }, - "required": ["slot_nr", "content"] - } - } - }), - serde_json::json!({ - "type": "function", - "function": { - "name": "gen_voice", - "description": "将文字合成为语音并直接发送给用户。", - "parameters": { - "type": "object", - "properties": { - "text": {"type": "string", "description": "要合成语音的文字内容"} - }, - "required": ["text"] - } - } - }), - ]; - - // discover script tools - let dir = tools_dir(); - if let Ok(entries) = std::fs::read_dir(&dir) { - for entry in entries.flatten() { - let path = entry.path(); - if !path.is_file() { - continue; - } - // run --schema with a short timeout - let output = std::process::Command::new(&path) - .arg("--schema") - .output(); - match output { - Ok(out) if out.status.success() => { - let stdout = String::from_utf8_lossy(&out.stdout); - match serde_json::from_str::(stdout.trim()) { - Ok(schema) => { - let name = schema["name"].as_str().unwrap_or("?"); - info!(tool = %name, path = %path.display(), "discovered script tool"); - tools.push(serde_json::json!({ - "type": "function", - "function": schema, - })); - } - Err(e) => { - warn!(path = %path.display(), "invalid --schema JSON: {e}"); - } - } - } - _ => {} // not a tool script, skip silently - } - } - } - - serde_json::Value::Array(tools) -} - -struct AppState { - persist: RwLock, - state_path: PathBuf, - db: tokio::sync::Mutex, - agents: RwLock>>, -} - -impl AppState { - fn load(path: PathBuf) -> Self { - let persist = std::fs::read_to_string(&path) - .ok() - .and_then(|s| serde_json::from_str(&s).ok()) - .unwrap_or_default(); - info!("loaded state from {}", path.display()); - - let db_path = path.parent().unwrap_or(Path::new(".")).join("noc.db"); - let conn = rusqlite::Connection::open(&db_path) - .unwrap_or_else(|e| panic!("open {}: {e}", db_path.display())); - conn.execute_batch( - "CREATE TABLE IF NOT EXISTS conversations ( - session_id TEXT PRIMARY KEY, - summary TEXT NOT NULL DEFAULT '', - total_messages INTEGER NOT NULL DEFAULT 0 - ); - CREATE TABLE IF NOT EXISTS messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id TEXT NOT NULL, - role TEXT NOT NULL, - content TEXT NOT NULL, - created_at TEXT NOT NULL DEFAULT (datetime('now', 'localtime')) - ); - CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id); - CREATE TABLE IF NOT EXISTS scratch_area ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - content TEXT NOT NULL, - created_at TEXT NOT NULL DEFAULT (datetime('now')) - ); - CREATE TABLE IF NOT EXISTS config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL DEFAULT '', - create_time TEXT NOT NULL DEFAULT (datetime('now')), - update_time TEXT NOT NULL DEFAULT (datetime('now')) - ); - CREATE TABLE IF NOT EXISTS config_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - key TEXT NOT NULL, - value TEXT NOT NULL, - create_time TEXT NOT NULL, - update_time TEXT NOT NULL - ); - CREATE TABLE IF NOT EXISTS memory_slots ( - slot_nr INTEGER PRIMARY KEY CHECK(slot_nr BETWEEN 0 AND 99), - content TEXT NOT NULL DEFAULT '' - );", - ) - .expect("init db schema"); - - // migrations - let _ = conn.execute( - "ALTER TABLE messages ADD COLUMN created_at TEXT NOT NULL DEFAULT ''", - [], - ); - - info!("opened db {}", db_path.display()); - - Self { - persist: RwLock::new(persist), - state_path: path, - db: tokio::sync::Mutex::new(conn), - agents: RwLock::new(HashMap::new()), - } - } - - async fn save(&self) { - let data = self.persist.read().await; - if let Ok(json) = serde_json::to_string_pretty(&*data) { - if let Err(e) = std::fs::write(&self.state_path, json) { - error!("save state: {e}"); - } - } - } - - async fn load_conv(&self, sid: &str) -> ConversationState { - let db = self.db.lock().await; - let (summary, total) = db - .query_row( - "SELECT summary, total_messages FROM conversations WHERE session_id = ?1", - [sid], - |row| Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?)), - ) - .unwrap_or_default(); - - let mut stmt = db - .prepare("SELECT role, content, created_at FROM messages WHERE session_id = ?1 ORDER BY id") - .unwrap(); - let messages: Vec = stmt - .query_map([sid], |row| { - let role: String = row.get(0)?; - let content: String = row.get(1)?; - let ts: String = row.get(2)?; - let tagged = if ts.is_empty() { - content - } else { - format!("[{ts}] {content}") - }; - Ok(serde_json::json!({"role": role, "content": tagged})) - }) - .unwrap() - .filter_map(|r| r.ok()) - .collect(); - - ConversationState { - summary, - messages, - total_messages: total, - } - } - - async fn push_message(&self, sid: &str, role: &str, content: &str) { - let db = self.db.lock().await; - let _ = db.execute( - "INSERT OR IGNORE INTO conversations (session_id) VALUES (?1)", - [sid], - ); - let _ = db.execute( - "INSERT INTO messages (session_id, role, content, created_at) VALUES (?1, ?2, ?3, datetime('now', 'localtime'))", - rusqlite::params![sid, role, content], - ); - } - - async fn message_count(&self, sid: &str) -> usize { - let db = self.db.lock().await; - db.query_row( - "SELECT COUNT(*) FROM messages WHERE session_id = ?1", - [sid], - |row| row.get(0), - ) - .unwrap_or(0) - } - - async fn slide_window(&self, sid: &str, new_summary: &str, slide_size: usize) { - let db = self.db.lock().await; - let _ = db.execute( - "DELETE FROM messages WHERE id IN ( - SELECT id FROM messages WHERE session_id = ?1 ORDER BY id LIMIT ?2 - )", - rusqlite::params![sid, slide_size], - ); - let _ = db.execute( - "UPDATE conversations SET summary = ?1, total_messages = total_messages + ?2 \ - WHERE session_id = ?3", - rusqlite::params![new_summary, slide_size, sid], - ); - } - - async fn get_oldest_messages(&self, sid: &str, count: usize) -> Vec { - let db = self.db.lock().await; - let mut stmt = db - .prepare( - "SELECT role, content FROM messages WHERE session_id = ?1 ORDER BY id LIMIT ?2", - ) - .unwrap(); - stmt.query_map(rusqlite::params![sid, count], |row| { - let role: String = row.get(0)?; - let content: String = row.get(1)?; - Ok(serde_json::json!({"role": role, "content": content})) - }) - .unwrap() - .filter_map(|r| r.ok()) - .collect() - } - - async fn get_scratch(&self) -> String { - let db = self.db.lock().await; - db.query_row( - "SELECT content FROM scratch_area ORDER BY id DESC LIMIT 1", - [], - |row| row.get(0), - ) - .unwrap_or_default() - } - - async fn push_scratch(&self, content: &str) { - let db = self.db.lock().await; - let _ = db.execute( - "INSERT INTO scratch_area (content) VALUES (?1)", - [content], - ); - } - - async fn get_config(&self, key: &str) -> Option { - let db = self.db.lock().await; - db.query_row( - "SELECT value FROM config WHERE key = ?1", - [key], - |row| row.get(0), - ) - .ok() - } - - async fn get_memory_slots(&self) -> Vec<(i32, String)> { - let db = self.db.lock().await; - let mut stmt = db - .prepare("SELECT slot_nr, content FROM memory_slots WHERE content != '' ORDER BY slot_nr") - .unwrap(); - stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?))) - .unwrap() - .filter_map(|r| r.ok()) - .collect() - } - - async fn set_memory_slot(&self, slot_nr: i32, content: &str) -> Result<()> { - if !(0..=99).contains(&slot_nr) { - anyhow::bail!("slot_nr must be 0-99, got {slot_nr}"); - } - if content.len() > 200 { - anyhow::bail!("content too long: {} chars (max 200)", content.len()); - } - let db = self.db.lock().await; - db.execute( - "INSERT INTO memory_slots (slot_nr, content) VALUES (?1, ?2) \ - ON CONFLICT(slot_nr) DO UPDATE SET content = ?2", - rusqlite::params![slot_nr, content], - )?; - Ok(()) - } -} +use config::{BackendConfig, Config}; +use display::build_user_content; +use state::{AppState, MAX_WINDOW, SLIDE_SIZE}; +use stream::{ + build_system_prompt, invoke_claude_streaming, run_claude_streaming, run_openai_with_tools, + summarize_messages, +}; +use tools::discover_tools; // ── helpers ───────────────────────────────────────────────────────── @@ -549,8 +88,13 @@ async fn main() { let handler = Update::filter_message().endpoint(handle); + let config = Arc::new(config); + + // start life loop + tokio::spawn(life::life_loop(bot.clone(), state.clone(), config.clone())); + Dispatcher::builder(bot, handler) - .dependencies(dptree::deps![state, Arc::new(config), bot_username]) + .dependencies(dptree::deps![state, config, bot_username]) .default_handler(|_| async {}) .build() .dispatch() @@ -988,1321 +532,3 @@ async fn transcribe_audio(whisper_url: &str, file_path: &Path) -> Result let json: serde_json::Value = resp.json().await?; Ok(json["text"].as_str().unwrap_or("").to_string()) } - -fn build_system_prompt(summary: &str, persona: &str, memory_slots: &[(i32, String)]) -> serde_json::Value { - let mut text = if persona.is_empty() { - String::from("你是一个AI助手。") - } else { - persona.to_string() - }; - - text.push_str( - "\n\n你可以使用提供的工具来完成任务。\ - 当需要执行命令、运行代码或启动复杂子任务时,直接调用对应的工具,不要只是描述你会怎么做。\ - 当需要搜索信息(如网页搜索、资料查找、技术调研等)时,使用 spawn_agent 启动一个子代理来完成搜索任务,\ - 子代理可以使用浏览器和搜索引擎,搜索完成后你会收到结果通知。\ - 输出格式:使用纯文本或基础Markdown(加粗、列表、代码块)。\ - 不要使用LaTeX公式($...$)、特殊Unicode符号(→←↔)或HTML标签,Telegram无法渲染这些。", - ); - - if !memory_slots.is_empty() { - text.push_str("\n\n## 持久记忆(跨会话保留)\n"); - for (nr, content) in memory_slots { - text.push_str(&format!("[{nr}] {content}\n")); - } - } - - if !summary.is_empty() { - text.push_str("\n\n## 之前的对话总结\n"); - text.push_str(summary); - } - - serde_json::json!({"role": "system", "content": text}) -} - -/// Build user message content, with optional images/videos as multimodal input. -fn build_user_content( - text: &str, - scratch: &str, - media: &[PathBuf], -) -> serde_json::Value { - let full_text = if scratch.is_empty() { - text.to_string() - } else { - format!("{text}\n\n[scratch]\n{scratch}") - }; - - // collect media data (images + videos) - let mut media_parts: Vec = Vec::new(); - for path in media { - let (mime, is_video) = match path - .extension() - .and_then(|e| e.to_str()) - .map(|e| e.to_lowercase()) - .as_deref() - { - Some("jpg" | "jpeg") => ("image/jpeg", false), - Some("png") => ("image/png", false), - Some("gif") => ("image/gif", false), - Some("webp") => ("image/webp", false), - Some("mp4") => ("video/mp4", true), - Some("webm") => ("video/webm", true), - Some("mov") => ("video/quicktime", true), - _ => continue, - }; - if let Ok(data) = std::fs::read(path) { - let b64 = base64::engine::general_purpose::STANDARD.encode(&data); - let data_url = format!("data:{mime};base64,{b64}"); - if is_video { - media_parts.push(serde_json::json!({ - "type": "video_url", - "video_url": {"url": data_url} - })); - } else { - media_parts.push(serde_json::json!({ - "type": "image_url", - "image_url": {"url": data_url} - })); - } - } - } - - if media_parts.is_empty() { - // plain text — more compatible - serde_json::Value::String(full_text) - } else { - // multimodal array - let mut content = vec![serde_json::json!({"type": "text", "text": full_text})]; - content.extend(media_parts); - serde_json::Value::Array(content) - } -} - -async fn summarize_messages( - endpoint: &str, - model: &str, - api_key: &str, - existing_summary: &str, - dropped: &[serde_json::Value], -) -> Result { - let msgs_text: String = dropped - .iter() - .filter_map(|m| { - let role = m["role"].as_str()?; - let content = m["content"].as_str()?; - Some(format!("{role}: {content}")) - }) - .collect::>() - .join("\n\n"); - - let prompt = if existing_summary.is_empty() { - format!( - "请将以下对话总结为约4000字符的摘要,保留关键信息和上下文:\n\n{}", - msgs_text - ) - } else { - format!( - "请将以下新对话内容整合到现有总结中,保持总结在约4000字符以内。\ - 保留重要信息,让较旧的话题自然淡出。\n\n\ - 现有总结:\n{}\n\n新对话:\n{}", - existing_summary, msgs_text - ) - }; - - let client = reqwest::Client::new(); - let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); - - let body = serde_json::json!({ - "model": model, - "messages": [ - {"role": "system", "content": "你是一个对话总结助手。请生成简洁但信息丰富的总结。"}, - {"role": "user", "content": prompt} - ], - }); - - let resp = client - .post(&url) - .header("Authorization", format!("Bearer {api_key}")) - .json(&body) - .send() - .await? - .error_for_status()?; - - let json: serde_json::Value = resp.json().await?; - let summary = json["choices"][0]["message"]["content"] - .as_str() - .unwrap_or("") - .to_string(); - - Ok(summary) -} - -// ── tool execution ───────────────────────────────────────────────── - -async fn execute_tool( - name: &str, - arguments: &str, - state: &Arc, - bot: &Bot, - chat_id: ChatId, - sid: &str, - config: &Arc, -) -> String { - let args: serde_json::Value = match serde_json::from_str(arguments) { - Ok(v) => v, - Err(e) => return format!("Invalid arguments: {e}"), - }; - - match name { - "spawn_agent" => { - let id = args["id"].as_str().unwrap_or("agent"); - let task = args["task"].as_str().unwrap_or(""); - spawn_agent(id, task, state, bot, chat_id, sid, config).await - } - "agent_status" => { - let id = args["id"].as_str().unwrap_or(""); - check_agent_status(id, state).await - } - "kill_agent" => { - let id = args["id"].as_str().unwrap_or(""); - kill_agent(id, state).await - } - "send_file" => { - let path_str = args["path"].as_str().unwrap_or(""); - let caption = args["caption"].as_str().unwrap_or(""); - let path = Path::new(path_str); - if !path.exists() { - return format!("File not found: {path_str}"); - } - if !path.is_file() { - return format!("Not a file: {path_str}"); - } - let input_file = InputFile::file(path); - let mut req = bot.send_document(chat_id, input_file); - if !caption.is_empty() { - req = req.caption(caption); - } - match req.await { - Ok(_) => format!("File sent: {path_str}"), - Err(e) => format!("Failed to send file: {e:#}"), - } - } - "update_scratch" => { - let content = args["content"].as_str().unwrap_or(""); - state.push_scratch(content).await; - format!("Scratch updated ({} chars)", content.len()) - } - "update_memory" => { - let slot_nr = args["slot_nr"].as_i64().unwrap_or(-1) as i32; - let content = args["content"].as_str().unwrap_or(""); - match state.set_memory_slot(slot_nr, content).await { - Ok(_) => { - if content.is_empty() { - format!("Memory slot {slot_nr} cleared") - } else { - format!("Memory slot {slot_nr} updated ({} chars)", content.len()) - } - } - Err(e) => format!("Error: {e}"), - } - } - "gen_voice" => { - let text = args["text"].as_str().unwrap_or(""); - if text.is_empty() { - return "Error: text is required".to_string(); - } - let script = tools_dir().join("gen_voice"); - let result = tokio::time::timeout( - std::time::Duration::from_secs(120), - tokio::process::Command::new(&script) - .arg(arguments) - .output(), - ) - .await; - match result { - Ok(Ok(out)) if out.status.success() => { - let path_str = String::from_utf8_lossy(&out.stdout).trim().to_string(); - let path = Path::new(&path_str); - if path.exists() { - let input_file = InputFile::file(path); - match bot.send_voice(chat_id, input_file).await { - Ok(_) => format!("语音已发送: {path_str}"), - Err(e) => format!("语音生成成功但发送失败: {e:#}"), - } - } else { - format!("语音生成失败: 输出文件不存在 ({path_str})") - } - } - Ok(Ok(out)) => { - let stderr = String::from_utf8_lossy(&out.stderr); - let stdout = String::from_utf8_lossy(&out.stdout); - format!("gen_voice failed: {stdout} {stderr}") - } - Ok(Err(e)) => format!("gen_voice exec error: {e}"), - Err(_) => "gen_voice timeout (120s)".to_string(), - } - } - _ => run_script_tool(name, arguments).await, - } -} - -async fn spawn_agent( - id: &str, - task: &str, - state: &Arc, - bot: &Bot, - chat_id: ChatId, - sid: &str, - config: &Arc, -) -> String { - // check if already exists - if state.agents.read().await.contains_key(id) { - return format!("Agent '{id}' already exists. Use agent_status to check it."); - } - - let mut child = match Command::new("claude") - .args(["--dangerously-skip-permissions", "-p", task]) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - { - Ok(c) => c, - Err(e) => return format!("Failed to spawn agent: {e}"), - }; - - let pid = child.id(); - let output = Arc::new(tokio::sync::RwLock::new(String::new())); - let completed = Arc::new(AtomicBool::new(false)); - let exit_code = Arc::new(tokio::sync::RwLock::new(None)); - - let agent = Arc::new(SubAgent { - task: task.to_string(), - output: output.clone(), - completed: completed.clone(), - exit_code: exit_code.clone(), - pid, - }); - - state.agents.write().await.insert(id.to_string(), agent); - - // background task: collect output and wakeup on completion - let out = output.clone(); - let done = completed.clone(); - let ecode = exit_code.clone(); - let bot_c = bot.clone(); - let chat_id_c = chat_id; - let state_c = state.clone(); - let config_c = config.clone(); - let sid_c = sid.to_string(); - let id_c = id.to_string(); - - tokio::spawn(async move { - let stdout = child.stdout.take(); - if let Some(stdout) = stdout { - let mut lines = tokio::io::BufReader::new(stdout).lines(); - while let Ok(Some(line)) = lines.next_line().await { - let mut o = out.write().await; - o.push_str(&line); - o.push('\n'); - } - } - let status = child.wait().await; - let code = status.as_ref().ok().and_then(|s| s.code()); - *ecode.write().await = code; - done.store(true, Ordering::SeqCst); - - info!(agent = %id_c, "agent completed, exit={code:?}"); - - // wakeup: inject result and trigger LLM - let result = out.read().await.clone(); - let result_short = truncate_at_char_boundary(&result, 4000); - let wakeup = format!( - "[Agent '{id_c}' 执行完成 (exit={})]\n{result_short}", - code.unwrap_or(-1) - ); - - if let Err(e) = agent_wakeup( - &config_c, &state_c, &bot_c, chat_id_c, &sid_c, &wakeup, &id_c, - ) - .await - { - error!(agent = %id_c, "wakeup failed: {e:#}"); - let _ = bot_c - .send_message(chat_id_c, format!("[agent wakeup error] {e:#}")) - .await; - } - }); - - format!("Agent '{id}' spawned (pid={pid:?})") -} - -async fn agent_wakeup( - config: &Config, - state: &AppState, - bot: &Bot, - chat_id: ChatId, - sid: &str, - wakeup_msg: &str, - agent_id: &str, -) -> Result<()> { - match &config.backend { - BackendConfig::OpenAI { - endpoint, - model, - api_key, - } => { - state.push_message(sid, "user", wakeup_msg).await; - let conv = state.load_conv(sid).await; - let persona = state.get_config("persona").await.unwrap_or_default(); - let memory_slots = state.get_memory_slots().await; - let system_msg = build_system_prompt(&conv.summary, &persona, &memory_slots); - let mut api_messages = vec![system_msg]; - api_messages.extend(conv.messages); - - info!(agent = %agent_id, "wakeup: sending {} messages to LLM", api_messages.len()); - - let response = - run_openai_streaming(endpoint, model, api_key, &api_messages, bot, chat_id) - .await?; - - if !response.is_empty() { - state.push_message(sid, "assistant", &response).await; - } - - Ok(()) - } - _ => { - let _ = bot - .send_message(chat_id, format!("[Agent '{agent_id}' done]\n{wakeup_msg}")) - .await; - Ok(()) - } - } -} - -async fn check_agent_status(id: &str, state: &AppState) -> String { - let agents = state.agents.read().await; - match agents.get(id) { - Some(agent) => { - let status = if agent.completed.load(Ordering::SeqCst) { - let code = agent.exit_code.read().await; - format!("completed (exit={})", code.unwrap_or(-1)) - } else { - "running".to_string() - }; - let output = agent.output.read().await; - let out_preview = truncate_at_char_boundary(&output, 3000); - format!( - "Agent '{id}': {status}\nTask: {}\nOutput ({} bytes):\n{out_preview}", - agent.task, - output.len() - ) - } - None => format!("Agent '{id}' not found"), - } -} - -async fn kill_agent(id: &str, state: &AppState) -> String { - let agents = state.agents.read().await; - match agents.get(id) { - Some(agent) => { - if agent.completed.load(Ordering::SeqCst) { - return format!("Agent '{id}' already completed"); - } - if let Some(pid) = agent.pid { - unsafe { - libc::kill(pid as i32, libc::SIGTERM); - } - format!("Sent SIGTERM to agent '{id}' (pid={pid})") - } else { - format!("Agent '{id}' has no PID") - } - } - None => format!("Agent '{id}' not found"), - } -} - -async fn run_script_tool(name: &str, arguments: &str) -> String { - // find script in tools/ that matches this tool name - let dir = tools_dir(); - let entries = match std::fs::read_dir(&dir) { - Ok(e) => e, - Err(_) => return format!("Unknown tool: {name}"), - }; - - for entry in entries.flatten() { - let path = entry.path(); - if !path.is_file() { - continue; - } - // check if this script provides the requested tool - let schema_out = std::process::Command::new(&path) - .arg("--schema") - .output(); - if let Ok(out) = schema_out { - if out.status.success() { - let stdout = String::from_utf8_lossy(&out.stdout); - if let Ok(schema) = serde_json::from_str::(stdout.trim()) { - if schema["name"].as_str() == Some(name) { - // found it — execute - info!(tool = %name, path = %path.display(), "running script tool"); - let result = tokio::time::timeout( - std::time::Duration::from_secs(60), - Command::new(&path).arg(arguments).output(), - ) - .await; - - return match result { - Ok(Ok(output)) => { - let mut s = String::from_utf8_lossy(&output.stdout).to_string(); - let stderr = String::from_utf8_lossy(&output.stderr); - if !stderr.is_empty() { - if !s.is_empty() { - s.push_str("\n[stderr]\n"); - } - s.push_str(&stderr); - } - if s.is_empty() { - format!("(exit={})", output.status.code().unwrap_or(-1)) - } else { - s - } - } - Ok(Err(e)) => format!("Failed to execute {name}: {e}"), - Err(_) => "Timeout after 60s".to_string(), - }; - } - } - } - } - } - - format!("Unknown tool: {name}") -} - -// ── openai with tool call loop ───────────────────────────────────── - -#[allow(clippy::too_many_arguments)] -async fn run_openai_with_tools( - endpoint: &str, - model: &str, - api_key: &str, - mut messages: Vec, - bot: &Bot, - chat_id: ChatId, - state: &Arc, - sid: &str, - config: &Arc, - is_private: bool, -) -> Result { - let client = reqwest::Client::new(); - let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); - let tools = discover_tools(); - - loop { - let body = serde_json::json!({ - "model": model, - "messages": messages, - "tools": tools, - "stream": true, - }); - - info!("API request: {} messages, {} tools", - messages.len(), - tools.as_array().map(|a| a.len()).unwrap_or(0)); - - let resp_raw = client - .post(&url) - .header("Authorization", format!("Bearer {api_key}")) - .json(&body) - .send() - .await?; - - if !resp_raw.status().is_success() { - let status = resp_raw.status(); - let body_text = resp_raw.text().await.unwrap_or_default(); - // dump messages for debugging - for (i, m) in messages.iter().enumerate() { - let role = m["role"].as_str().unwrap_or("?"); - let content_len = m["content"].as_str().map(|s| s.len()).unwrap_or(0); - let has_tc = m.get("tool_calls").is_some(); - let has_tcid = m.get("tool_call_id").is_some(); - warn!(" msg[{i}] role={role} content_len={content_len} tool_calls={has_tc} tool_call_id={has_tcid}"); - } - error!("OpenAI API {status}: {body_text}"); - anyhow::bail!("OpenAI API {status}: {body_text}"); - } - - let mut resp = resp_raw; - - let token = bot.token().to_owned(); - let raw_chat_id = chat_id.0; - let draft_id: i64 = 1; - let mut use_draft = is_private; // sendMessageDraft only works in private chats - - let mut msg_id: Option = None; - let mut accumulated = String::new(); - let mut last_edit = Instant::now(); - let mut buffer = String::new(); - let mut done = false; - - // tool call accumulation - let mut tool_calls: Vec = Vec::new(); - let mut has_tool_calls = false; - - while let Some(chunk) = resp.chunk().await? { - if done { - break; - } - buffer.push_str(&String::from_utf8_lossy(&chunk)); - - while let Some(pos) = buffer.find('\n') { - let line = buffer[..pos].to_string(); - buffer = buffer[pos + 1..].to_string(); - - let trimmed = line.trim(); - if trimmed.is_empty() || trimmed.starts_with(':') { - continue; - } - - let data = match trimmed.strip_prefix("data: ") { - Some(d) => d, - None => continue, - }; - - if data.trim() == "[DONE]" { - done = true; - break; - } - - if let Ok(json) = serde_json::from_str::(data) { - let delta = &json["choices"][0]["delta"]; - - // handle content delta - if let Some(content) = delta["content"].as_str() { - if !content.is_empty() { - accumulated.push_str(content); - } - } - - // handle tool call delta - if let Some(tc_arr) = delta["tool_calls"].as_array() { - has_tool_calls = true; - for tc in tc_arr { - let idx = tc["index"].as_u64().unwrap_or(0) as usize; - while tool_calls.len() <= idx { - tool_calls.push(ToolCall { - id: String::new(), - name: String::new(), - arguments: String::new(), - }); - } - if let Some(id) = tc["id"].as_str() { - tool_calls[idx].id = id.to_string(); - } - if let Some(name) = tc["function"]["name"].as_str() { - tool_calls[idx].name = name.to_string(); - } - if let Some(args) = tc["function"]["arguments"].as_str() { - tool_calls[idx].arguments.push_str(args); - } - } - } - - // display update (only when there's content to show) - if accumulated.is_empty() { - continue; - } - - { - - let interval = if use_draft { - DRAFT_INTERVAL_MS - } else { - EDIT_INTERVAL_MS - }; - if last_edit.elapsed().as_millis() < interval as u128 { - continue; - } - - let display = if use_draft { - truncate_at_char_boundary(&accumulated, TG_MSG_LIMIT).to_string() - } else { - truncate_for_display(&accumulated) - }; - - if use_draft { - match send_message_draft( - &client, &token, raw_chat_id, draft_id, &display, - ) - .await - { - Ok(_) => { - last_edit = Instant::now(); - } - Err(e) => { - warn!("sendMessageDraft failed, falling back: {e:#}"); - use_draft = false; - if let Ok(sent) = - bot.send_message(chat_id, &display).await - { - msg_id = Some(sent.id); - last_edit = Instant::now(); - } - } - } - } else if let Some(id) = msg_id { - if bot - .edit_message_text(chat_id, id, &display) - .await - .is_ok() - { - last_edit = Instant::now(); - } - } else if let Ok(sent) = - bot.send_message(chat_id, &display).await - { - msg_id = Some(sent.id); - last_edit = Instant::now(); - } - } // end display block - } - } - } - - // decide what to do based on response type - if has_tool_calls && !tool_calls.is_empty() { - // append assistant message with tool calls - let tc_json: Vec = tool_calls - .iter() - .map(|tc| { - serde_json::json!({ - "id": tc.id, - "type": "function", - "function": { - "name": tc.name, - "arguments": tc.arguments, - } - }) - }) - .collect(); - - let assistant_msg = serde_json::json!({ - "role": "assistant", - "content": if accumulated.is_empty() { "" } else { &accumulated }, - "tool_calls": tc_json, - }); - messages.push(assistant_msg); - - // execute each tool - for tc in &tool_calls { - info!(tool = %tc.name, "executing tool call"); - let _ = bot - .send_message(chat_id, format!("[{}({})]", tc.name, truncate_at_char_boundary(&tc.arguments, 100))) - .await; - - let result = - execute_tool(&tc.name, &tc.arguments, state, bot, chat_id, sid, config) - .await; - - messages.push(serde_json::json!({ - "role": "tool", - "tool_call_id": tc.id, - "content": result, - })); - } - - // clear display state for next round - tool_calls.clear(); - // loop back to call API again - continue; - } - - // content response — send final result - if !accumulated.is_empty() { - send_final_result(bot, chat_id, msg_id, use_draft, &accumulated).await; - } - - return Ok(accumulated); - } -} - -// ── claude bridge (streaming) ─────────────────────────────────────── - -/// Stream JSON event types we care about. -#[derive(Deserialize)] -struct StreamEvent { - #[serde(rename = "type")] - event_type: String, - message: Option, - result: Option, - #[serde(default)] - is_error: bool, -} - -#[derive(Deserialize)] -struct AssistantMessage { - content: Vec, -} - -#[derive(Deserialize)] -struct ContentBlock { - #[serde(rename = "type")] - block_type: String, - text: Option, - name: Option, - input: Option, -} - -/// Extract all text from an assistant message's content blocks. -fn extract_text(msg: &AssistantMessage) -> String { - msg.content - .iter() - .filter(|b| b.block_type == "text") - .filter_map(|b| b.text.as_deref()) - .collect::>() - .join("") -} - -/// Extract tool use status line, e.g. "Bash: echo hello" -fn extract_tool_use(msg: &AssistantMessage) -> Option { - for block in &msg.content { - if block.block_type == "tool_use" { - let name = block.name.as_deref().unwrap_or("tool"); - let detail = block - .input - .as_ref() - .and_then(|v| { - // try common fields: command, pattern, file_path, query - v.get("command") - .or(v.get("pattern")) - .or(v.get("file_path")) - .or(v.get("query")) - .or(v.get("prompt")) - .and_then(|s| s.as_str()) - }) - .unwrap_or(""); - let detail_short = truncate_at_char_boundary(detail, 80); - return Some(format!("{name}: {detail_short}")); - } - } - None -} - -const EDIT_INTERVAL_MS: u64 = 2000; -const DRAFT_INTERVAL_MS: u64 = 1000; -const TG_MSG_LIMIT: usize = 4096; - -async fn invoke_claude_streaming( - sid: &str, - prompt: &str, - known: bool, - bot: &Bot, - chat_id: ChatId, -) -> Result { - if known { - return run_claude_streaming(&["--resume", sid], prompt, bot, chat_id).await; - } - - match run_claude_streaming(&["--resume", sid], prompt, bot, chat_id).await { - Ok(out) => { - info!(%sid, "resumed existing session"); - Ok(out) - } - Err(e) => { - warn!(%sid, "resume failed ({e:#}), creating new session"); - run_claude_streaming(&["--session-id", sid], prompt, bot, chat_id).await - } - } -} - -async fn send_message_draft( - client: &reqwest::Client, - token: &str, - chat_id: i64, - draft_id: i64, - text: &str, -) -> Result<()> { - let url = format!("https://api.telegram.org/bot{token}/sendMessageDraft"); - let resp = client - .post(&url) - .json(&serde_json::json!({ - "chat_id": chat_id, - "draft_id": draft_id, - "text": text, - })) - .send() - .await?; - let body: serde_json::Value = resp.json().await?; - if body["ok"].as_bool() != Some(true) { - anyhow::bail!("sendMessageDraft: {}", body); - } - Ok(()) -} - -async fn run_claude_streaming( - extra_args: &[&str], - prompt: &str, - bot: &Bot, - chat_id: ChatId, -) -> Result { - let mut args: Vec<&str> = vec![ - "--dangerously-skip-permissions", - "-p", - "--output-format", - "stream-json", - "--verbose", - ]; - args.extend(extra_args); - args.push(prompt); - - let mut child = Command::new("claude") - .args(&args) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; - - let stdout = child.stdout.take().unwrap(); - let mut lines = tokio::io::BufReader::new(stdout).lines(); - - // sendMessageDraft for native streaming, with editMessageText fallback - let http = reqwest::Client::new(); - let token = bot.token().to_owned(); - let raw_chat_id = chat_id.0; - let draft_id: i64 = 1; - let mut use_draft = true; - - let mut msg_id: Option = None; - let mut last_sent_text = String::new(); - let mut last_edit = Instant::now(); - let mut final_result = String::new(); - let mut is_error = false; - let mut tool_status = String::new(); - - while let Ok(Some(line)) = lines.next_line().await { - let event: StreamEvent = match serde_json::from_str(&line) { - Ok(e) => e, - Err(_) => continue, - }; - - match event.event_type.as_str() { - "assistant" => { - if let Some(amsg) = &event.message { - // determine display content - let (display_raw, new_text) = - if let Some(status) = extract_tool_use(amsg) { - tool_status = format!("[{status}]"); - let d = if last_sent_text.is_empty() { - tool_status.clone() - } else { - format!("{last_sent_text}\n\n{tool_status}") - }; - (d, None) - } else { - let text = extract_text(amsg); - if text.is_empty() || text == last_sent_text { - continue; - } - let interval = if use_draft { - DRAFT_INTERVAL_MS - } else { - EDIT_INTERVAL_MS - }; - if last_edit.elapsed().as_millis() < interval as u128 { - continue; - } - tool_status.clear(); - (text.clone(), Some(text)) - }; - - let display = if use_draft { - // draft mode: no cursor — cursor breaks monotonic text growth - truncate_at_char_boundary(&display_raw, TG_MSG_LIMIT).to_string() - } else { - truncate_for_display(&display_raw) - }; - - if use_draft { - match send_message_draft( - &http, &token, raw_chat_id, draft_id, &display, - ) - .await - { - Ok(_) => { - if let Some(t) = new_text { - last_sent_text = t; - } - last_edit = Instant::now(); - } - Err(e) => { - warn!("sendMessageDraft failed, falling back: {e:#}"); - use_draft = false; - if let Ok(sent) = - bot.send_message(chat_id, &display).await - { - msg_id = Some(sent.id); - if let Some(t) = new_text { - last_sent_text = t; - } - last_edit = Instant::now(); - } - } - } - } else if let Some(id) = msg_id { - if bot - .edit_message_text(chat_id, id, &display) - .await - .is_ok() - { - if let Some(t) = new_text { - last_sent_text = t; - } - last_edit = Instant::now(); - } - } else if let Ok(sent) = - bot.send_message(chat_id, &display).await - { - msg_id = Some(sent.id); - if let Some(t) = new_text { - last_sent_text = t; - } - last_edit = Instant::now(); - } - } - } - "result" => { - final_result = event.result.unwrap_or_default(); - is_error = event.is_error; - } - _ => {} - } - } - - // read stderr before waiting (in case child already exited) - let stderr_handle = child.stderr.take(); - let status = child.wait().await; - - // collect stderr for diagnostics - let stderr_text = if let Some(mut se) = stderr_handle { - let mut buf = String::new(); - let _ = tokio::io::AsyncReadExt::read_to_string(&mut se, &mut buf).await; - buf - } else { - String::new() - }; - - // determine error: explicit is_error from stream, or non-zero exit with no result - let has_error = is_error - || (final_result.is_empty() - && status.as_ref().map(|s| !s.success()).unwrap_or(true)); - - if has_error { - let err_detail = if !final_result.is_empty() { - final_result.clone() - } else if !stderr_text.is_empty() { - stderr_text.trim().to_string() - } else { - format!("claude exited: {:?}", status) - }; - if !use_draft { - if let Some(id) = msg_id { - let _ = bot - .edit_message_text(chat_id, id, format!("[error] {err_detail}")) - .await; - } - } - anyhow::bail!("{err_detail}"); - } - - if final_result.is_empty() { - return Ok(final_result); - } - - send_final_result(bot, chat_id, msg_id, use_draft, &final_result).await; - - Ok(final_result) -} - -// ── openai-compatible backend (streaming) ────────────────────────── - -async fn run_openai_streaming( - endpoint: &str, - model: &str, - api_key: &str, - messages: &[serde_json::Value], - bot: &Bot, - chat_id: ChatId, -) -> Result { - let client = reqwest::Client::new(); - let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); - - let body = serde_json::json!({ - "model": model, - "messages": messages, - "stream": true, - }); - - let mut resp = client - .post(&url) - .header("Authorization", format!("Bearer {api_key}")) - .json(&body) - .send() - .await? - .error_for_status()?; - - let token = bot.token().to_owned(); - let raw_chat_id = chat_id.0; - let draft_id: i64 = 1; - let mut use_draft = true; - - let mut msg_id: Option = None; - let mut accumulated = String::new(); - let mut last_edit = Instant::now(); - let mut buffer = String::new(); - let mut done = false; - - while let Some(chunk) = resp.chunk().await? { - if done { - break; - } - buffer.push_str(&String::from_utf8_lossy(&chunk)); - - while let Some(pos) = buffer.find('\n') { - let line = buffer[..pos].to_string(); - buffer = buffer[pos + 1..].to_string(); - - let trimmed = line.trim(); - if trimmed.is_empty() || trimmed.starts_with(':') { - continue; - } - - let data = match trimmed.strip_prefix("data: ") { - Some(d) => d, - None => continue, - }; - - if data.trim() == "[DONE]" { - done = true; - break; - } - - if let Ok(json) = serde_json::from_str::(data) { - if let Some(content) = json["choices"][0]["delta"]["content"].as_str() { - if content.is_empty() { - continue; - } - accumulated.push_str(content); - - let interval = if use_draft { - DRAFT_INTERVAL_MS - } else { - EDIT_INTERVAL_MS - }; - if last_edit.elapsed().as_millis() < interval as u128 { - continue; - } - - let display = if use_draft { - truncate_at_char_boundary(&accumulated, TG_MSG_LIMIT).to_string() - } else { - truncate_for_display(&accumulated) - }; - - if use_draft { - match send_message_draft( - &client, &token, raw_chat_id, draft_id, &display, - ) - .await - { - Ok(_) => { - last_edit = Instant::now(); - } - Err(e) => { - warn!("sendMessageDraft failed, falling back: {e:#}"); - use_draft = false; - if let Ok(sent) = bot.send_message(chat_id, &display).await { - msg_id = Some(sent.id); - last_edit = Instant::now(); - } - } - } - } else if let Some(id) = msg_id { - if bot.edit_message_text(chat_id, id, &display).await.is_ok() { - last_edit = Instant::now(); - } - } else if let Ok(sent) = bot.send_message(chat_id, &display).await { - msg_id = Some(sent.id); - last_edit = Instant::now(); - } - } - } - } - } - - if accumulated.is_empty() { - return Ok(accumulated); - } - - send_final_result(bot, chat_id, msg_id, use_draft, &accumulated).await; - - Ok(accumulated) -} - -const CURSOR: &str = " \u{25CE}"; - -fn truncate_for_display(s: &str) -> String { - let budget = TG_MSG_LIMIT - CURSOR.len() - 1; - if s.len() <= budget { - format!("{s}{CURSOR}") - } else { - let truncated = truncate_at_char_boundary(s, budget - 2); - format!("{truncated}\n…{CURSOR}") - } -} - -fn truncate_at_char_boundary(s: &str, max: usize) -> &str { - if s.len() <= max { - return s; - } - let mut end = max; - while !s.is_char_boundary(end) { - end -= 1; - } - &s[..end] -} - -fn escape_html(s: &str) -> String { - s.replace('&', "&") - .replace('<', "<") - .replace('>', ">") - .replace('"', """) -} - -fn markdown_to_telegram_html(md: &str) -> String { - use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd}; - - let mut opts = Options::empty(); - opts.insert(Options::ENABLE_STRIKETHROUGH); - - let parser = Parser::new_ext(md, opts); - let mut html = String::new(); - - for event in parser { - match event { - Event::Start(tag) => match tag { - Tag::Paragraph => {} - Tag::Heading { .. } => html.push_str(""), - Tag::BlockQuote(_) => html.push_str("
"), - Tag::CodeBlock(kind) => match kind { - CodeBlockKind::Fenced(ref lang) if !lang.is_empty() => { - html.push_str(&format!( - "
",
-                            escape_html(lang.as_ref())
-                        ));
-                    }
-                    _ => html.push_str("
"),
-                },
-                Tag::Item => html.push_str("• "),
-                Tag::Emphasis => html.push_str(""),
-                Tag::Strong => html.push_str(""),
-                Tag::Strikethrough => html.push_str(""),
-                Tag::Link { dest_url, .. } => {
-                    html.push_str(&format!(
-                        "",
-                        escape_html(dest_url.as_ref())
-                    ));
-                }
-                _ => {}
-            },
-            Event::End(tag) => match tag {
-                TagEnd::Paragraph => html.push_str("\n\n"),
-                TagEnd::Heading(_) => html.push_str("\n\n"),
-                TagEnd::BlockQuote(_) => html.push_str("
"), - TagEnd::CodeBlock => html.push_str("\n\n"), - TagEnd::List(_) => html.push('\n'), - TagEnd::Item => html.push('\n'), - TagEnd::Emphasis => html.push_str(""), - TagEnd::Strong => html.push_str("
"), - TagEnd::Strikethrough => html.push_str(""), - TagEnd::Link => html.push_str(""), - _ => {} - }, - Event::Text(text) => html.push_str(&escape_html(text.as_ref())), - Event::Code(text) => { - html.push_str(""); - html.push_str(&escape_html(text.as_ref())); - html.push_str(""); - } - Event::SoftBreak | Event::HardBreak => html.push('\n'), - Event::Rule => html.push_str("\n---\n\n"), - _ => {} - } - } - - html.trim_end().to_string() -} - -/// Send final result with HTML formatting, fallback to plain text on failure. -async fn send_final_result( - bot: &Bot, - chat_id: ChatId, - msg_id: Option, - use_draft: bool, - result: &str, -) { - use teloxide::types::ParseMode; - - let html = markdown_to_telegram_html(result); - - // try HTML as single message - let html_ok = if let (false, Some(id)) = (use_draft, msg_id) { - bot.edit_message_text(chat_id, id, &html) - .parse_mode(ParseMode::Html) - .await - .is_ok() - } else { - bot.send_message(chat_id, &html) - .parse_mode(ParseMode::Html) - .await - .is_ok() - }; - - if html_ok { - return; - } - - // fallback: plain text with chunking - let chunks = split_msg(result, TG_MSG_LIMIT); - if let (false, Some(id)) = (use_draft, msg_id) { - let _ = bot.edit_message_text(chat_id, id, chunks[0]).await; - for chunk in &chunks[1..] { - let _ = bot.send_message(chat_id, *chunk).await; - } - } else { - for chunk in &chunks { - let _ = bot.send_message(chat_id, *chunk).await; - } - } -} - -fn split_msg(s: &str, max: usize) -> Vec<&str> { - if s.len() <= max { - return vec![s]; - } - let mut parts = Vec::new(); - let mut rest = s; - while !rest.is_empty() { - if rest.len() <= max { - parts.push(rest); - break; - } - let mut end = max; - while !rest.is_char_boundary(end) { - end -= 1; - } - let (chunk, tail) = rest.split_at(end); - parts.push(chunk); - rest = tail; - } - parts -} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..1fa6ba4 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,358 @@ +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use anyhow::Result; +use chrono::NaiveDate; +use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; +use tracing::{error, info}; + +use crate::tools::SubAgent; + +// ── persistent state ──────────────────────────────────────────────── + +#[derive(Serialize, Deserialize, Default)] +pub struct Persistent { + pub authed: HashMap, + pub known_sessions: HashSet, +} + +#[derive(Serialize, Deserialize, Clone, Default)] +pub struct ConversationState { + pub summary: String, + pub messages: Vec, + pub total_messages: usize, +} + +pub const MAX_WINDOW: usize = 100; +pub const SLIDE_SIZE: usize = 50; + +pub struct AppState { + pub persist: RwLock, + pub state_path: PathBuf, + pub db: tokio::sync::Mutex, + pub agents: RwLock>>, +} + +impl AppState { + pub fn load(path: PathBuf) -> Self { + let persist = std::fs::read_to_string(&path) + .ok() + .and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or_default(); + info!("loaded state from {}", path.display()); + + let db_path = path.parent().unwrap_or(Path::new(".")).join("noc.db"); + let conn = rusqlite::Connection::open(&db_path) + .unwrap_or_else(|e| panic!("open {}: {e}", db_path.display())); + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS conversations ( + session_id TEXT PRIMARY KEY, + summary TEXT NOT NULL DEFAULT '', + total_messages INTEGER NOT NULL DEFAULT 0 + ); + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now', 'localtime')) + ); + CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id); + CREATE TABLE IF NOT EXISTS scratch_area ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + content TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')) + ); + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL DEFAULT '', + create_time TEXT NOT NULL DEFAULT (datetime('now')), + update_time TEXT NOT NULL DEFAULT (datetime('now')) + ); + CREATE TABLE IF NOT EXISTS config_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key TEXT NOT NULL, + value TEXT NOT NULL, + create_time TEXT NOT NULL, + update_time TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS memory_slots ( + slot_nr INTEGER PRIMARY KEY CHECK(slot_nr BETWEEN 0 AND 99), + content TEXT NOT NULL DEFAULT '' + ); + CREATE TABLE IF NOT EXISTS timers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + chat_id INTEGER NOT NULL, + label TEXT NOT NULL, + schedule TEXT NOT NULL, + next_fire TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL DEFAULT (datetime('now', 'localtime')) + ); + CREATE TABLE IF NOT EXISTS inner_state ( + id INTEGER PRIMARY KEY CHECK(id = 1), + content TEXT NOT NULL DEFAULT '' + ); + INSERT OR IGNORE INTO inner_state (id, content) VALUES (1, '');", + ) + .expect("init db schema"); + + // migrations + let _ = conn.execute( + "ALTER TABLE messages ADD COLUMN created_at TEXT NOT NULL DEFAULT ''", + [], + ); + + info!("opened db {}", db_path.display()); + + Self { + persist: RwLock::new(persist), + state_path: path, + db: tokio::sync::Mutex::new(conn), + agents: RwLock::new(HashMap::new()), + } + } + + pub async fn save(&self) { + let data = self.persist.read().await; + if let Ok(json) = serde_json::to_string_pretty(&*data) { + if let Err(e) = std::fs::write(&self.state_path, json) { + error!("save state: {e}"); + } + } + } + + pub async fn load_conv(&self, sid: &str) -> ConversationState { + let db = self.db.lock().await; + let (summary, total) = db + .query_row( + "SELECT summary, total_messages FROM conversations WHERE session_id = ?1", + [sid], + |row| Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?)), + ) + .unwrap_or_default(); + + let mut stmt = db + .prepare("SELECT role, content, created_at FROM messages WHERE session_id = ?1 ORDER BY id") + .unwrap(); + let messages: Vec = stmt + .query_map([sid], |row| { + let role: String = row.get(0)?; + let content: String = row.get(1)?; + let ts: String = row.get(2)?; + let tagged = if ts.is_empty() { + content + } else { + format!("[{ts}] {content}") + }; + Ok(serde_json::json!({"role": role, "content": tagged})) + }) + .unwrap() + .filter_map(|r| r.ok()) + .collect(); + + ConversationState { + summary, + messages, + total_messages: total, + } + } + + pub async fn push_message(&self, sid: &str, role: &str, content: &str) { + let db = self.db.lock().await; + let _ = db.execute( + "INSERT OR IGNORE INTO conversations (session_id) VALUES (?1)", + [sid], + ); + let _ = db.execute( + "INSERT INTO messages (session_id, role, content, created_at) VALUES (?1, ?2, ?3, datetime('now', 'localtime'))", + rusqlite::params![sid, role, content], + ); + } + + pub async fn message_count(&self, sid: &str) -> usize { + let db = self.db.lock().await; + db.query_row( + "SELECT COUNT(*) FROM messages WHERE session_id = ?1", + [sid], + |row| row.get(0), + ) + .unwrap_or(0) + } + + pub async fn slide_window(&self, sid: &str, new_summary: &str, slide_size: usize) { + let db = self.db.lock().await; + let _ = db.execute( + "DELETE FROM messages WHERE id IN ( + SELECT id FROM messages WHERE session_id = ?1 ORDER BY id LIMIT ?2 + )", + rusqlite::params![sid, slide_size], + ); + let _ = db.execute( + "UPDATE conversations SET summary = ?1, total_messages = total_messages + ?2 \ + WHERE session_id = ?3", + rusqlite::params![new_summary, slide_size, sid], + ); + } + + pub async fn get_oldest_messages(&self, sid: &str, count: usize) -> Vec { + let db = self.db.lock().await; + let mut stmt = db + .prepare( + "SELECT role, content FROM messages WHERE session_id = ?1 ORDER BY id LIMIT ?2", + ) + .unwrap(); + stmt.query_map(rusqlite::params![sid, count], |row| { + let role: String = row.get(0)?; + let content: String = row.get(1)?; + Ok(serde_json::json!({"role": role, "content": content})) + }) + .unwrap() + .filter_map(|r| r.ok()) + .collect() + } + + pub async fn get_scratch(&self) -> String { + let db = self.db.lock().await; + db.query_row( + "SELECT content FROM scratch_area ORDER BY id DESC LIMIT 1", + [], + |row| row.get(0), + ) + .unwrap_or_default() + } + + pub async fn push_scratch(&self, content: &str) { + let db = self.db.lock().await; + let _ = db.execute( + "INSERT INTO scratch_area (content) VALUES (?1)", + [content], + ); + } + + pub async fn get_config(&self, key: &str) -> Option { + let db = self.db.lock().await; + db.query_row( + "SELECT value FROM config WHERE key = ?1", + [key], + |row| row.get(0), + ) + .ok() + } + + pub async fn get_inner_state(&self) -> String { + let db = self.db.lock().await; + db.query_row("SELECT content FROM inner_state WHERE id = 1", [], |row| row.get(0)) + .unwrap_or_default() + } + + #[allow(dead_code)] // used by life loop tools (coming soon) + pub async fn set_inner_state(&self, content: &str) { + let db = self.db.lock().await; + let _ = db.execute( + "UPDATE inner_state SET content = ?1 WHERE id = 1", + [content], + ); + } + + pub async fn add_timer(&self, chat_id: i64, label: &str, schedule: &str, next_fire: &str) -> i64 { + let db = self.db.lock().await; + db.execute( + "INSERT INTO timers (chat_id, label, schedule, next_fire) VALUES (?1, ?2, ?3, ?4)", + rusqlite::params![chat_id, label, schedule, next_fire], + ) + .unwrap(); + db.last_insert_rowid() + } + + pub async fn list_timers(&self, chat_id: Option) -> Vec<(i64, i64, String, String, String, bool)> { + let db = self.db.lock().await; + let (sql, params): (&str, Vec>) = match chat_id { + Some(cid) => ( + "SELECT id, chat_id, label, schedule, next_fire, enabled FROM timers WHERE chat_id = ?1 ORDER BY next_fire", + vec![Box::new(cid)], + ), + None => ( + "SELECT id, chat_id, label, schedule, next_fire, enabled FROM timers ORDER BY next_fire", + vec![], + ), + }; + let mut stmt = db.prepare(sql).unwrap(); + stmt.query_map(rusqlite::params_from_iter(params), |row| { + Ok(( + row.get(0)?, + row.get(1)?, + row.get::<_, String>(2)?, + row.get::<_, String>(3)?, + row.get::<_, String>(4)?, + row.get::<_, bool>(5)?, + )) + }) + .unwrap() + .filter_map(|r| r.ok()) + .collect() + } + + pub async fn cancel_timer(&self, timer_id: i64) -> bool { + let db = self.db.lock().await; + db.execute("DELETE FROM timers WHERE id = ?1", [timer_id]).unwrap() > 0 + } + + pub async fn due_timers(&self) -> Vec<(i64, i64, String, String)> { + let db = self.db.lock().await; + let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string(); + let mut stmt = db + .prepare( + "SELECT id, chat_id, label, schedule FROM timers WHERE enabled = 1 AND next_fire <= ?1", + ) + .unwrap(); + stmt.query_map([&now], |row| { + Ok(( + row.get(0)?, + row.get(1)?, + row.get::<_, String>(2)?, + row.get::<_, String>(3)?, + )) + }) + .unwrap() + .filter_map(|r| r.ok()) + .collect() + } + + pub async fn update_timer_next_fire(&self, timer_id: i64, next_fire: &str) { + let db = self.db.lock().await; + let _ = db.execute( + "UPDATE timers SET next_fire = ?1 WHERE id = ?2", + rusqlite::params![next_fire, timer_id], + ); + } + + pub async fn get_memory_slots(&self) -> Vec<(i32, String)> { + let db = self.db.lock().await; + let mut stmt = db + .prepare("SELECT slot_nr, content FROM memory_slots WHERE content != '' ORDER BY slot_nr") + .unwrap(); + stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?))) + .unwrap() + .filter_map(|r| r.ok()) + .collect() + } + + pub async fn set_memory_slot(&self, slot_nr: i32, content: &str) -> Result<()> { + if !(0..=99).contains(&slot_nr) { + anyhow::bail!("slot_nr must be 0-99, got {slot_nr}"); + } + if content.len() > 200 { + anyhow::bail!("content too long: {} chars (max 200)", content.len()); + } + let db = self.db.lock().await; + db.execute( + "INSERT INTO memory_slots (slot_nr, content) VALUES (?1, ?2) \ + ON CONFLICT(slot_nr) DO UPDATE SET content = ?2", + rusqlite::params![slot_nr, content], + )?; + Ok(()) + } +} diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..1429f76 --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,776 @@ +use std::process::Stdio; +use std::sync::Arc; + +use anyhow::Result; +use serde::Deserialize; +use teloxide::prelude::*; +use tokio::io::AsyncBufReadExt; +use tokio::process::Command; +use tokio::time::Instant; +use tracing::{error, info, warn}; + +use crate::config::Config; +use crate::display::{ + send_final_result, truncate_at_char_boundary, truncate_for_display, +}; +use crate::state::AppState; +use crate::tools::{discover_tools, execute_tool, ToolCall}; + +pub const EDIT_INTERVAL_MS: u64 = 2000; +pub const DRAFT_INTERVAL_MS: u64 = 1000; +pub const TG_MSG_LIMIT: usize = 4096; +pub const CURSOR: &str = " \u{25CE}"; + +/// Stream JSON event types we care about. +#[derive(Deserialize)] +pub struct StreamEvent { + #[serde(rename = "type")] + pub event_type: String, + pub message: Option, + pub result: Option, + #[serde(default)] + pub is_error: bool, +} + +#[derive(Deserialize)] +pub struct AssistantMessage { + pub content: Vec, +} + +#[derive(Deserialize)] +pub struct ContentBlock { + #[serde(rename = "type")] + pub block_type: String, + pub text: Option, + pub name: Option, + pub input: Option, +} + +/// Extract all text from an assistant message's content blocks. +pub fn extract_text(msg: &AssistantMessage) -> String { + msg.content + .iter() + .filter(|b| b.block_type == "text") + .filter_map(|b| b.text.as_deref()) + .collect::>() + .join("") +} + +/// Extract tool use status line, e.g. "Bash: echo hello" +pub fn extract_tool_use(msg: &AssistantMessage) -> Option { + for block in &msg.content { + if block.block_type == "tool_use" { + let name = block.name.as_deref().unwrap_or("tool"); + let detail = block + .input + .as_ref() + .and_then(|v| { + // try common fields: command, pattern, file_path, query + v.get("command") + .or(v.get("pattern")) + .or(v.get("file_path")) + .or(v.get("query")) + .or(v.get("prompt")) + .and_then(|s| s.as_str()) + }) + .unwrap_or(""); + let detail_short = truncate_at_char_boundary(detail, 80); + return Some(format!("{name}: {detail_short}")); + } + } + None +} + +pub async fn send_message_draft( + client: &reqwest::Client, + token: &str, + chat_id: i64, + draft_id: i64, + text: &str, +) -> Result<()> { + let url = format!("https://api.telegram.org/bot{token}/sendMessageDraft"); + let resp = client + .post(&url) + .json(&serde_json::json!({ + "chat_id": chat_id, + "draft_id": draft_id, + "text": text, + })) + .send() + .await?; + let body: serde_json::Value = resp.json().await?; + if body["ok"].as_bool() != Some(true) { + anyhow::bail!("sendMessageDraft: {}", body); + } + Ok(()) +} + +// ── openai with tool call loop ───────────────────────────────────── + +#[allow(clippy::too_many_arguments)] +pub async fn run_openai_with_tools( + endpoint: &str, + model: &str, + api_key: &str, + mut messages: Vec, + bot: &Bot, + chat_id: ChatId, + state: &Arc, + sid: &str, + config: &Arc, + is_private: bool, +) -> Result { + let client = reqwest::Client::new(); + let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + let tools = discover_tools(); + + loop { + let body = serde_json::json!({ + "model": model, + "messages": messages, + "tools": tools, + "stream": true, + }); + + info!("API request: {} messages, {} tools", + messages.len(), + tools.as_array().map(|a| a.len()).unwrap_or(0)); + + let resp_raw = client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .json(&body) + .send() + .await?; + + if !resp_raw.status().is_success() { + let status = resp_raw.status(); + let body_text = resp_raw.text().await.unwrap_or_default(); + // dump messages for debugging + for (i, m) in messages.iter().enumerate() { + let role = m["role"].as_str().unwrap_or("?"); + let content_len = m["content"].as_str().map(|s| s.len()).unwrap_or(0); + let has_tc = m.get("tool_calls").is_some(); + let has_tcid = m.get("tool_call_id").is_some(); + warn!(" msg[{i}] role={role} content_len={content_len} tool_calls={has_tc} tool_call_id={has_tcid}"); + } + error!("OpenAI API {status}: {body_text}"); + anyhow::bail!("OpenAI API {status}: {body_text}"); + } + + let mut resp = resp_raw; + + let token = bot.token().to_owned(); + let raw_chat_id = chat_id.0; + let draft_id: i64 = 1; + let mut use_draft = is_private; // sendMessageDraft only works in private chats + + let mut msg_id: Option = None; + let mut accumulated = String::new(); + let mut last_edit = Instant::now(); + let mut buffer = String::new(); + let mut done = false; + + // tool call accumulation + let mut tool_calls: Vec = Vec::new(); + let mut has_tool_calls = false; + + while let Some(chunk) = resp.chunk().await? { + if done { + break; + } + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(pos) = buffer.find('\n') { + let line = buffer[..pos].to_string(); + buffer = buffer[pos + 1..].to_string(); + + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with(':') { + continue; + } + + let data = match trimmed.strip_prefix("data: ") { + Some(d) => d, + None => continue, + }; + + if data.trim() == "[DONE]" { + done = true; + break; + } + + if let Ok(json) = serde_json::from_str::(data) { + let delta = &json["choices"][0]["delta"]; + + // handle content delta + if let Some(content) = delta["content"].as_str() { + if !content.is_empty() { + accumulated.push_str(content); + } + } + + // handle tool call delta + if let Some(tc_arr) = delta["tool_calls"].as_array() { + has_tool_calls = true; + for tc in tc_arr { + let idx = tc["index"].as_u64().unwrap_or(0) as usize; + while tool_calls.len() <= idx { + tool_calls.push(ToolCall { + id: String::new(), + name: String::new(), + arguments: String::new(), + }); + } + if let Some(id) = tc["id"].as_str() { + tool_calls[idx].id = id.to_string(); + } + if let Some(name) = tc["function"]["name"].as_str() { + tool_calls[idx].name = name.to_string(); + } + if let Some(args) = tc["function"]["arguments"].as_str() { + tool_calls[idx].arguments.push_str(args); + } + } + } + + // display update (only when there's content to show) + if accumulated.is_empty() { + continue; + } + + { + + let interval = if use_draft { + DRAFT_INTERVAL_MS + } else { + EDIT_INTERVAL_MS + }; + if last_edit.elapsed().as_millis() < interval as u128 { + continue; + } + + let display = if use_draft { + truncate_at_char_boundary(&accumulated, TG_MSG_LIMIT).to_string() + } else { + truncate_for_display(&accumulated) + }; + + if use_draft { + match send_message_draft( + &client, &token, raw_chat_id, draft_id, &display, + ) + .await + { + Ok(_) => { + last_edit = Instant::now(); + } + Err(e) => { + warn!("sendMessageDraft failed, falling back: {e:#}"); + use_draft = false; + if let Ok(sent) = + bot.send_message(chat_id, &display).await + { + msg_id = Some(sent.id); + last_edit = Instant::now(); + } + } + } + } else if let Some(id) = msg_id { + if bot + .edit_message_text(chat_id, id, &display) + .await + .is_ok() + { + last_edit = Instant::now(); + } + } else if let Ok(sent) = + bot.send_message(chat_id, &display).await + { + msg_id = Some(sent.id); + last_edit = Instant::now(); + } + } // end display block + } + } + } + + // decide what to do based on response type + if has_tool_calls && !tool_calls.is_empty() { + // append assistant message with tool calls + let tc_json: Vec = tool_calls + .iter() + .map(|tc| { + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": tc.arguments, + } + }) + }) + .collect(); + + let assistant_msg = serde_json::json!({ + "role": "assistant", + "content": if accumulated.is_empty() { "" } else { &accumulated }, + "tool_calls": tc_json, + }); + messages.push(assistant_msg); + + // execute each tool + for tc in &tool_calls { + info!(tool = %tc.name, "executing tool call"); + let _ = bot + .send_message(chat_id, format!("[{}({})]", tc.name, truncate_at_char_boundary(&tc.arguments, 100))) + .await; + + let result = + execute_tool(&tc.name, &tc.arguments, state, bot, chat_id, sid, config) + .await; + + messages.push(serde_json::json!({ + "role": "tool", + "tool_call_id": tc.id, + "content": result, + })); + } + + // clear display state for next round + tool_calls.clear(); + // loop back to call API again + continue; + } + + // content response — send final result + if !accumulated.is_empty() { + send_final_result(bot, chat_id, msg_id, use_draft, &accumulated).await; + } + + return Ok(accumulated); + } +} + +// ── claude bridge (streaming) ─────────────────────────────────────── + +pub async fn invoke_claude_streaming( + sid: &str, + prompt: &str, + known: bool, + bot: &Bot, + chat_id: ChatId, +) -> Result { + if known { + return run_claude_streaming(&["--resume", sid], prompt, bot, chat_id).await; + } + + match run_claude_streaming(&["--resume", sid], prompt, bot, chat_id).await { + Ok(out) => { + info!(%sid, "resumed existing session"); + Ok(out) + } + Err(e) => { + warn!(%sid, "resume failed ({e:#}), creating new session"); + run_claude_streaming(&["--session-id", sid], prompt, bot, chat_id).await + } + } +} + +pub async fn run_claude_streaming( + extra_args: &[&str], + prompt: &str, + bot: &Bot, + chat_id: ChatId, +) -> Result { + let mut args: Vec<&str> = vec![ + "--dangerously-skip-permissions", + "-p", + "--output-format", + "stream-json", + "--verbose", + ]; + args.extend(extra_args); + args.push(prompt); + + let mut child = Command::new("claude") + .args(&args) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + let stdout = child.stdout.take().unwrap(); + let mut lines = tokio::io::BufReader::new(stdout).lines(); + + // sendMessageDraft for native streaming, with editMessageText fallback + let http = reqwest::Client::new(); + let token = bot.token().to_owned(); + let raw_chat_id = chat_id.0; + let draft_id: i64 = 1; + let mut use_draft = true; + + let mut msg_id: Option = None; + let mut last_sent_text = String::new(); + let mut last_edit = Instant::now(); + let mut final_result = String::new(); + let mut is_error = false; + let mut tool_status = String::new(); + + while let Ok(Some(line)) = lines.next_line().await { + let event: StreamEvent = match serde_json::from_str(&line) { + Ok(e) => e, + Err(_) => continue, + }; + + match event.event_type.as_str() { + "assistant" => { + if let Some(amsg) = &event.message { + // determine display content + let (display_raw, new_text) = + if let Some(status) = extract_tool_use(amsg) { + tool_status = format!("[{status}]"); + let d = if last_sent_text.is_empty() { + tool_status.clone() + } else { + format!("{last_sent_text}\n\n{tool_status}") + }; + (d, None) + } else { + let text = extract_text(amsg); + if text.is_empty() || text == last_sent_text { + continue; + } + let interval = if use_draft { + DRAFT_INTERVAL_MS + } else { + EDIT_INTERVAL_MS + }; + if last_edit.elapsed().as_millis() < interval as u128 { + continue; + } + tool_status.clear(); + (text.clone(), Some(text)) + }; + + let display = if use_draft { + // draft mode: no cursor — cursor breaks monotonic text growth + truncate_at_char_boundary(&display_raw, TG_MSG_LIMIT).to_string() + } else { + truncate_for_display(&display_raw) + }; + + if use_draft { + match send_message_draft( + &http, &token, raw_chat_id, draft_id, &display, + ) + .await + { + Ok(_) => { + if let Some(t) = new_text { + last_sent_text = t; + } + last_edit = Instant::now(); + } + Err(e) => { + warn!("sendMessageDraft failed, falling back: {e:#}"); + use_draft = false; + if let Ok(sent) = + bot.send_message(chat_id, &display).await + { + msg_id = Some(sent.id); + if let Some(t) = new_text { + last_sent_text = t; + } + last_edit = Instant::now(); + } + } + } + } else if let Some(id) = msg_id { + if bot + .edit_message_text(chat_id, id, &display) + .await + .is_ok() + { + if let Some(t) = new_text { + last_sent_text = t; + } + last_edit = Instant::now(); + } + } else if let Ok(sent) = + bot.send_message(chat_id, &display).await + { + msg_id = Some(sent.id); + if let Some(t) = new_text { + last_sent_text = t; + } + last_edit = Instant::now(); + } + } + } + "result" => { + final_result = event.result.unwrap_or_default(); + is_error = event.is_error; + } + _ => {} + } + } + + // read stderr before waiting (in case child already exited) + let stderr_handle = child.stderr.take(); + let status = child.wait().await; + + // collect stderr for diagnostics + let stderr_text = if let Some(mut se) = stderr_handle { + let mut buf = String::new(); + let _ = tokio::io::AsyncReadExt::read_to_string(&mut se, &mut buf).await; + buf + } else { + String::new() + }; + + // determine error: explicit is_error from stream, or non-zero exit with no result + let has_error = is_error + || (final_result.is_empty() + && status.as_ref().map(|s| !s.success()).unwrap_or(true)); + + if has_error { + let err_detail = if !final_result.is_empty() { + final_result.clone() + } else if !stderr_text.is_empty() { + stderr_text.trim().to_string() + } else { + format!("claude exited: {:?}", status) + }; + if !use_draft { + if let Some(id) = msg_id { + let _ = bot + .edit_message_text(chat_id, id, format!("[error] {err_detail}")) + .await; + } + } + anyhow::bail!("{err_detail}"); + } + + if final_result.is_empty() { + return Ok(final_result); + } + + send_final_result(bot, chat_id, msg_id, use_draft, &final_result).await; + + Ok(final_result) +} + +// ── openai-compatible backend (streaming) ────────────────────────── + +pub async fn run_openai_streaming( + endpoint: &str, + model: &str, + api_key: &str, + messages: &[serde_json::Value], + bot: &Bot, + chat_id: ChatId, +) -> Result { + let client = reqwest::Client::new(); + let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + + let body = serde_json::json!({ + "model": model, + "messages": messages, + "stream": true, + }); + + let mut resp = client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .json(&body) + .send() + .await? + .error_for_status()?; + + let token = bot.token().to_owned(); + let raw_chat_id = chat_id.0; + let draft_id: i64 = 1; + let mut use_draft = true; + + let mut msg_id: Option = None; + let mut accumulated = String::new(); + let mut last_edit = Instant::now(); + let mut buffer = String::new(); + let mut done = false; + + while let Some(chunk) = resp.chunk().await? { + if done { + break; + } + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(pos) = buffer.find('\n') { + let line = buffer[..pos].to_string(); + buffer = buffer[pos + 1..].to_string(); + + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with(':') { + continue; + } + + let data = match trimmed.strip_prefix("data: ") { + Some(d) => d, + None => continue, + }; + + if data.trim() == "[DONE]" { + done = true; + break; + } + + if let Ok(json) = serde_json::from_str::(data) { + if let Some(content) = json["choices"][0]["delta"]["content"].as_str() { + if content.is_empty() { + continue; + } + accumulated.push_str(content); + + let interval = if use_draft { + DRAFT_INTERVAL_MS + } else { + EDIT_INTERVAL_MS + }; + if last_edit.elapsed().as_millis() < interval as u128 { + continue; + } + + let display = if use_draft { + truncate_at_char_boundary(&accumulated, TG_MSG_LIMIT).to_string() + } else { + truncate_for_display(&accumulated) + }; + + if use_draft { + match send_message_draft( + &client, &token, raw_chat_id, draft_id, &display, + ) + .await + { + Ok(_) => { + last_edit = Instant::now(); + } + Err(e) => { + warn!("sendMessageDraft failed, falling back: {e:#}"); + use_draft = false; + if let Ok(sent) = bot.send_message(chat_id, &display).await { + msg_id = Some(sent.id); + last_edit = Instant::now(); + } + } + } + } else if let Some(id) = msg_id { + if bot.edit_message_text(chat_id, id, &display).await.is_ok() { + last_edit = Instant::now(); + } + } else if let Ok(sent) = bot.send_message(chat_id, &display).await { + msg_id = Some(sent.id); + last_edit = Instant::now(); + } + } + } + } + } + + if accumulated.is_empty() { + return Ok(accumulated); + } + + send_final_result(bot, chat_id, msg_id, use_draft, &accumulated).await; + + Ok(accumulated) +} + +pub fn build_system_prompt(summary: &str, persona: &str, memory_slots: &[(i32, String)]) -> serde_json::Value { + let mut text = if persona.is_empty() { + String::from("你是一个AI助手。") + } else { + persona.to_string() + }; + + text.push_str( + "\n\n你可以使用提供的工具来完成任务。\ + 当需要执行命令、运行代码或启动复杂子任务时,直接调用对应的工具,不要只是描述你会怎么做。\ + 当需要搜索信息(如网页搜索、资料查找、技术调研等)时,使用 spawn_agent 启动一个子代理来完成搜索任务,\ + 子代理可以使用浏览器和搜索引擎,搜索完成后你会收到结果通知。\ + 输出格式:使用纯文本或基础Markdown(加粗、列表、代码块)。\ + 不要使用LaTeX公式($...$)、特殊Unicode符号(→←↔)或HTML标签,Telegram无法渲染这些。", + ); + + if !memory_slots.is_empty() { + text.push_str("\n\n## 持久记忆(跨会话保留)\n"); + for (nr, content) in memory_slots { + text.push_str(&format!("[{nr}] {content}\n")); + } + } + + if !summary.is_empty() { + text.push_str("\n\n## 之前的对话总结\n"); + text.push_str(summary); + } + + serde_json::json!({"role": "system", "content": text}) +} + +pub async fn summarize_messages( + endpoint: &str, + model: &str, + api_key: &str, + existing_summary: &str, + dropped: &[serde_json::Value], +) -> Result { + let msgs_text: String = dropped + .iter() + .filter_map(|m| { + let role = m["role"].as_str()?; + let content = m["content"].as_str()?; + Some(format!("{role}: {content}")) + }) + .collect::>() + .join("\n\n"); + + let prompt = if existing_summary.is_empty() { + format!( + "请将以下对话总结为约4000字符的摘要,保留关键信息和上下文:\n\n{}", + msgs_text + ) + } else { + format!( + "请将以下新对话内容整合到现有总结中,保持总结在约4000字符以内。\ + 保留重要信息,让较旧的话题自然淡出。\n\n\ + 现有总结:\n{}\n\n新对话:\n{}", + existing_summary, msgs_text + ) + }; + + let client = reqwest::Client::new(); + let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + + let body = serde_json::json!({ + "model": model, + "messages": [ + {"role": "system", "content": "你是一个对话总结助手。请生成简洁但信息丰富的总结。"}, + {"role": "user", "content": prompt} + ], + }); + + let resp = client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .json(&body) + .send() + .await? + .error_for_status()?; + + let json: serde_json::Value = resp.json().await?; + let summary = json["choices"][0]["message"]["content"] + .as_str() + .unwrap_or("") + .to_string(); + + Ok(summary) +} diff --git a/src/tools.rs b/src/tools.rs new file mode 100644 index 0000000..13fbddf --- /dev/null +++ b/src/tools.rs @@ -0,0 +1,665 @@ +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use anyhow::Result; +use teloxide::prelude::*; +use teloxide::types::InputFile; +use tokio::io::AsyncBufReadExt; +use tokio::process::Command; +use tokio::sync::RwLock; +use tracing::{error, info, warn}; + +use crate::config::{BackendConfig, Config}; +use crate::display::truncate_at_char_boundary; +use crate::state::AppState; +use crate::stream::{build_system_prompt, run_openai_streaming}; + +// ── subagent & tool call ─────────────────────────────────────────── + +pub struct SubAgent { + pub task: String, + pub output: Arc>, + pub completed: Arc, + pub exit_code: Arc>>, + pub pid: Option, +} + +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: String, +} + +pub fn tools_dir() -> PathBuf { + // tools/ relative to the config file location + let config_path = std::env::var("NOC_CONFIG").unwrap_or_else(|_| "config.yaml".into()); + let config_dir = Path::new(&config_path) + .parent() + .unwrap_or(Path::new(".")); + config_dir.join("tools") +} + +/// Scan tools/ directory for scripts with --schema, merge with built-in tools. +/// Called on every API request so new/updated scripts take effect immediately. +pub fn discover_tools() -> serde_json::Value { + let mut tools = vec![ + serde_json::json!({ + "type": "function", + "function": { + "name": "spawn_agent", + "description": "启动一个 Claude Code 子代理异步执行复杂任务。子代理可使用 shell、浏览器和搜索引擎,适合网页搜索、资料查找、技术调研、代码任务等。完成后会收到通知。", + "parameters": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "简短唯一标识符(如 'research'、'fix-bug')"}, + "task": {"type": "string", "description": "给子代理的详细任务描述"} + }, + "required": ["id", "task"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "agent_status", + "description": "查看正在运行或已完成的子代理的状态和输出", + "parameters": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "子代理标识符"} + }, + "required": ["id"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "kill_agent", + "description": "终止一个正在运行的子代理", + "parameters": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "子代理标识符"} + }, + "required": ["id"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "send_file", + "description": "通过 Telegram 向用户发送服务器上的文件,文件必须存在于服务器文件系统中。", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "服务器上文件的绝对路径"}, + "caption": {"type": "string", "description": "可选的文件说明/描述"} + }, + "required": ["path"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "update_scratch", + "description": "更新你的草稿区(工作笔记、状态、提醒)。草稿区内容会附加到每条用户消息中,确保你始终可见。用于跨轮次跟踪上下文。", + "parameters": { + "type": "object", + "properties": { + "content": {"type": "string", "description": "完整的草稿区内容(替换之前的内容)"} + }, + "required": ["content"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "set_timer", + "description": "Set a timer that will fire in the future. Supports: '5min'/'2h' (relative), 'once:2026-04-10 09:00' (absolute), 'cron:0 8 * * *' (recurring). When fired, you'll receive the label as a prompt.", + "parameters": { + "type": "object", + "properties": { + "schedule": {"type": "string", "description": "Timer schedule: e.g. '5min', '1h', 'once:2026-04-10 09:00', 'cron:30 8 * * *'"}, + "label": {"type": "string", "description": "What this timer is for — this text will be sent to you when it fires"} + }, + "required": ["schedule", "label"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "list_timers", + "description": "List all active timers", + "parameters": { + "type": "object", + "properties": {}, + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "cancel_timer", + "description": "Cancel a timer by ID", + "parameters": { + "type": "object", + "properties": { + "timer_id": {"type": "integer", "description": "Timer ID from list_timers"} + }, + "required": ["timer_id"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "update_memory", + "description": "写入持久记忆槽。共 100 个槽位(0-99),跨会话保留。记忆槽内容会注入到每次对话的 system prompt 中。用于存储关键事实、用户偏好或重要上下文。内容设为空字符串可清除槽位。", + "parameters": { + "type": "object", + "properties": { + "slot_nr": {"type": "integer", "description": "槽位编号(0-99)"}, + "content": {"type": "string", "description": "要存储的内容(最多200字符),空字符串表示清除该槽位"} + }, + "required": ["slot_nr", "content"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "gen_voice", + "description": "将文字合成为语音并直接发送给用户。", + "parameters": { + "type": "object", + "properties": { + "text": {"type": "string", "description": "要合成语音的文字内容"} + }, + "required": ["text"] + } + } + }), + ]; + + // discover script tools + let dir = tools_dir(); + if let Ok(entries) = std::fs::read_dir(&dir) { + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_file() { + continue; + } + // run --schema with a short timeout + let output = std::process::Command::new(&path) + .arg("--schema") + .output(); + match output { + Ok(out) if out.status.success() => { + let stdout = String::from_utf8_lossy(&out.stdout); + match serde_json::from_str::(stdout.trim()) { + Ok(schema) => { + let name = schema["name"].as_str().unwrap_or("?"); + info!(tool = %name, path = %path.display(), "discovered script tool"); + tools.push(serde_json::json!({ + "type": "function", + "function": schema, + })); + } + Err(e) => { + warn!(path = %path.display(), "invalid --schema JSON: {e}"); + } + } + } + _ => {} // not a tool script, skip silently + } + } + } + + serde_json::Value::Array(tools) +} + +// ── tool execution ───────────────────────────────────────────────── + +pub async fn execute_tool( + name: &str, + arguments: &str, + state: &Arc, + bot: &Bot, + chat_id: ChatId, + sid: &str, + config: &Arc, +) -> String { + let args: serde_json::Value = match serde_json::from_str(arguments) { + Ok(v) => v, + Err(e) => return format!("Invalid arguments: {e}"), + }; + + match name { + "spawn_agent" => { + let id = args["id"].as_str().unwrap_or("agent"); + let task = args["task"].as_str().unwrap_or(""); + spawn_agent(id, task, state, bot, chat_id, sid, config).await + } + "agent_status" => { + let id = args["id"].as_str().unwrap_or(""); + check_agent_status(id, state).await + } + "kill_agent" => { + let id = args["id"].as_str().unwrap_or(""); + kill_agent(id, state).await + } + "send_file" => { + let path_str = args["path"].as_str().unwrap_or(""); + let caption = args["caption"].as_str().unwrap_or(""); + let path = Path::new(path_str); + if !path.exists() { + return format!("File not found: {path_str}"); + } + if !path.is_file() { + return format!("Not a file: {path_str}"); + } + let input_file = InputFile::file(path); + let mut req = bot.send_document(chat_id, input_file); + if !caption.is_empty() { + req = req.caption(caption); + } + match req.await { + Ok(_) => format!("File sent: {path_str}"), + Err(e) => format!("Failed to send file: {e:#}"), + } + } + "update_scratch" => { + let content = args["content"].as_str().unwrap_or(""); + state.push_scratch(content).await; + format!("Scratch updated ({} chars)", content.len()) + } + "set_timer" => { + let schedule = args["schedule"].as_str().unwrap_or(""); + let label = args["label"].as_str().unwrap_or(""); + match parse_next_fire(schedule) { + Ok(next) => { + let next_str = next.format("%Y-%m-%d %H:%M:%S").to_string(); + let id = state + .add_timer(chat_id.0, label, schedule, &next_str) + .await; + format!("Timer #{id} set: \"{label}\" → next fire at {next_str}") + } + Err(e) => format!("Invalid schedule '{schedule}': {e}"), + } + } + "list_timers" => { + let timers = state.list_timers(Some(chat_id.0)).await; + if timers.is_empty() { + "No active timers.".to_string() + } else { + timers + .iter() + .map(|(id, _, label, sched, next, enabled)| { + let status = if *enabled { "" } else { " [disabled]" }; + format!("#{id}: \"{label}\" ({sched}) → {next}{status}") + }) + .collect::>() + .join("\n") + } + } + "cancel_timer" => { + let tid = args["timer_id"].as_i64().unwrap_or(0); + if state.cancel_timer(tid).await { + format!("Timer #{tid} cancelled") + } else { + format!("Timer #{tid} not found") + } + } + "update_memory" => { + let slot_nr = args["slot_nr"].as_i64().unwrap_or(-1) as i32; + let content = args["content"].as_str().unwrap_or(""); + match state.set_memory_slot(slot_nr, content).await { + Ok(_) => { + if content.is_empty() { + format!("Memory slot {slot_nr} cleared") + } else { + format!("Memory slot {slot_nr} updated ({} chars)", content.len()) + } + } + Err(e) => format!("Error: {e}"), + } + } + "gen_voice" => { + let text = args["text"].as_str().unwrap_or(""); + if text.is_empty() { + return "Error: text is required".to_string(); + } + let script = tools_dir().join("gen_voice"); + let result = tokio::time::timeout( + std::time::Duration::from_secs(120), + tokio::process::Command::new(&script) + .arg(arguments) + .output(), + ) + .await; + match result { + Ok(Ok(out)) if out.status.success() => { + let path_str = String::from_utf8_lossy(&out.stdout).trim().to_string(); + let path = Path::new(&path_str); + if path.exists() { + let input_file = InputFile::file(path); + match bot.send_voice(chat_id, input_file).await { + Ok(_) => format!("语音已发送: {path_str}"), + Err(e) => format!("语音生成成功但发送失败: {e:#}"), + } + } else { + format!("语音生成失败: 输出文件不存在 ({path_str})") + } + } + Ok(Ok(out)) => { + let stderr = String::from_utf8_lossy(&out.stderr); + let stdout = String::from_utf8_lossy(&out.stdout); + format!("gen_voice failed: {stdout} {stderr}") + } + Ok(Err(e)) => format!("gen_voice exec error: {e}"), + Err(_) => "gen_voice timeout (120s)".to_string(), + } + } + _ => run_script_tool(name, arguments).await, + } +} + +pub async fn spawn_agent( + id: &str, + task: &str, + state: &Arc, + bot: &Bot, + chat_id: ChatId, + sid: &str, + config: &Arc, +) -> String { + // check if already exists + if state.agents.read().await.contains_key(id) { + return format!("Agent '{id}' already exists. Use agent_status to check it."); + } + + let mut child = match Command::new("claude") + .args(["--dangerously-skip-permissions", "-p", task]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + { + Ok(c) => c, + Err(e) => return format!("Failed to spawn agent: {e}"), + }; + + let pid = child.id(); + let output = Arc::new(tokio::sync::RwLock::new(String::new())); + let completed = Arc::new(AtomicBool::new(false)); + let exit_code = Arc::new(tokio::sync::RwLock::new(None)); + + let agent = Arc::new(SubAgent { + task: task.to_string(), + output: output.clone(), + completed: completed.clone(), + exit_code: exit_code.clone(), + pid, + }); + + state.agents.write().await.insert(id.to_string(), agent); + + // background task: collect output and wakeup on completion + let out = output.clone(); + let done = completed.clone(); + let ecode = exit_code.clone(); + let bot_c = bot.clone(); + let chat_id_c = chat_id; + let state_c = state.clone(); + let config_c = config.clone(); + let sid_c = sid.to_string(); + let id_c = id.to_string(); + + tokio::spawn(async move { + let stdout = child.stdout.take(); + if let Some(stdout) = stdout { + let mut lines = tokio::io::BufReader::new(stdout).lines(); + while let Ok(Some(line)) = lines.next_line().await { + let mut o = out.write().await; + o.push_str(&line); + o.push('\n'); + } + } + let status = child.wait().await; + let code = status.as_ref().ok().and_then(|s| s.code()); + *ecode.write().await = code; + done.store(true, Ordering::SeqCst); + + info!(agent = %id_c, "agent completed, exit={code:?}"); + + // wakeup: inject result and trigger LLM + let result = out.read().await.clone(); + let result_short = truncate_at_char_boundary(&result, 4000); + let wakeup = format!( + "[Agent '{id_c}' 执行完成 (exit={})]\n{result_short}", + code.unwrap_or(-1) + ); + + if let Err(e) = agent_wakeup( + &config_c, &state_c, &bot_c, chat_id_c, &sid_c, &wakeup, &id_c, + ) + .await + { + error!(agent = %id_c, "wakeup failed: {e:#}"); + let _ = bot_c + .send_message(chat_id_c, format!("[agent wakeup error] {e:#}")) + .await; + } + }); + + format!("Agent '{id}' spawned (pid={pid:?})") +} + +pub async fn agent_wakeup( + config: &Config, + state: &AppState, + bot: &Bot, + chat_id: ChatId, + sid: &str, + wakeup_msg: &str, + agent_id: &str, +) -> Result<()> { + match &config.backend { + BackendConfig::OpenAI { + endpoint, + model, + api_key, + } => { + state.push_message(sid, "user", wakeup_msg).await; + let conv = state.load_conv(sid).await; + let persona = state.get_config("persona").await.unwrap_or_default(); + let memory_slots = state.get_memory_slots().await; + let system_msg = build_system_prompt(&conv.summary, &persona, &memory_slots); + let mut api_messages = vec![system_msg]; + api_messages.extend(conv.messages); + + info!(agent = %agent_id, "wakeup: sending {} messages to LLM", api_messages.len()); + + let response = + run_openai_streaming(endpoint, model, api_key, &api_messages, bot, chat_id) + .await?; + + if !response.is_empty() { + state.push_message(sid, "assistant", &response).await; + } + + Ok(()) + } + _ => { + let _ = bot + .send_message(chat_id, format!("[Agent '{agent_id}' done]\n{wakeup_msg}")) + .await; + Ok(()) + } + } +} + +pub async fn check_agent_status(id: &str, state: &AppState) -> String { + let agents = state.agents.read().await; + match agents.get(id) { + Some(agent) => { + let status = if agent.completed.load(Ordering::SeqCst) { + let code = agent.exit_code.read().await; + format!("completed (exit={})", code.unwrap_or(-1)) + } else { + "running".to_string() + }; + let output = agent.output.read().await; + let out_preview = truncate_at_char_boundary(&output, 3000); + format!( + "Agent '{id}': {status}\nTask: {}\nOutput ({} bytes):\n{out_preview}", + agent.task, + output.len() + ) + } + None => format!("Agent '{id}' not found"), + } +} + +pub async fn kill_agent(id: &str, state: &AppState) -> String { + let agents = state.agents.read().await; + match agents.get(id) { + Some(agent) => { + if agent.completed.load(Ordering::SeqCst) { + return format!("Agent '{id}' already completed"); + } + if let Some(pid) = agent.pid { + unsafe { + libc::kill(pid as i32, libc::SIGTERM); + } + format!("Sent SIGTERM to agent '{id}' (pid={pid})") + } else { + format!("Agent '{id}' has no PID") + } + } + None => format!("Agent '{id}' not found"), + } +} + +pub async fn run_script_tool(name: &str, arguments: &str) -> String { + // find script in tools/ that matches this tool name + let dir = tools_dir(); + let entries = match std::fs::read_dir(&dir) { + Ok(e) => e, + Err(_) => return format!("Unknown tool: {name}"), + }; + + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_file() { + continue; + } + // check if this script provides the requested tool + let schema_out = std::process::Command::new(&path) + .arg("--schema") + .output(); + if let Ok(out) = schema_out { + if out.status.success() { + let stdout = String::from_utf8_lossy(&out.stdout); + if let Ok(schema) = serde_json::from_str::(stdout.trim()) { + if schema["name"].as_str() == Some(name) { + // found it — execute + info!(tool = %name, path = %path.display(), "running script tool"); + let result = tokio::time::timeout( + std::time::Duration::from_secs(60), + Command::new(&path).arg(arguments).output(), + ) + .await; + + return match result { + Ok(Ok(output)) => { + let mut s = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr); + if !stderr.is_empty() { + if !s.is_empty() { + s.push_str("\n[stderr]\n"); + } + s.push_str(&stderr); + } + if s.is_empty() { + format!("(exit={})", output.status.code().unwrap_or(-1)) + } else { + s + } + } + Ok(Err(e)) => format!("Failed to execute {name}: {e}"), + Err(_) => "Timeout after 60s".to_string(), + }; + } + } + } + } + } + + format!("Unknown tool: {name}") +} + +// ── schedule parsing ─────────────────────────────────────────────── + +pub fn parse_next_fire(schedule: &str) -> Result> { + let now = chrono::Local::now(); + + // relative: "5min", "2h", "30s", "1d" + if let Some(val) = schedule + .strip_suffix("min") + .or_else(|| schedule.strip_suffix("m")) + { + let mins: i64 = val.trim().parse().map_err(|e| anyhow::anyhow!("{e}"))?; + return Ok(now + chrono::Duration::minutes(mins)); + } + if let Some(val) = schedule.strip_suffix('h') { + let hours: i64 = val.trim().parse().map_err(|e| anyhow::anyhow!("{e}"))?; + return Ok(now + chrono::Duration::hours(hours)); + } + if let Some(val) = schedule.strip_suffix('s') { + let secs: i64 = val.trim().parse().map_err(|e| anyhow::anyhow!("{e}"))?; + return Ok(now + chrono::Duration::seconds(secs)); + } + if let Some(val) = schedule.strip_suffix('d') { + let days: i64 = val.trim().parse().map_err(|e| anyhow::anyhow!("{e}"))?; + return Ok(now + chrono::Duration::days(days)); + } + + // absolute: "once:2026-04-10 09:00" + if let Some(dt_str) = schedule.strip_prefix("once:") { + let dt = chrono::NaiveDateTime::parse_from_str(dt_str.trim(), "%Y-%m-%d %H:%M") + .or_else(|_| { + chrono::NaiveDateTime::parse_from_str(dt_str.trim(), "%Y-%m-%d %H:%M:%S") + }) + .map_err(|e| anyhow::anyhow!("parse datetime: {e}"))?; + return Ok(dt.and_local_timezone(chrono::Local).unwrap()); + } + + // cron: "cron:30 8 * * *" + if let Some(expr) = schedule.strip_prefix("cron:") { + let cron_schedule = expr + .trim() + .parse::() + .map_err(|e| anyhow::anyhow!("parse cron: {e}"))?; + let next = cron_schedule + .upcoming(chrono::Local) + .next() + .ok_or_else(|| anyhow::anyhow!("no upcoming time for cron"))?; + return Ok(next); + } + + anyhow::bail!("unknown schedule format: {schedule}") +} + +pub fn compute_next_cron_fire(schedule: &str) -> Option { + let expr = schedule.strip_prefix("cron:")?; + let cron_schedule = expr.trim().parse::().ok()?; + let next = cron_schedule.upcoming(chrono::Local).next()?; + Some(next.format("%Y-%m-%d %H:%M:%S").to_string()) +}