From 128f2481c07f8c568e0cd3df50f3a068f86746f0 Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Thu, 9 Apr 2026 16:38:28 +0100 Subject: [PATCH] add tool calling, SQLite persistence, group chat, image vision, voice transcription Major features: - OpenAI function calling with tool call loop (streaming SSE parsing) - Built-in tools: spawn_agent (async claude -p), agent_status, kill_agent, update_scratch, send_file - Script-based tool discovery: tools/ dir with --schema convention - Feishu todo management script (tools/manage_todo) - SQLite persistence: conversations, messages, config, scratch_area tables - Sliding window context (100 msgs, slide 50, auto-summarize) - Conversation summary generation via LLM on window slide - Group chat support with independent session contexts - Image understanding: multimodal vision input (base64 to API) - Voice transcription via faster-whisper Docker service - Configurable persona stored in DB - diag command for session diagnostics - System prompt restructured: persona + tool instructions separated - RUST_BACKTRACE=1 in service, clippy in deploy pipeline - .gitignore for config/state/db files --- .gitignore | 6 +- Cargo.lock | 92 ++++ Cargo.toml | 5 +- Makefile | 8 +- doc/todo.md | 52 +- noc.service.in | 1 + src/main.rs | 1179 ++++++++++++++++++++++++++++++++++++++++++-- tests/tool_call.rs | 361 ++++++++++++++ tools/manage_todo | 187 +++++++ 9 files changed, 1840 insertions(+), 51 deletions(-) create mode 100644 tests/tool_call.rs create mode 100755 tools/manage_todo diff --git a/.gitignore b/.gitignore index f775163..ed28664 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -/target config.yaml -config.hera.yaml +config.*.yaml state.json -noc.service +state.*.json +*.db diff --git a/Cargo.lock b/Cargo.lock index f2ec589..4b5a191 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,18 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -258,6 +270,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.4.0" @@ -465,6 +489,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -480,6 +513,15 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.4.1" @@ -886,6 +928,17 @@ version = "0.2.184" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af" +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -983,10 +1036,13 @@ name = "noc" version = "0.1.0" dependencies = [ "anyhow", + "base64 0.22.1", "chrono", "dptree", + "libc", "pulldown-cmark", "reqwest 0.12.28", + "rusqlite", "serde", "serde_json", "serde_yaml", @@ -1300,6 +1356,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-core", + "futures-util", "h2 0.4.13", "http 1.4.0", "http-body 1.0.1", @@ -1311,6 +1368,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -1344,6 +1402,20 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags 2.11.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc_version" version = "0.4.1" @@ -2607,6 +2679,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "zerofrom" version = "0.1.7" diff --git a/Cargo.toml b/Cargo.toml index 2a22d0c..9f95dfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,13 +5,16 @@ edition = "2021" [dependencies] anyhow = "1" +base64 = "0.22" chrono = { version = "0.4", features = ["serde"] } dptree = "0.3" +libc = "0.2" serde = { version = "1", features = ["derive"] } serde_json = "1" serde_yaml = "0.9" pulldown-cmark = "0.12" -reqwest = { version = "0.12", features = ["json"] } +reqwest = { version = "0.12", features = ["json", "multipart"] } +rusqlite = { version = "0.32", features = ["bundled"] } teloxide = { version = "0.12", features = ["macros"] } tokio = { version = "1", features = ["full"] } uuid = { version = "1", features = ["v5"] } diff --git a/Makefile b/Makefile index 3b9ab7a..6958a19 100644 --- a/Makefile +++ b/Makefile @@ -2,15 +2,19 @@ REPO := $(shell pwd) HERA := heradev HERA_DIR := noc -.PHONY: build deploy deploy-hera +.PHONY: build test deploy deploy-hera build: cargo build --release +test: + cargo clippy -- -D warnings + cargo test -- --nocapture + noc.service: noc.service.in sed -e 's|@REPO@|$(REPO)|g' -e 's|@PATH@|$(PATH)|g' $< > $@ -deploy: build noc.service +deploy: test build noc.service mkdir -p ~/bin ~/.config/systemd/user systemctl --user stop noc 2>/dev/null || true install target/release/noc ~/bin/noc diff --git a/doc/todo.md b/doc/todo.md index 701f7d0..6d958d6 100644 --- a/doc/todo.md +++ b/doc/todo.md @@ -1,10 +1,44 @@ -# TODO +# noc roadmap -- [ ] Streaming responses — edit message as claude output arrives instead of waiting for full completion -- [ ] Markdown formatting — parse claude output and send with TG MarkdownV2 -- [ ] Timeout handling — kill claude if it hangs beyond a threshold -- [ ] Graceful shutdown on SIGTERM -- [ ] `/reset` command to force new session without waiting for 5am -- [ ] Rate limiting per chat -- [ ] Voice message support — STT (whisper.cpp) → text → claude -- [ ] Video/audio file transcription +## "会呼吸的助手" — 让 noc 活着 + +核心理念:noc 不应该只在收到消息时才被唤醒,而是一个持续运行、有自己节奏的存在。 + +### 主动行为 +- [ ] 定时任务 (cron):LLM 可以自己设置提醒、定期检查 +- [ ] 事件驱动:监控文件变化、git push、CI 状态等,主动通知 +- [ ] 晨间/晚间报告:每天自动汇总待办、提醒重要事项 +- [ ] 情境感知:根据时间、地点、日历自动调整行为 + +### 记忆与成长 +- [ ] 长期记忆 (MEMORY.md):跨 session 的持久化记忆 +- [ ] 语义搜索:基于 embedding 的记忆检索 +- [ ] 自我反思:定期回顾对话质量,优化自己的行为 + +### 感知能力 +- [x] 图片理解:multimodal vision input +- [ ] 语音转录:whisper API 转文字 +- [ ] 屏幕/截图分析 +- [ ] 链接预览/摘要 + +### 交互体验 +- [x] 群组支持:独立上下文 +- [x] 流式输出:sendMessageDraft + editMessageText +- [x] Markdown 渲染 +- [ ] Typing indicator +- [ ] Inline keyboard 交互 +- [ ] 语音回复 (TTS) + +### 工具生态 +- [x] 脚本工具发现 (tools/ + --schema) +- [x] 异步子代理 (spawn_agent) +- [x] 飞书待办管理 +- [ ] Web search / fetch +- [ ] 更多脚本工具 +- [ ] MCP 协议支持 + +### 可靠性 +- [ ] API 重试策略 (指数退避) +- [ ] 用量追踪 +- [ ] Context pruning (只裁工具输出) +- [ ] Model failover diff --git a/noc.service.in b/noc.service.in index c0021dd..b6a7c0e 100644 --- a/noc.service.in +++ b/noc.service.in @@ -10,6 +10,7 @@ ExecStart=%h/bin/noc Restart=on-failure RestartSec=5 Environment=RUST_LOG=noc=info +Environment=RUST_BACKTRACE=1 Environment=NOC_CONFIG=@REPO@/config.yaml Environment=NOC_STATE=@REPO@/state.json Environment=PATH=@PATH@ diff --git a/src/main.rs b/src/main.rs index 8d8cc80..f6043b3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use std::process::Stdio; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use anyhow::Result; @@ -15,6 +16,7 @@ use tokio::process::Command; use tokio::sync::RwLock; use tokio::time::Instant; use tracing::{error, info, warn}; +use base64::Engine; use uuid::Uuid; // ── config ────────────────────────────────────────────────────────── @@ -28,16 +30,19 @@ struct Config { session: SessionConfig, #[serde(default)] backend: BackendConfig, + #[serde(default)] + whisper_url: Option, } fn default_name() -> String { "noc".to_string() } -#[derive(Deserialize, Clone)] +#[derive(Deserialize, Clone, Default)] #[serde(tag = "type")] enum BackendConfig { #[serde(rename = "claude")] + #[default] Claude, #[serde(rename = "openai")] OpenAI { @@ -52,11 +57,6 @@ fn default_api_key() -> String { "unused".to_string() } -impl Default for BackendConfig { - fn default() -> Self { - BackendConfig::Claude - } -} #[derive(Deserialize)] struct TgConfig { @@ -81,10 +81,161 @@ struct Persistent { known_sessions: HashSet, } +#[derive(Serialize, Deserialize, Clone, Default)] +struct ConversationState { + summary: String, + messages: Vec, + total_messages: usize, +} + +const MAX_WINDOW: usize = 100; +const SLIDE_SIZE: usize = 50; + +// ── subagent & tool call ─────────────────────────────────────────── + +struct SubAgent { + task: String, + output: Arc>, + completed: Arc, + exit_code: Arc>>, + pid: Option, +} + +struct ToolCall { + id: String, + name: String, + arguments: String, +} + +fn tools_dir() -> PathBuf { + // tools/ relative to the config file location + let config_path = std::env::var("NOC_CONFIG").unwrap_or_else(|_| "config.yaml".into()); + let config_dir = Path::new(&config_path) + .parent() + .unwrap_or(Path::new(".")); + config_dir.join("tools") +} + +/// Scan tools/ directory for scripts with --schema, merge with built-in tools. +/// Called on every API request so new/updated scripts take effect immediately. +fn discover_tools() -> serde_json::Value { + let mut tools = vec![ + serde_json::json!({ + "type": "function", + "function": { + "name": "spawn_agent", + "description": "Spawn a Claude Code subagent to handle a complex task asynchronously. You'll be notified when it completes.", + "parameters": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "Short unique identifier (e.g. 'research', 'fix-bug')"}, + "task": {"type": "string", "description": "Detailed task description for the agent"} + }, + "required": ["id", "task"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "agent_status", + "description": "Check the current status and output of a running or completed agent", + "parameters": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "The agent identifier"} + }, + "required": ["id"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "kill_agent", + "description": "Terminate a running agent", + "parameters": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "The agent identifier"} + }, + "required": ["id"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "send_file", + "description": "Send a file from the server to the user via Telegram. The file must exist on the server filesystem.", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "Absolute path to the file on the server"}, + "caption": {"type": "string", "description": "Optional caption/description for the file"} + }, + "required": ["path"] + } + } + }), + serde_json::json!({ + "type": "function", + "function": { + "name": "update_scratch", + "description": "Update your scratch area (working notes, state, reminders). This content is appended to every user message so you always see it. Use it to track ongoing context across turns.", + "parameters": { + "type": "object", + "properties": { + "content": {"type": "string", "description": "The full scratch area content (replaces previous)"} + }, + "required": ["content"] + } + } + }), + ]; + + // discover script tools + let dir = tools_dir(); + if let Ok(entries) = std::fs::read_dir(&dir) { + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_file() { + continue; + } + // run --schema with a short timeout + let output = std::process::Command::new(&path) + .arg("--schema") + .output(); + match output { + Ok(out) if out.status.success() => { + let stdout = String::from_utf8_lossy(&out.stdout); + match serde_json::from_str::(stdout.trim()) { + Ok(schema) => { + let name = schema["name"].as_str().unwrap_or("?"); + info!(tool = %name, path = %path.display(), "discovered script tool"); + tools.push(serde_json::json!({ + "type": "function", + "function": schema, + })); + } + Err(e) => { + warn!(path = %path.display(), "invalid --schema JSON: {e}"); + } + } + } + _ => {} // not a tool script, skip silently + } + } + } + + serde_json::Value::Array(tools) +} + struct AppState { persist: RwLock, state_path: PathBuf, - conversations: RwLock>>, + db: tokio::sync::Mutex, + agents: RwLock>>, } impl AppState { @@ -94,10 +245,50 @@ impl AppState { .and_then(|s| serde_json::from_str(&s).ok()) .unwrap_or_default(); info!("loaded state from {}", path.display()); + + let db_path = path.parent().unwrap_or(Path::new(".")).join("noc.db"); + let conn = rusqlite::Connection::open(&db_path) + .unwrap_or_else(|e| panic!("open {}: {e}", db_path.display())); + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS conversations ( + session_id TEXT PRIMARY KEY, + summary TEXT NOT NULL DEFAULT '', + total_messages INTEGER NOT NULL DEFAULT 0 + ); + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id); + CREATE TABLE IF NOT EXISTS scratch_area ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + content TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')) + ); + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL DEFAULT '', + create_time TEXT NOT NULL DEFAULT (datetime('now')), + update_time TEXT NOT NULL DEFAULT (datetime('now')) + ); + CREATE TABLE IF NOT EXISTS config_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key TEXT NOT NULL, + value TEXT NOT NULL, + create_time TEXT NOT NULL, + update_time TEXT NOT NULL + );", + ) + .expect("init db schema"); + info!("opened db {}", db_path.display()); + Self { persist: RwLock::new(persist), state_path: path, - conversations: RwLock::new(HashMap::new()), + db: tokio::sync::Mutex::new(conn), + agents: RwLock::new(HashMap::new()), } } @@ -109,6 +300,119 @@ impl AppState { } } } + + async fn load_conv(&self, sid: &str) -> ConversationState { + let db = self.db.lock().await; + let (summary, total) = db + .query_row( + "SELECT summary, total_messages FROM conversations WHERE session_id = ?1", + [sid], + |row| Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?)), + ) + .unwrap_or_default(); + + let mut stmt = db + .prepare("SELECT role, content FROM messages WHERE session_id = ?1 ORDER BY id") + .unwrap(); + let messages: Vec = stmt + .query_map([sid], |row| { + let role: String = row.get(0)?; + let content: String = row.get(1)?; + Ok(serde_json::json!({"role": role, "content": content})) + }) + .unwrap() + .filter_map(|r| r.ok()) + .collect(); + + ConversationState { + summary, + messages, + total_messages: total, + } + } + + async fn push_message(&self, sid: &str, role: &str, content: &str) { + let db = self.db.lock().await; + let _ = db.execute( + "INSERT OR IGNORE INTO conversations (session_id) VALUES (?1)", + [sid], + ); + let _ = db.execute( + "INSERT INTO messages (session_id, role, content) VALUES (?1, ?2, ?3)", + rusqlite::params![sid, role, content], + ); + } + + async fn message_count(&self, sid: &str) -> usize { + let db = self.db.lock().await; + db.query_row( + "SELECT COUNT(*) FROM messages WHERE session_id = ?1", + [sid], + |row| row.get(0), + ) + .unwrap_or(0) + } + + async fn slide_window(&self, sid: &str, new_summary: &str, slide_size: usize) { + let db = self.db.lock().await; + let _ = db.execute( + "DELETE FROM messages WHERE id IN ( + SELECT id FROM messages WHERE session_id = ?1 ORDER BY id LIMIT ?2 + )", + rusqlite::params![sid, slide_size], + ); + let _ = db.execute( + "UPDATE conversations SET summary = ?1, total_messages = total_messages + ?2 \ + WHERE session_id = ?3", + rusqlite::params![new_summary, slide_size, sid], + ); + } + + async fn get_oldest_messages(&self, sid: &str, count: usize) -> Vec { + let db = self.db.lock().await; + let mut stmt = db + .prepare( + "SELECT role, content FROM messages WHERE session_id = ?1 ORDER BY id LIMIT ?2", + ) + .unwrap(); + stmt.query_map(rusqlite::params![sid, count], |row| { + let role: String = row.get(0)?; + let content: String = row.get(1)?; + Ok(serde_json::json!({"role": role, "content": content})) + }) + .unwrap() + .filter_map(|r| r.ok()) + .collect() + } + + async fn get_scratch(&self) -> String { + let db = self.db.lock().await; + db.query_row( + "SELECT content FROM scratch_area ORDER BY id DESC LIMIT 1", + [], + |row| row.get(0), + ) + .unwrap_or_default() + } + + async fn push_scratch(&self, content: &str) { + let db = self.db.lock().await; + let _ = db.execute( + "INSERT INTO scratch_area (content) VALUES (?1)", + [content], + ); + } + + async fn get_config(&self, key: &str) -> Option { + let db = self.db.lock().await; + db.query_row( + "SELECT value FROM config WHERE key = ?1", + [key], + |row| row.get(0), + ) + .ok() + } + } // ── helpers ───────────────────────────────────────────────────────── @@ -165,13 +469,15 @@ async fn main() { let _ = std::fs::create_dir_all(incoming_dir()); - info!("noc bot starting"); - let bot = Bot::new(&config.tg.key); + let me = bot.get_me().await.unwrap(); + let bot_username = Arc::new(me.username.clone().unwrap_or_default()); + info!(username = %bot_username, "noc bot starting"); + let handler = Update::filter_message().endpoint(handle); Dispatcher::builder(bot, handler) - .dependencies(dptree::deps![state, Arc::new(config)]) + .dependencies(dptree::deps![state, Arc::new(config), bot_username]) .default_handler(|_| async {}) .build() .dispatch() @@ -229,8 +535,10 @@ async fn handle( msg: Message, state: Arc, config: Arc, + _bot_username: Arc, ) -> ResponseResult<()> { let chat_id = msg.chat.id; + let is_private = msg.chat.is_private(); let text = msg.text().or(msg.caption()).unwrap_or("").to_string(); let raw_id = chat_id.0; let date = session_date(config.session.refresh_hour); @@ -255,7 +563,7 @@ async fn handle( return Ok(()); } - if let Err(e) = handle_inner(&bot, &msg, chat_id, &text, &state, &config).await { + if let Err(e) = handle_inner(&bot, &msg, chat_id, &text, is_private, &state, &config).await { error!(chat = raw_id, "handle: {e:#}"); let _ = bot.send_message(chat_id, format!("[error] {e:#}")).await; } @@ -268,11 +576,13 @@ async fn handle_inner( msg: &Message, chat_id: ChatId, text: &str, + is_private: bool, state: &Arc, config: &Arc, ) -> Result<()> { let mut uploaded: Vec = Vec::new(); let mut download_errors: Vec = Vec::new(); + let mut transcriptions: Vec = Vec::new(); if let Some(doc) = msg.document() { let name = doc.file_name.as_deref().unwrap_or("file"); @@ -296,7 +606,20 @@ async fn handle_inner( let fallback = format!("audio_{}.ogg", Local::now().format("%H%M%S")); let name = audio.file_name.as_deref().unwrap_or(&fallback); match download_tg_file(bot, &audio.file.id, name).await { - Ok(p) => uploaded.push(p), + Ok(p) => { + if let Some(url) = &config.whisper_url { + match transcribe_audio(url, &p).await { + Ok(t) if !t.is_empty() => transcriptions.push(t), + Ok(_) => uploaded.push(p), + Err(e) => { + warn!("transcribe failed: {e:#}"); + uploaded.push(p); + } + } + } else { + uploaded.push(p); + } + } Err(e) => download_errors.push(format!("audio: {e:#}")), } } @@ -304,7 +627,20 @@ async fn handle_inner( if let Some(voice) = msg.voice() { let name = format!("voice_{}.ogg", Local::now().format("%H%M%S")); match download_tg_file(bot, &voice.file.id, &name).await { - Ok(p) => uploaded.push(p), + Ok(p) => { + if let Some(url) = &config.whisper_url { + match transcribe_audio(url, &p).await { + Ok(t) if !t.is_empty() => transcriptions.push(t), + Ok(_) => uploaded.push(p), + Err(e) => { + warn!("transcribe failed: {e:#}"); + uploaded.push(p); + } + } + } else { + uploaded.push(p); + } + } Err(e) => download_errors.push(format!("voice: {e:#}")), } } @@ -326,7 +662,7 @@ async fn handle_inner( } } - if text.is_empty() && uploaded.is_empty() { + if text.is_empty() && uploaded.is_empty() && transcriptions.is_empty() { if !download_errors.is_empty() { let err_msg = format!("[文件下载失败]\n{}", download_errors.join("\n")); bot.send_message(chat_id, err_msg).await?; @@ -341,12 +677,44 @@ async fn handle_inner( tokio::fs::create_dir_all(&out_dir).await?; let before = snapshot_dir(&out_dir).await; - let prompt = build_prompt(text, &uploaded, &download_errors, &out_dir); + // handle diag command (OpenAI backend only) + if text.trim() == "diag" { + if let BackendConfig::OpenAI { .. } = &config.backend { + let conv = state.load_conv(&sid).await; + let count = state.message_count(&sid).await; + let persona = state.get_config("persona").await.unwrap_or_default(); + let scratch = state.get_scratch().await; + let diag = format!( + "session: {sid}\n\ + window: {count}/{MAX_WINDOW} (slide at {MAX_WINDOW}, drop {SLIDE_SIZE})\n\ + total processed: {}\n\n\ + persona ({} chars):\n{}\n\n\ + scratch ({} chars):\n{}\n\n\ + summary ({} chars):\n{}", + conv.total_messages + count, + persona.len(), + if persona.is_empty() { "(default)" } else { &persona }, + scratch.len(), + if scratch.is_empty() { "(empty)" } else { &scratch }, + conv.summary.len(), + if conv.summary.is_empty() { + "(empty)".to_string() + } else { + conv.summary + } + ); + bot.send_message(chat_id, diag).await?; + return Ok(()); + } + } + + let prompt = build_prompt(text, &uploaded, &download_errors, &transcriptions); match &config.backend { BackendConfig::Claude => { let known = state.persist.read().await.known_sessions.contains(&sid); - let result = invoke_claude_streaming(&sid, &prompt, known, bot, chat_id).await; + let result = + invoke_claude_streaming(&sid, &prompt, known, bot, chat_id).await; match &result { Ok(_) => { if !known { @@ -365,18 +733,67 @@ async fn handle_inner( model, api_key, } => { - let mut messages = { - let convos = state.conversations.read().await; - convos.get(&sid).cloned().unwrap_or_default() - }; - messages.push(serde_json::json!({"role": "user", "content": &prompt})); - match run_openai_streaming(endpoint, model, api_key, &messages, bot, chat_id).await { + let conv = state.load_conv(&sid).await; + let persona = state.get_config("persona").await.unwrap_or_default(); + let system_msg = build_system_prompt(&conv.summary, &persona); + + let mut api_messages = vec![system_msg]; + api_messages.extend(conv.messages); + + let scratch = state.get_scratch().await; + let user_content = build_user_content(&prompt, &scratch, &uploaded); + api_messages.push(serde_json::json!({"role": "user", "content": user_content})); + + match run_openai_with_tools( + endpoint, model, api_key, api_messages, bot, chat_id, state, &sid, config, is_private, + ) + .await + { Ok(response) => { + state.push_message(&sid, "user", &prompt).await; if !response.is_empty() { - messages - .push(serde_json::json!({"role": "assistant", "content": &response})); + state.push_message(&sid, "assistant", &response).await; + } + + // sliding window + let count = state.message_count(&sid).await; + if count >= MAX_WINDOW { + info!(%sid, "sliding window: {count} messages, summarizing oldest {SLIDE_SIZE}"); + let _ = bot + .send_message(chat_id, "[整理记忆中...]") + .await; + + let to_summarize = + state.get_oldest_messages(&sid, SLIDE_SIZE).await; + let current_summary = { + let db = state.db.lock().await; + db.query_row( + "SELECT summary FROM conversations WHERE session_id = ?1", + [&sid], + |row| row.get::<_, String>(0), + ) + .unwrap_or_default() + }; + + match summarize_messages( + endpoint, + model, + api_key, + ¤t_summary, + &to_summarize, + ) + .await + { + Ok(new_summary) => { + state.slide_window(&sid, &new_summary, SLIDE_SIZE).await; + let remaining = state.message_count(&sid).await; + info!(%sid, "window slid, {remaining} messages remain, summary {} chars", new_summary.len()); + } + Err(e) => { + warn!(%sid, "summarize failed: {e:#}, keeping all messages"); + } + } } - state.conversations.write().await.insert(sid.clone(), messages); } Err(e) => { error!(%sid, "openai: {e:#}"); @@ -401,9 +818,18 @@ async fn handle_inner( Ok(()) } -fn build_prompt(text: &str, uploaded: &[PathBuf], errors: &[String], out_dir: &Path) -> String { +fn build_prompt( + text: &str, + uploaded: &[PathBuf], + errors: &[String], + transcriptions: &[String], +) -> String { let mut parts = Vec::new(); + for t in transcriptions { + parts.push(format!("[语音消息] {t}")); + } + for f in uploaded { parts.push(format!("[用户上传了文件: {}]", f.display())); } @@ -416,14 +842,696 @@ fn build_prompt(text: &str, uploaded: &[PathBuf], errors: &[String], out_dir: &P parts.push(text.to_string()); } - parts.push(format!( - "\n[系统提示: 如果需要发送文件给用户,将文件写入 {} 目录]", - out_dir.display() - )); - parts.join("\n") } +async fn transcribe_audio(whisper_url: &str, file_path: &Path) -> Result { + let client = reqwest::Client::new(); + let url = format!("{}/v1/audio/transcriptions", whisper_url.trim_end_matches('/')); + let file_bytes = tokio::fs::read(file_path).await?; + let file_name = file_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("audio.ogg") + .to_string(); + let part = reqwest::multipart::Part::bytes(file_bytes) + .file_name(file_name) + .mime_str("audio/ogg")?; + let form = reqwest::multipart::Form::new() + .part("file", part) + .text("model", "base"); + let resp = client.post(&url).multipart(form).send().await?.error_for_status()?; + let json: serde_json::Value = resp.json().await?; + Ok(json["text"].as_str().unwrap_or("").to_string()) +} + +fn build_system_prompt(summary: &str, persona: &str) -> serde_json::Value { + let mut text = if persona.is_empty() { + String::from("你是一个AI助手。") + } else { + persona.to_string() + }; + + text.push_str( + "\n\n你可以使用提供的工具来完成任务。\ + 当需要执行命令、运行代码或启动复杂子任务时,直接调用对应的工具,不要只是描述你会怎么做。\ + 输出格式:使用纯文本或基础Markdown(加粗、列表、代码块)。\ + 不要使用LaTeX公式($...$)、特殊Unicode符号(→←↔)或HTML标签,Telegram无法渲染这些。", + ); + + if !summary.is_empty() { + text.push_str("\n\n## 之前的对话总结\n"); + text.push_str(summary); + } + + serde_json::json!({"role": "system", "content": text}) +} + +/// Build user message content, with optional images as multimodal input. +fn build_user_content( + text: &str, + scratch: &str, + images: &[PathBuf], +) -> serde_json::Value { + let full_text = if scratch.is_empty() { + text.to_string() + } else { + format!("{text}\n\n[scratch]\n{scratch}") + }; + + // collect image data + let mut image_parts: Vec = Vec::new(); + for path in images { + let mime = match path + .extension() + .and_then(|e| e.to_str()) + .map(|e| e.to_lowercase()) + .as_deref() + { + Some("jpg" | "jpeg") => "image/jpeg", + Some("png") => "image/png", + Some("gif") => "image/gif", + Some("webp") => "image/webp", + _ => continue, + }; + if let Ok(data) = std::fs::read(path) { + let b64 = base64::engine::general_purpose::STANDARD.encode(&data); + image_parts.push(serde_json::json!({ + "type": "image_url", + "image_url": {"url": format!("data:{mime};base64,{b64}")} + })); + } + } + + if image_parts.is_empty() { + // plain text — more compatible + serde_json::Value::String(full_text) + } else { + // multimodal array + let mut content = vec![serde_json::json!({"type": "text", "text": full_text})]; + content.extend(image_parts); + serde_json::Value::Array(content) + } +} + +async fn summarize_messages( + endpoint: &str, + model: &str, + api_key: &str, + existing_summary: &str, + dropped: &[serde_json::Value], +) -> Result { + let msgs_text: String = dropped + .iter() + .filter_map(|m| { + let role = m["role"].as_str()?; + let content = m["content"].as_str()?; + Some(format!("{role}: {content}")) + }) + .collect::>() + .join("\n\n"); + + let prompt = if existing_summary.is_empty() { + format!( + "请将以下对话总结为约4000字符的摘要,保留关键信息和上下文:\n\n{}", + msgs_text + ) + } else { + format!( + "请将以下新对话内容整合到现有总结中,保持总结在约4000字符以内。\ + 保留重要信息,让较旧的话题自然淡出。\n\n\ + 现有总结:\n{}\n\n新对话:\n{}", + existing_summary, msgs_text + ) + }; + + let client = reqwest::Client::new(); + let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + + let body = serde_json::json!({ + "model": model, + "messages": [ + {"role": "system", "content": "你是一个对话总结助手。请生成简洁但信息丰富的总结。"}, + {"role": "user", "content": prompt} + ], + }); + + let resp = client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .json(&body) + .send() + .await? + .error_for_status()?; + + let json: serde_json::Value = resp.json().await?; + let summary = json["choices"][0]["message"]["content"] + .as_str() + .unwrap_or("") + .to_string(); + + Ok(summary) +} + +// ── tool execution ───────────────────────────────────────────────── + +async fn execute_tool( + name: &str, + arguments: &str, + state: &Arc, + bot: &Bot, + chat_id: ChatId, + sid: &str, + config: &Arc, +) -> String { + let args: serde_json::Value = match serde_json::from_str(arguments) { + Ok(v) => v, + Err(e) => return format!("Invalid arguments: {e}"), + }; + + match name { + "spawn_agent" => { + let id = args["id"].as_str().unwrap_or("agent"); + let task = args["task"].as_str().unwrap_or(""); + spawn_agent(id, task, state, bot, chat_id, sid, config).await + } + "agent_status" => { + let id = args["id"].as_str().unwrap_or(""); + check_agent_status(id, state).await + } + "kill_agent" => { + let id = args["id"].as_str().unwrap_or(""); + kill_agent(id, state).await + } + "send_file" => { + let path_str = args["path"].as_str().unwrap_or(""); + let caption = args["caption"].as_str().unwrap_or(""); + let path = Path::new(path_str); + if !path.exists() { + return format!("File not found: {path_str}"); + } + if !path.is_file() { + return format!("Not a file: {path_str}"); + } + let input_file = InputFile::file(path); + let mut req = bot.send_document(chat_id, input_file); + if !caption.is_empty() { + req = req.caption(caption); + } + match req.await { + Ok(_) => format!("File sent: {path_str}"), + Err(e) => format!("Failed to send file: {e:#}"), + } + } + "update_scratch" => { + let content = args["content"].as_str().unwrap_or(""); + state.push_scratch(content).await; + format!("Scratch updated ({} chars)", content.len()) + } + _ => run_script_tool(name, arguments).await, + } +} + +async fn spawn_agent( + id: &str, + task: &str, + state: &Arc, + bot: &Bot, + chat_id: ChatId, + sid: &str, + config: &Arc, +) -> String { + // check if already exists + if state.agents.read().await.contains_key(id) { + return format!("Agent '{id}' already exists. Use agent_status to check it."); + } + + let mut child = match Command::new("claude") + .args(["--dangerously-skip-permissions", "-p", task]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + { + Ok(c) => c, + Err(e) => return format!("Failed to spawn agent: {e}"), + }; + + let pid = child.id(); + let output = Arc::new(tokio::sync::RwLock::new(String::new())); + let completed = Arc::new(AtomicBool::new(false)); + let exit_code = Arc::new(tokio::sync::RwLock::new(None)); + + let agent = Arc::new(SubAgent { + task: task.to_string(), + output: output.clone(), + completed: completed.clone(), + exit_code: exit_code.clone(), + pid, + }); + + state.agents.write().await.insert(id.to_string(), agent); + + // background task: collect output and wakeup on completion + let out = output.clone(); + let done = completed.clone(); + let ecode = exit_code.clone(); + let bot_c = bot.clone(); + let chat_id_c = chat_id; + let state_c = state.clone(); + let config_c = config.clone(); + let sid_c = sid.to_string(); + let id_c = id.to_string(); + + tokio::spawn(async move { + let stdout = child.stdout.take(); + if let Some(stdout) = stdout { + let mut lines = tokio::io::BufReader::new(stdout).lines(); + while let Ok(Some(line)) = lines.next_line().await { + let mut o = out.write().await; + o.push_str(&line); + o.push('\n'); + } + } + let status = child.wait().await; + let code = status.as_ref().ok().and_then(|s| s.code()); + *ecode.write().await = code; + done.store(true, Ordering::SeqCst); + + info!(agent = %id_c, "agent completed, exit={code:?}"); + + // wakeup: inject result and trigger LLM + let result = out.read().await.clone(); + let result_short = truncate_at_char_boundary(&result, 4000); + let wakeup = format!( + "[Agent '{id_c}' 执行完成 (exit={})]\n{result_short}", + code.unwrap_or(-1) + ); + + if let Err(e) = agent_wakeup( + &config_c, &state_c, &bot_c, chat_id_c, &sid_c, &wakeup, &id_c, + ) + .await + { + error!(agent = %id_c, "wakeup failed: {e:#}"); + let _ = bot_c + .send_message(chat_id_c, format!("[agent wakeup error] {e:#}")) + .await; + } + }); + + format!("Agent '{id}' spawned (pid={pid:?})") +} + +async fn agent_wakeup( + config: &Config, + state: &AppState, + bot: &Bot, + chat_id: ChatId, + sid: &str, + wakeup_msg: &str, + agent_id: &str, +) -> Result<()> { + match &config.backend { + BackendConfig::OpenAI { + endpoint, + model, + api_key, + } => { + state.push_message(sid, "user", wakeup_msg).await; + let conv = state.load_conv(sid).await; + let persona = state.get_config("persona").await.unwrap_or_default(); + let system_msg = build_system_prompt(&conv.summary, &persona); + let mut api_messages = vec![system_msg]; + api_messages.extend(conv.messages); + + info!(agent = %agent_id, "wakeup: sending {} messages to LLM", api_messages.len()); + + let response = + run_openai_streaming(endpoint, model, api_key, &api_messages, bot, chat_id) + .await?; + + if !response.is_empty() { + state.push_message(sid, "assistant", &response).await; + } + + Ok(()) + } + _ => { + let _ = bot + .send_message(chat_id, format!("[Agent '{agent_id}' done]\n{wakeup_msg}")) + .await; + Ok(()) + } + } +} + +async fn check_agent_status(id: &str, state: &AppState) -> String { + let agents = state.agents.read().await; + match agents.get(id) { + Some(agent) => { + let status = if agent.completed.load(Ordering::SeqCst) { + let code = agent.exit_code.read().await; + format!("completed (exit={})", code.unwrap_or(-1)) + } else { + "running".to_string() + }; + let output = agent.output.read().await; + let out_preview = truncate_at_char_boundary(&output, 3000); + format!( + "Agent '{id}': {status}\nTask: {}\nOutput ({} bytes):\n{out_preview}", + agent.task, + output.len() + ) + } + None => format!("Agent '{id}' not found"), + } +} + +async fn kill_agent(id: &str, state: &AppState) -> String { + let agents = state.agents.read().await; + match agents.get(id) { + Some(agent) => { + if agent.completed.load(Ordering::SeqCst) { + return format!("Agent '{id}' already completed"); + } + if let Some(pid) = agent.pid { + unsafe { + libc::kill(pid as i32, libc::SIGTERM); + } + format!("Sent SIGTERM to agent '{id}' (pid={pid})") + } else { + format!("Agent '{id}' has no PID") + } + } + None => format!("Agent '{id}' not found"), + } +} + +async fn run_script_tool(name: &str, arguments: &str) -> String { + // find script in tools/ that matches this tool name + let dir = tools_dir(); + let entries = match std::fs::read_dir(&dir) { + Ok(e) => e, + Err(_) => return format!("Unknown tool: {name}"), + }; + + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_file() { + continue; + } + // check if this script provides the requested tool + let schema_out = std::process::Command::new(&path) + .arg("--schema") + .output(); + if let Ok(out) = schema_out { + if out.status.success() { + let stdout = String::from_utf8_lossy(&out.stdout); + if let Ok(schema) = serde_json::from_str::(stdout.trim()) { + if schema["name"].as_str() == Some(name) { + // found it — execute + info!(tool = %name, path = %path.display(), "running script tool"); + let result = tokio::time::timeout( + std::time::Duration::from_secs(60), + Command::new(&path).arg(arguments).output(), + ) + .await; + + return match result { + Ok(Ok(output)) => { + let mut s = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr); + if !stderr.is_empty() { + if !s.is_empty() { + s.push_str("\n[stderr]\n"); + } + s.push_str(&stderr); + } + if s.is_empty() { + format!("(exit={})", output.status.code().unwrap_or(-1)) + } else { + s + } + } + Ok(Err(e)) => format!("Failed to execute {name}: {e}"), + Err(_) => "Timeout after 60s".to_string(), + }; + } + } + } + } + } + + format!("Unknown tool: {name}") +} + +// ── openai with tool call loop ───────────────────────────────────── + +#[allow(clippy::too_many_arguments)] +async fn run_openai_with_tools( + endpoint: &str, + model: &str, + api_key: &str, + mut messages: Vec, + bot: &Bot, + chat_id: ChatId, + state: &Arc, + sid: &str, + config: &Arc, + is_private: bool, +) -> Result { + let client = reqwest::Client::new(); + let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + let tools = discover_tools(); + + loop { + let body = serde_json::json!({ + "model": model, + "messages": messages, + "tools": tools, + "stream": true, + }); + + info!("API request: {} messages, {} tools", + messages.len(), + tools.as_array().map(|a| a.len()).unwrap_or(0)); + + let resp_raw = client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .json(&body) + .send() + .await?; + + if !resp_raw.status().is_success() { + let status = resp_raw.status(); + let body_text = resp_raw.text().await.unwrap_or_default(); + // dump messages for debugging + for (i, m) in messages.iter().enumerate() { + let role = m["role"].as_str().unwrap_or("?"); + let content_len = m["content"].as_str().map(|s| s.len()).unwrap_or(0); + let has_tc = m.get("tool_calls").is_some(); + let has_tcid = m.get("tool_call_id").is_some(); + warn!(" msg[{i}] role={role} content_len={content_len} tool_calls={has_tc} tool_call_id={has_tcid}"); + } + error!("OpenAI API {status}: {body_text}"); + anyhow::bail!("OpenAI API {status}: {body_text}"); + } + + let mut resp = resp_raw; + + let token = bot.token().to_owned(); + let raw_chat_id = chat_id.0; + let draft_id: i64 = 1; + let mut use_draft = is_private; // sendMessageDraft only works in private chats + + let mut msg_id: Option = None; + let mut accumulated = String::new(); + let mut last_edit = Instant::now(); + let mut buffer = String::new(); + let mut done = false; + + // tool call accumulation + let mut tool_calls: Vec = Vec::new(); + let mut has_tool_calls = false; + + while let Some(chunk) = resp.chunk().await? { + if done { + break; + } + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(pos) = buffer.find('\n') { + let line = buffer[..pos].to_string(); + buffer = buffer[pos + 1..].to_string(); + + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with(':') { + continue; + } + + let data = match trimmed.strip_prefix("data: ") { + Some(d) => d, + None => continue, + }; + + if data.trim() == "[DONE]" { + done = true; + break; + } + + if let Ok(json) = serde_json::from_str::(data) { + let delta = &json["choices"][0]["delta"]; + + // handle content delta + if let Some(content) = delta["content"].as_str() { + if !content.is_empty() { + accumulated.push_str(content); + } + } + + // handle tool call delta + if let Some(tc_arr) = delta["tool_calls"].as_array() { + has_tool_calls = true; + for tc in tc_arr { + let idx = tc["index"].as_u64().unwrap_or(0) as usize; + while tool_calls.len() <= idx { + tool_calls.push(ToolCall { + id: String::new(), + name: String::new(), + arguments: String::new(), + }); + } + if let Some(id) = tc["id"].as_str() { + tool_calls[idx].id = id.to_string(); + } + if let Some(name) = tc["function"]["name"].as_str() { + tool_calls[idx].name = name.to_string(); + } + if let Some(args) = tc["function"]["arguments"].as_str() { + tool_calls[idx].arguments.push_str(args); + } + } + } + + // display update (only when there's content to show) + if accumulated.is_empty() { + continue; + } + + { + + let interval = if use_draft { + DRAFT_INTERVAL_MS + } else { + EDIT_INTERVAL_MS + }; + if last_edit.elapsed().as_millis() < interval as u128 { + continue; + } + + let display = if use_draft { + truncate_at_char_boundary(&accumulated, TG_MSG_LIMIT).to_string() + } else { + truncate_for_display(&accumulated) + }; + + if use_draft { + match send_message_draft( + &client, &token, raw_chat_id, draft_id, &display, + ) + .await + { + Ok(_) => { + last_edit = Instant::now(); + } + Err(e) => { + warn!("sendMessageDraft failed, falling back: {e:#}"); + use_draft = false; + if let Ok(sent) = + bot.send_message(chat_id, &display).await + { + msg_id = Some(sent.id); + last_edit = Instant::now(); + } + } + } + } else if let Some(id) = msg_id { + if bot + .edit_message_text(chat_id, id, &display) + .await + .is_ok() + { + last_edit = Instant::now(); + } + } else if let Ok(sent) = + bot.send_message(chat_id, &display).await + { + msg_id = Some(sent.id); + last_edit = Instant::now(); + } + } // end display block + } + } + } + + // decide what to do based on response type + if has_tool_calls && !tool_calls.is_empty() { + // append assistant message with tool calls + let tc_json: Vec = tool_calls + .iter() + .map(|tc| { + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": tc.arguments, + } + }) + }) + .collect(); + + let assistant_msg = serde_json::json!({ + "role": "assistant", + "content": if accumulated.is_empty() { "" } else { &accumulated }, + "tool_calls": tc_json, + }); + messages.push(assistant_msg); + + // execute each tool + for tc in &tool_calls { + info!(tool = %tc.name, "executing tool call"); + let _ = bot + .send_message(chat_id, format!("[{}({})]", tc.name, truncate_at_char_boundary(&tc.arguments, 100))) + .await; + + let result = + execute_tool(&tc.name, &tc.arguments, state, bot, chat_id, sid, config) + .await; + + messages.push(serde_json::json!({ + "role": "tool", + "tool_call_id": tc.id, + "content": result, + })); + } + + // clear display state for next round + tool_calls.clear(); + // loop back to call API again + continue; + } + + // content response — send final result + if !accumulated.is_empty() { + send_final_result(bot, chat_id, msg_id, use_draft, &accumulated).await; + } + + return Ok(accumulated); + } +} + // ── claude bridge (streaming) ─────────────────────────────────────── /// Stream JSON event types we care about. @@ -952,8 +2060,8 @@ async fn send_final_result( let html = markdown_to_telegram_html(result); // try HTML as single message - let html_ok = if !use_draft && msg_id.is_some() { - bot.edit_message_text(chat_id, msg_id.unwrap(), &html) + let html_ok = if let (false, Some(id)) = (use_draft, msg_id) { + bot.edit_message_text(chat_id, id, &html) .parse_mode(ParseMode::Html) .await .is_ok() @@ -970,8 +2078,7 @@ async fn send_final_result( // fallback: plain text with chunking let chunks = split_msg(result, TG_MSG_LIMIT); - if !use_draft && msg_id.is_some() { - let id = msg_id.unwrap(); + if let (false, Some(id)) = (use_draft, msg_id) { let _ = bot.edit_message_text(chat_id, id, chunks[0]).await; for chunk in &chunks[1..] { let _ = bot.send_message(chat_id, *chunk).await; diff --git a/tests/tool_call.rs b/tests/tool_call.rs new file mode 100644 index 0000000..28e7caa --- /dev/null +++ b/tests/tool_call.rs @@ -0,0 +1,361 @@ +//! Integration test: verify tool call round-trip with Ollama's OpenAI-compatible API. +//! Requires Ollama running at OLLAMA_URL (default: http://100.84.7.49:11434). + +use serde_json::json; + +const OLLAMA_URL: &str = "http://100.84.7.49:11434/v1"; +const MODEL: &str = "gemma4:31b"; + +fn tools() -> serde_json::Value { + json!([{ + "type": "function", + "function": { + "name": "calculator", + "description": "Calculate a math expression", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "Math expression to evaluate"} + }, + "required": ["expression"] + } + } + }]) +} + +/// Test non-streaming tool call round-trip +#[tokio::test] +async fn test_tool_call_roundtrip_non_streaming() { + let client = reqwest::Client::new(); + let url = format!("{OLLAMA_URL}/chat/completions"); + + // Round 1: ask the model to use the calculator + let body = json!({ + "model": MODEL, + "messages": [ + {"role": "user", "content": "What is 2+2? Use the calculator tool."} + ], + "tools": tools(), + }); + + let resp = client.post(&url).json(&body).send().await.unwrap(); + assert!(resp.status().is_success(), "Round 1 failed: {}", resp.status()); + + let result: serde_json::Value = resp.json().await.unwrap(); + let choice = &result["choices"][0]; + assert_eq!( + choice["finish_reason"].as_str().unwrap(), + "tool_calls", + "Expected tool_calls finish_reason, got: {choice}" + ); + + let tool_calls = choice["message"]["tool_calls"].as_array().unwrap(); + assert!(!tool_calls.is_empty(), "No tool calls returned"); + + let tc = &tool_calls[0]; + let call_id = tc["id"].as_str().unwrap(); + let func_name = tc["function"]["name"].as_str().unwrap(); + assert_eq!(func_name, "calculator"); + + // Round 2: send tool result back + let body2 = json!({ + "model": MODEL, + "messages": [ + {"role": "user", "content": "What is 2+2? Use the calculator tool."}, + { + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": func_name, + "arguments": tc["function"]["arguments"].as_str().unwrap() + } + }] + }, + { + "role": "tool", + "tool_call_id": call_id, + "content": "4" + } + ], + "tools": tools(), + }); + + let resp2 = client.post(&url).json(&body2).send().await.unwrap(); + let status2 = resp2.status(); + let body2_text = resp2.text().await.unwrap(); + assert!( + status2.is_success(), + "Round 2 failed ({status2}): {body2_text}" + ); + + let result2: serde_json::Value = serde_json::from_str(&body2_text).unwrap(); + let content = result2["choices"][0]["message"]["content"] + .as_str() + .unwrap_or(""); + assert!(!content.is_empty(), "Expected content in round 2 response"); + println!("Round 2 response: {content}"); +} + +/// Test tool call with conversation history (simulates real scenario) +#[tokio::test] +async fn test_tool_call_with_history() { + let client = reqwest::Client::new(); + let url = format!("{OLLAMA_URL}/chat/completions"); + + // Simulate real message history with system prompt + let body = json!({ + "model": MODEL, + "stream": true, + "messages": [ + {"role": "system", "content": "你是一个AI助手。你可以使用提供的工具来完成任务。当需要执行命令、运行代码或启动复杂子任务时,直接调用对应的工具,不要只是描述你会怎么做。"}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "What is 3+4? Use the calculator."} + ], + "tools": tools(), + }); + + // Round 1: expect tool call + let mut resp = client.post(&url).json(&body).send().await.unwrap(); + assert!(resp.status().is_success(), "Round 1 failed: {}", resp.status()); + + let mut buffer = String::new(); + let mut tc_id = String::new(); + let mut tc_name = String::new(); + let mut tc_args = String::new(); + let mut has_tc = false; + + while let Some(chunk) = resp.chunk().await.unwrap() { + buffer.push_str(&String::from_utf8_lossy(&chunk)); + while let Some(pos) = buffer.find('\n') { + let line = buffer[..pos].to_string(); + buffer = buffer[pos + 1..].to_string(); + if let Some(data) = line.trim().strip_prefix("data: ") { + if data.trim() == "[DONE]" { break; } + if let Ok(j) = serde_json::from_str::(data) { + if let Some(arr) = j["choices"][0]["delta"]["tool_calls"].as_array() { + has_tc = true; + for tc in arr { + if let Some(id) = tc["id"].as_str() { tc_id = id.into(); } + if let Some(n) = tc["function"]["name"].as_str() { tc_name = n.into(); } + if let Some(a) = tc["function"]["arguments"].as_str() { tc_args.push_str(a); } + } + } + } + } + } + } + + assert!(has_tc, "Expected tool call, got content only"); + println!("Tool: {tc_name}({tc_args}) id={tc_id}"); + + // Round 2: tool result → expect content + let body2 = json!({ + "model": MODEL, + "stream": true, + "messages": [ + {"role": "system", "content": "你是一个AI助手。"}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "What is 3+4? Use the calculator."}, + {"role": "assistant", "content": "", "tool_calls": [{"id": tc_id, "type": "function", "function": {"name": tc_name, "arguments": tc_args}}]}, + {"role": "tool", "tool_call_id": tc_id, "content": "7"} + ], + "tools": tools(), + }); + + let resp2 = client.post(&url).json(&body2).send().await.unwrap(); + let status = resp2.status(); + if !status.is_success() { + let err = resp2.text().await.unwrap(); + panic!("Round 2 failed ({status}): {err}"); + } + + let mut resp2 = client.post(&url).json(&body2).send().await.unwrap(); + let mut content = String::new(); + let mut buf2 = String::new(); + while let Some(chunk) = resp2.chunk().await.unwrap() { + buf2.push_str(&String::from_utf8_lossy(&chunk)); + while let Some(pos) = buf2.find('\n') { + let line = buf2[..pos].to_string(); + buf2 = buf2[pos + 1..].to_string(); + if let Some(data) = line.trim().strip_prefix("data: ") { + if data.trim() == "[DONE]" { break; } + if let Ok(j) = serde_json::from_str::(data) { + if let Some(c) = j["choices"][0]["delta"]["content"].as_str() { + content.push_str(c); + } + } + } + } + } + + println!("Final response: {content}"); + assert!(!content.is_empty(), "Expected non-empty content in round 2"); +} + +/// Test multimodal image input +#[tokio::test] +async fn test_image_multimodal() { + let client = reqwest::Client::new(); + let url = format!("{OLLAMA_URL}/chat/completions"); + + // 2x2 red PNG generated by PIL + let b64 = "iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAIAAAD91JpzAAAAFklEQVR4nGP8z8DAwMDAxMDAwMDAAAANHQEDasKb6QAAAABJRU5ErkJggg=="; + + let body = json!({ + "model": MODEL, + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "What color is this image? Reply with just the color name."}, + {"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}} + ] + }], + }); + + let resp = client.post(&url).json(&body).send().await.unwrap(); + let status = resp.status(); + let text = resp.text().await.unwrap(); + assert!(status.is_success(), "Multimodal request failed ({status}): {text}"); + + let result: serde_json::Value = serde_json::from_str(&text).unwrap(); + let content = result["choices"][0]["message"]["content"] + .as_str() + .unwrap_or(""); + println!("Image description: {content}"); + assert!(!content.is_empty(), "Expected non-empty response for image"); +} + +/// Test streaming tool call round-trip (matches our actual code path) +#[tokio::test] +async fn test_tool_call_roundtrip_streaming() { + let client = reqwest::Client::new(); + let url = format!("{OLLAMA_URL}/chat/completions"); + + // Round 1: streaming, get tool calls + let body = json!({ + "model": MODEL, + "stream": true, + "messages": [ + {"role": "user", "content": "What is 7*6? Use the calculator tool."} + ], + "tools": tools(), + }); + + let mut resp = client.post(&url).json(&body).send().await.unwrap(); + assert!(resp.status().is_success(), "Round 1 streaming failed"); + + // Parse SSE to extract tool calls + let mut buffer = String::new(); + let mut tool_call_id = String::new(); + let mut tool_call_name = String::new(); + let mut tool_call_args = String::new(); + let mut has_tool_calls = false; + + while let Some(chunk) = resp.chunk().await.unwrap() { + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(pos) = buffer.find('\n') { + let line = buffer[..pos].to_string(); + buffer = buffer[pos + 1..].to_string(); + + let trimmed = line.trim(); + if let Some(data) = trimmed.strip_prefix("data: ") { + if data.trim() == "[DONE]" { + break; + } + if let Ok(json) = serde_json::from_str::(data) { + let delta = &json["choices"][0]["delta"]; + if let Some(tc_arr) = delta["tool_calls"].as_array() { + has_tool_calls = true; + for tc in tc_arr { + if let Some(id) = tc["id"].as_str() { + tool_call_id = id.to_string(); + } + if let Some(name) = tc["function"]["name"].as_str() { + tool_call_name = name.to_string(); + } + if let Some(args) = tc["function"]["arguments"].as_str() { + tool_call_args.push_str(args); + } + } + } + } + } + } + } + + assert!(has_tool_calls, "No tool calls in streaming response"); + assert_eq!(tool_call_name, "calculator"); + println!("Tool call: {tool_call_name}({tool_call_args}) id={tool_call_id}"); + + // Round 2: send tool result, streaming + let body2 = json!({ + "model": MODEL, + "stream": true, + "messages": [ + {"role": "user", "content": "What is 7*6? Use the calculator tool."}, + { + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call_name, + "arguments": tool_call_args + } + }] + }, + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": "42" + } + ], + "tools": tools(), + }); + + let resp2 = client.post(&url).json(&body2).send().await.unwrap(); + let status2 = resp2.status(); + if !status2.is_success() { + let err = resp2.text().await.unwrap(); + panic!("Round 2 streaming failed ({status2}): {err}"); + } + + // Collect content from streaming response + let mut resp2 = client + .post(&url) + .json(&body2) + .send() + .await + .unwrap(); + let mut content = String::new(); + let mut buffer2 = String::new(); + + while let Some(chunk) = resp2.chunk().await.unwrap() { + buffer2.push_str(&String::from_utf8_lossy(&chunk)); + while let Some(pos) = buffer2.find('\n') { + let line = buffer2[..pos].to_string(); + buffer2 = buffer2[pos + 1..].to_string(); + let trimmed = line.trim(); + if let Some(data) = trimmed.strip_prefix("data: ") { + if data.trim() == "[DONE]" { + break; + } + if let Ok(json) = serde_json::from_str::(data) { + if let Some(c) = json["choices"][0]["delta"]["content"].as_str() { + content.push_str(c); + } + } + } + } + } + + assert!(!content.is_empty(), "Expected content in round 2 streaming"); + println!("Round 2 streaming content: {content}"); +} diff --git a/tools/manage_todo b/tools/manage_todo new file mode 100755 index 0000000..ca18c5d --- /dev/null +++ b/tools/manage_todo @@ -0,0 +1,187 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = ">=3.11" +# dependencies = ["requests"] +# /// +"""Feishu Bitable todo manager. + +Usage: + ./fam-todo.py list-undone List open todos + ./fam-todo.py list-done List completed todos + ./fam-todo.py add Add a new todo + ./fam-todo.py mark-done <record_id> Mark as done + ./fam-todo.py mark-undone <record_id> Mark as undone + ./fam-todo.py --schema Print tool schema JSON +""" + +import json +import sys +import requests + +APP_ID = "cli_a7f042e93d385013" +APP_SECRET = "ht4FCjQ8JJ65ZPUWlff6ldFBmaP0mxqY" +APP_TOKEN = "SSoGbmGFoazJkUs7bbfcaSG8n7f" +TABLE_ID = "tblIA2biceDpvr35" +BASE_URL = "https://open.feishu.cn/open-apis" + +ACTIONS = ["list-undone", "list-done", "add", "mark-done", "mark-undone"] + +SCHEMA = { + "name": "fam_todo", + "description": "管理 Fam 的飞书待办事项表格。", + "parameters": { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ACTIONS, + "description": "操作类型", + }, + "title": { + "type": "string", + "description": "待办标题 (add 时必填)", + }, + "record_id": { + "type": "string", + "description": "记录ID (mark-done/mark-undone 时必填)", + }, + }, + "required": ["action"], + }, +} + + +def get_token(): + r = requests.post( + f"{BASE_URL}/auth/v3/tenant_access_token/internal/", + json={"app_id": APP_ID, "app_secret": APP_SECRET}, + ) + r.raise_for_status() + return r.json()["tenant_access_token"] + + +def headers(): + return {"Authorization": f"Bearer {get_token()}", "Content-Type": "application/json"} + + +def api(method, path, **kwargs): + url = f"{BASE_URL}/bitable/v1/apps/{APP_TOKEN}/tables/{TABLE_ID}{path}" + r = requests.request(method, url, headers=headers(), **kwargs) + r.raise_for_status() + return r.json() + + +def format_field(v): + if isinstance(v, list): + return "".join( + seg.get("text", str(seg)) if isinstance(seg, dict) else str(seg) + for seg in v + ) + return str(v) + + +def list_records(done_filter): + """List records. done_filter: True=done only, False=undone only.""" + data = api("GET", "/records", params={"page_size": 500}) + items = data.get("data", {}).get("items", []) + if not items: + return "No records found." + + lines = [] + for item in items: + fields = item.get("fields", {}) + is_done = bool(fields.get("Done")) + if is_done != done_filter: + continue + + rid = item["record_id"] + title = format_field(fields.get("Item", "")) + priority = fields.get("Priority", "") + notes = format_field(fields.get("Notes", "")) + + parts = [f"[{rid}] {title}"] + if priority: + parts.append(f" P: {priority}") + if notes: + preview = notes[:80].replace("\n", " ") + parts.append(f" Note: {preview}") + lines.append("\n".join(parts)) + + if not lines: + label = "completed" if done_filter else "open" + return f"No {label} todos." + return "\n".join(lines) + + +def add_record(title): + data = api("POST", "/records", json={"fields": {"Item": title}}) + rid = data.get("data", {}).get("record", {}).get("record_id", "?") + return f"Added [{rid}]: {title}" + + +def mark_done(record_id): + api("PUT", f"/records/{record_id}", json={"fields": {"Done": True}}) + return f"Marked [{record_id}] as done" + + +def mark_undone(record_id): + api("PUT", f"/records/{record_id}", json={"fields": {"Done": False}}) + return f"Marked [{record_id}] as undone" + + +def main(): + if len(sys.argv) < 2 or sys.argv[1] in ("--help", "-h"): + print(__doc__.strip()) + sys.exit(0) + + if sys.argv[1] == "--schema": + print(json.dumps(SCHEMA, ensure_ascii=False)) + sys.exit(0) + + arg = sys.argv[1] + if not arg.startswith("{"): + args = {"action": arg} + if len(sys.argv) > 2: + args["title"] = " ".join(sys.argv[2:]) + args["record_id"] = sys.argv[2] # also set record_id for mark-* + else: + try: + args = json.loads(arg) + except json.JSONDecodeError as e: + print(f"Invalid JSON: {e}") + sys.exit(1) + + action = args.get("action", "") + try: + if action == "list-undone": + print(list_records(done_filter=False)) + elif action == "list-done": + print(list_records(done_filter=True)) + elif action == "add": + title = args.get("title", "") + if not title: + print("Error: title is required") + sys.exit(1) + print(add_record(title)) + elif action == "mark-done": + rid = args.get("record_id", "") + if not rid: + print("Error: record_id is required") + sys.exit(1) + print(mark_done(rid)) + elif action == "mark-undone": + rid = args.get("record_id", "") + if not rid: + print("Error: record_id is required") + sys.exit(1) + print(mark_undone(rid)) + else: + print(f"Unknown action: {action}. Valid: {', '.join(ACTIONS)}") + sys.exit(1) + except Exception as e: + print(f"Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main()