noc/src/main.rs

655 lines
20 KiB
Rust

use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use anyhow::Result;
use chrono::{Local, NaiveDate, NaiveTime};
use serde::{Deserialize, Serialize};
use teloxide::dispatching::UpdateFilterExt;
use teloxide::net::Download;
use teloxide::prelude::*;
use teloxide::types::InputFile;
use tokio::io::AsyncBufReadExt;
use tokio::process::Command;
use tokio::sync::RwLock;
use tokio::time::Instant;
use tracing::{error, info, warn};
use uuid::Uuid;
// ── config ──────────────────────────────────────────────────────────
#[derive(Deserialize)]
struct Config {
tg: TgConfig,
auth: AuthConfig,
session: SessionConfig,
}
#[derive(Deserialize)]
struct TgConfig {
key: String,
}
#[derive(Deserialize)]
struct AuthConfig {
passphrase: String,
}
#[derive(Deserialize)]
struct SessionConfig {
refresh_hour: u32,
}
// ── persistent state ────────────────────────────────────────────────
#[derive(Serialize, Deserialize, Default)]
struct Persistent {
authed: HashMap<i64, NaiveDate>,
known_sessions: HashSet<String>,
}
struct AppState {
persist: RwLock<Persistent>,
state_path: PathBuf,
}
impl AppState {
fn load(path: PathBuf) -> Self {
let persist = std::fs::read_to_string(&path)
.ok()
.and_then(|s| serde_json::from_str(&s).ok())
.unwrap_or_default();
info!("loaded state from {}", path.display());
Self {
persist: RwLock::new(persist),
state_path: path,
}
}
async fn save(&self) {
let data = self.persist.read().await;
if let Ok(json) = serde_json::to_string_pretty(&*data) {
if let Err(e) = std::fs::write(&self.state_path, json) {
error!("save state: {e}");
}
}
}
}
// ── helpers ─────────────────────────────────────────────────────────
fn session_date(refresh_hour: u32) -> NaiveDate {
let now = Local::now();
let refresh = NaiveTime::from_hms_opt(refresh_hour, 0, 0).unwrap();
if now.time() < refresh {
now.date_naive() - chrono::Duration::days(1)
} else {
now.date_naive()
}
}
fn session_uuid(chat_id: i64, refresh_hour: u32) -> String {
let date = session_date(refresh_hour);
let name = format!("noc-{}-{}", chat_id, date.format("%Y%m%d"));
Uuid::new_v5(&Uuid::NAMESPACE_OID, name.as_bytes()).to_string()
}
fn home_dir() -> PathBuf {
PathBuf::from(std::env::var("HOME").unwrap_or_else(|_| "/tmp".into()))
}
fn incoming_dir() -> PathBuf {
home_dir().join("incoming")
}
fn outgoing_dir(sid: &str) -> PathBuf {
home_dir().join("outgoing").join(sid)
}
// ── main ────────────────────────────────────────────────────────────
#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("noc=info".parse().unwrap()),
)
.init();
let config_path = std::env::var("NOC_CONFIG").unwrap_or_else(|_| "config.yaml".into());
let raw = std::fs::read_to_string(&config_path)
.unwrap_or_else(|e| panic!("read {config_path}: {e}"));
let config: Config =
serde_yaml::from_str(&raw).unwrap_or_else(|e| panic!("parse config: {e}"));
let state_path = std::env::var("NOC_STATE")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("state.json"));
let state = Arc::new(AppState::load(state_path));
let _ = std::fs::create_dir_all(incoming_dir());
info!("noc bot starting");
let bot = Bot::new(&config.tg.key);
let handler = Update::filter_message().endpoint(handle);
Dispatcher::builder(bot, handler)
.dependencies(dptree::deps![state, Arc::new(config)])
.default_handler(|_| async {})
.build()
.dispatch()
.await;
}
// ── file download ───────────────────────────────────────────────────
async fn download_tg_file(bot: &Bot, file_id: &str, filename: &str) -> Result<PathBuf> {
let dir = incoming_dir();
tokio::fs::create_dir_all(&dir).await?;
let dest = dir.join(filename);
let tf = bot.get_file(file_id).await?;
let mut file = tokio::fs::File::create(&dest).await?;
bot.download_file(&tf.path, &mut file).await?;
info!("downloaded {} -> {}", filename, dest.display());
Ok(dest)
}
// ── outgoing scan ───────────────────────────────────────────────────
async fn snapshot_dir(dir: &Path) -> HashSet<PathBuf> {
let mut set = HashSet::new();
if let Ok(mut entries) = tokio::fs::read_dir(dir).await {
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
if path.is_file() {
set.insert(path);
}
}
}
set
}
async fn new_files_in(dir: &Path, before: &HashSet<PathBuf>) -> Vec<PathBuf> {
let mut files = Vec::new();
if let Ok(mut entries) = tokio::fs::read_dir(dir).await {
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
if path.is_file() && !before.contains(&path) {
files.push(path);
}
}
}
files.sort();
files
}
// ── handler ─────────────────────────────────────────────────────────
async fn handle(
bot: Bot,
msg: Message,
state: Arc<AppState>,
config: Arc<Config>,
) -> ResponseResult<()> {
let chat_id = msg.chat.id;
let text = msg.text().or(msg.caption()).unwrap_or("").to_string();
let raw_id = chat_id.0;
let date = session_date(config.session.refresh_hour);
let is_authed = {
let p = state.persist.read().await;
p.authed.get(&raw_id) == Some(&date)
};
if !is_authed {
if text.trim() == config.auth.passphrase {
{
let mut p = state.persist.write().await;
p.authed.insert(raw_id, date);
}
state.save().await;
bot.send_message(chat_id, "authenticated").await?;
info!(chat = raw_id, "authed");
} else {
bot.send_message(chat_id, "not authenticated").await?;
}
return Ok(());
}
if let Err(e) = handle_inner(&bot, &msg, chat_id, &text, &state, &config).await {
error!(chat = raw_id, "handle: {e:#}");
let _ = bot.send_message(chat_id, format!("[error] {e:#}")).await;
}
Ok(())
}
async fn handle_inner(
bot: &Bot,
msg: &Message,
chat_id: ChatId,
text: &str,
state: &Arc<AppState>,
config: &Arc<Config>,
) -> Result<()> {
let mut uploaded: Vec<PathBuf> = Vec::new();
let mut download_errors: Vec<String> = Vec::new();
if let Some(doc) = msg.document() {
let name = doc.file_name.as_deref().unwrap_or("file");
match download_tg_file(bot, &doc.file.id, name).await {
Ok(p) => uploaded.push(p),
Err(e) => download_errors.push(format!("{name}: {e:#}")),
}
}
if let Some(photos) = msg.photo() {
if let Some(photo) = photos.last() {
let name = format!("photo_{}.jpg", Local::now().format("%H%M%S"));
match download_tg_file(bot, &photo.file.id, &name).await {
Ok(p) => uploaded.push(p),
Err(e) => download_errors.push(format!("photo: {e:#}")),
}
}
}
if let Some(audio) = msg.audio() {
let fallback = format!("audio_{}.ogg", Local::now().format("%H%M%S"));
let name = audio.file_name.as_deref().unwrap_or(&fallback);
match download_tg_file(bot, &audio.file.id, name).await {
Ok(p) => uploaded.push(p),
Err(e) => download_errors.push(format!("audio: {e:#}")),
}
}
if let Some(voice) = msg.voice() {
let name = format!("voice_{}.ogg", Local::now().format("%H%M%S"));
match download_tg_file(bot, &voice.file.id, &name).await {
Ok(p) => uploaded.push(p),
Err(e) => download_errors.push(format!("voice: {e:#}")),
}
}
if let Some(video) = msg.video() {
let fallback = format!("video_{}.mp4", Local::now().format("%H%M%S"));
let name = video.file_name.as_deref().unwrap_or(&fallback);
match download_tg_file(bot, &video.file.id, name).await {
Ok(p) => uploaded.push(p),
Err(e) => download_errors.push(format!("video: {e:#}")),
}
}
if let Some(vn) = msg.video_note() {
let name = format!("videonote_{}.mp4", Local::now().format("%H%M%S"));
match download_tg_file(bot, &vn.file.id, &name).await {
Ok(p) => uploaded.push(p),
Err(e) => download_errors.push(format!("video_note: {e:#}")),
}
}
if text.is_empty() && uploaded.is_empty() {
if !download_errors.is_empty() {
let err_msg = format!("[文件下载失败]\n{}", download_errors.join("\n"));
bot.send_message(chat_id, err_msg).await?;
}
return Ok(());
}
let sid = session_uuid(chat_id.0, config.session.refresh_hour);
info!(%sid, "recv");
let out_dir = outgoing_dir(&sid);
tokio::fs::create_dir_all(&out_dir).await?;
let before = snapshot_dir(&out_dir).await;
let prompt = build_prompt(text, &uploaded, &download_errors, &out_dir);
let known = state.persist.read().await.known_sessions.contains(&sid);
let result = invoke_claude_streaming(&sid, &prompt, known, bot, chat_id).await;
match &result {
Ok(_) => {
if !known {
state.persist.write().await.known_sessions.insert(sid.clone());
state.save().await;
}
}
Err(e) => {
error!(%sid, "claude: {e:#}");
let _ = bot.send_message(chat_id, format!("[error] {e:#}")).await;
}
}
// send new files from outgoing dir
let new_files = new_files_in(&out_dir, &before).await;
for path in &new_files {
info!(%sid, "sending file: {}", path.display());
if let Err(e) = bot.send_document(chat_id, InputFile::file(path)).await {
error!(%sid, "send_document {}: {e:#}", path.display());
let _ = bot
.send_message(chat_id, format!("[发送文件失败: {}]", path.display()))
.await;
}
}
Ok(())
}
fn build_prompt(text: &str, uploaded: &[PathBuf], errors: &[String], out_dir: &Path) -> String {
let mut parts = Vec::new();
for f in uploaded {
parts.push(format!("[用户上传了文件: {}]", f.display()));
}
for e in errors {
parts.push(format!("[文件下载失败: {e}]"));
}
if !text.is_empty() {
parts.push(text.to_string());
}
parts.push(format!(
"\n[系统提示: 如果需要发送文件给用户,将文件写入 {} 目录]",
out_dir.display()
));
parts.join("\n")
}
// ── claude bridge (streaming) ───────────────────────────────────────
/// Stream JSON event types we care about.
#[derive(Deserialize)]
struct StreamEvent {
#[serde(rename = "type")]
event_type: String,
message: Option<AssistantMessage>,
result: Option<String>,
#[serde(default)]
is_error: bool,
}
#[derive(Deserialize)]
struct AssistantMessage {
content: Vec<ContentBlock>,
}
#[derive(Deserialize)]
struct ContentBlock {
#[serde(rename = "type")]
block_type: String,
text: Option<String>,
name: Option<String>,
input: Option<serde_json::Value>,
}
/// Extract all text from an assistant message's content blocks.
fn extract_text(msg: &AssistantMessage) -> String {
msg.content
.iter()
.filter(|b| b.block_type == "text")
.filter_map(|b| b.text.as_deref())
.collect::<Vec<_>>()
.join("")
}
/// Extract tool use status line, e.g. "Bash: echo hello"
fn extract_tool_use(msg: &AssistantMessage) -> Option<String> {
for block in &msg.content {
if block.block_type == "tool_use" {
let name = block.name.as_deref().unwrap_or("tool");
let detail = block
.input
.as_ref()
.and_then(|v| {
// try common fields: command, pattern, file_path, query
v.get("command")
.or(v.get("pattern"))
.or(v.get("file_path"))
.or(v.get("query"))
.or(v.get("prompt"))
.and_then(|s| s.as_str())
})
.unwrap_or("");
let detail_short = truncate_at_char_boundary(detail, 80);
return Some(format!("{name}: {detail_short}"));
}
}
None
}
const EDIT_INTERVAL_MS: u64 = 5000;
const TG_MSG_LIMIT: usize = 4096;
async fn invoke_claude_streaming(
sid: &str,
prompt: &str,
known: bool,
bot: &Bot,
chat_id: ChatId,
) -> Result<String> {
if known {
return run_claude_streaming(&["--resume", sid], prompt, bot, chat_id).await;
}
match run_claude_streaming(&["--resume", sid], prompt, bot, chat_id).await {
Ok(out) => {
info!(%sid, "resumed existing session");
Ok(out)
}
Err(e) => {
warn!(%sid, "resume failed ({e:#}), creating new session");
run_claude_streaming(&["--session-id", sid], prompt, bot, chat_id).await
}
}
}
async fn run_claude_streaming(
extra_args: &[&str],
prompt: &str,
bot: &Bot,
chat_id: ChatId,
) -> Result<String> {
let mut args: Vec<&str> = vec![
"--dangerously-skip-permissions",
"-p",
"--output-format",
"stream-json",
"--verbose",
];
args.extend(extra_args);
args.push(prompt);
let mut child = Command::new("claude")
.args(&args)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let stdout = child.stdout.take().unwrap();
let mut lines = tokio::io::BufReader::new(stdout).lines();
// send placeholder immediately so user knows we're on it
let mut msg_id: Option<teloxide::types::MessageId> = match bot.send_message(chat_id, CURSOR).await {
Ok(sent) => Some(sent.id),
Err(_) => None,
};
let mut last_sent_text = String::new();
let mut last_edit = Instant::now();
let mut final_result = String::new();
let mut is_error = false;
let mut tool_status = String::new(); // current tool use status line
while let Ok(Some(line)) = lines.next_line().await {
let event: StreamEvent = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => continue,
};
match event.event_type.as_str() {
"assistant" => {
if let Some(msg) = &event.message {
// check for tool use — show status
if let Some(status) = extract_tool_use(msg) {
tool_status = format!("[{status}]");
let display = if last_sent_text.is_empty() {
tool_status.clone()
} else {
format!("{last_sent_text}\n\n{tool_status}")
};
let display = truncate_for_display(&display);
if let Some(id) = msg_id {
let _ = bot.edit_message_text(chat_id, id, &display).await;
} else if let Ok(sent) = bot.send_message(chat_id, &display).await {
msg_id = Some(sent.id);
}
last_edit = Instant::now();
continue;
}
// check for text content
let text = extract_text(msg);
if text.is_empty() || text == last_sent_text {
continue;
}
// throttle edits
if last_edit.elapsed().as_millis() < EDIT_INTERVAL_MS as u128 {
continue;
}
tool_status.clear();
let display = truncate_for_display(&text);
if let Some(id) = msg_id {
if bot.edit_message_text(chat_id, id, &display).await.is_ok() {
last_sent_text = text;
last_edit = Instant::now();
}
} else if let Ok(sent) = bot.send_message(chat_id, &display).await {
msg_id = Some(sent.id);
last_sent_text = text;
last_edit = Instant::now();
}
}
}
"result" => {
final_result = event.result.unwrap_or_default();
is_error = event.is_error;
}
_ => {}
}
}
// read stderr before waiting (in case child already exited)
let stderr_handle = child.stderr.take();
let status = child.wait().await;
// collect stderr for diagnostics
let stderr_text = if let Some(mut se) = stderr_handle {
let mut buf = String::new();
let _ = tokio::io::AsyncReadExt::read_to_string(&mut se, &mut buf).await;
buf
} else {
String::new()
};
// determine error: explicit is_error from stream, or non-zero exit with no result
let has_error = is_error
|| (final_result.is_empty()
&& status.as_ref().map(|s| !s.success()).unwrap_or(true));
if has_error {
let err_detail = if !final_result.is_empty() {
final_result.clone()
} else if !stderr_text.is_empty() {
stderr_text.trim().to_string()
} else {
format!("claude exited: {:?}", status)
};
if let Some(id) = msg_id {
let _ = bot
.edit_message_text(chat_id, id, format!("[error] {err_detail}"))
.await;
}
anyhow::bail!("{err_detail}");
}
if final_result.is_empty() {
return Ok(final_result);
}
// final update: replace streaming message with complete result
let chunks: Vec<&str> = split_msg(&final_result, TG_MSG_LIMIT);
if let Some(id) = msg_id {
// edit first message with final text
let _ = bot.edit_message_text(chat_id, id, chunks[0]).await;
// send remaining chunks as new messages
for chunk in &chunks[1..] {
let _ = bot.send_message(chat_id, *chunk).await;
}
} else {
// never got to send a streaming message, send all now
for chunk in &chunks {
let _ = bot.send_message(chat_id, *chunk).await;
}
}
Ok(final_result)
}
const CURSOR: &str = " \u{25CE}";
fn truncate_for_display(s: &str) -> String {
let budget = TG_MSG_LIMIT - CURSOR.len() - 1;
if s.len() <= budget {
format!("{s}{CURSOR}")
} else {
let truncated = truncate_at_char_boundary(s, budget - 2);
format!("{truncated}\n{CURSOR}")
}
}
fn truncate_at_char_boundary(s: &str, max: usize) -> &str {
if s.len() <= max {
return s;
}
let mut end = max;
while !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
fn split_msg(s: &str, max: usize) -> Vec<&str> {
if s.len() <= max {
return vec![s];
}
let mut parts = Vec::new();
let mut rest = s;
while !rest.is_empty() {
if rest.len() <= max {
parts.push(rest);
break;
}
let mut end = max;
while !rest.is_char_boundary(end) {
end -= 1;
}
let (chunk, tail) = rest.split_at(end);
parts.push(chunk);
rest = tail;
}
parts
}