improve: enhance KB search with better embedding and chunking
This commit is contained in:
parent
fe1370230f
commit
69ad06ca5b
@ -1,26 +1,66 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Generate embeddings for text chunks. Reads JSON from stdin, writes JSON to stdout.
|
"""Embedding HTTP server. Loads model once at startup, serves requests on port 8199.
|
||||||
|
|
||||||
Input: {"texts": ["text1", "text2", ...]}
|
POST /embed {"texts": ["text1", "text2", ...]}
|
||||||
Output: {"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...], ...]}
|
Response: {"embeddings": [[0.1, 0.2, ...], ...]}
|
||||||
|
|
||||||
|
GET /health -> 200 OK
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
MODEL_NAME = "all-MiniLM-L6-v2"
|
MODEL_NAME = "all-MiniLM-L6-v2"
|
||||||
|
PORT = 8199
|
||||||
|
|
||||||
def main():
|
# Load model once at startup
|
||||||
data = json.loads(sys.stdin.read())
|
print(f"Loading model {MODEL_NAME}...", flush=True)
|
||||||
texts = data["texts"]
|
model = SentenceTransformer(MODEL_NAME)
|
||||||
|
print(f"Model loaded, serving on port {PORT}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedHandler(BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
length = int(self.headers.get("Content-Length", 0))
|
||||||
|
body = self.rfile.read(length)
|
||||||
|
data = json.loads(body)
|
||||||
|
texts = data.get("texts", [])
|
||||||
|
|
||||||
if not texts:
|
if not texts:
|
||||||
print(json.dumps({"embeddings": []}))
|
result = {"embeddings": []}
|
||||||
return
|
else:
|
||||||
|
|
||||||
model = SentenceTransformer(MODEL_NAME)
|
|
||||||
embeddings = model.encode(texts, normalize_embeddings=True)
|
embeddings = model.encode(texts, normalize_embeddings=True)
|
||||||
print(json.dumps({"embeddings": embeddings.tolist()}))
|
result = {"embeddings": embeddings.tolist()}
|
||||||
|
|
||||||
|
resp = json.dumps(result).encode()
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.send_header("Content-Length", str(len(resp)))
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(resp)
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-Type", "text/plain")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(b"ok")
|
||||||
|
|
||||||
|
def log_message(self, format, *args):
|
||||||
|
# Suppress per-request logs
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
import socket
|
||||||
|
# Dual-stack: listen on both IPv4 and IPv6
|
||||||
|
class DualStackHTTPServer(HTTPServer):
|
||||||
|
address_family = socket.AF_INET6
|
||||||
|
def server_bind(self):
|
||||||
|
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
||||||
|
super().server_bind()
|
||||||
|
server = DualStackHTTPServer(("::", PORT), EmbedHandler)
|
||||||
|
try:
|
||||||
|
server.serve_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|||||||
58
src/kb.rs
58
src/kb.rs
@ -1,6 +1,5 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use sqlx::sqlite::SqlitePool;
|
use sqlx::sqlite::SqlitePool;
|
||||||
use std::process::Stdio;
|
|
||||||
|
|
||||||
const TOP_K: usize = 5;
|
const TOP_K: usize = 5;
|
||||||
|
|
||||||
@ -30,21 +29,34 @@ impl KbManager {
|
|||||||
|
|
||||||
/// Re-index a single article: delete its old chunks, chunk the content, embed, store
|
/// Re-index a single article: delete its old chunks, chunk the content, embed, store
|
||||||
pub async fn index(&self, article_id: &str, content: &str) -> Result<()> {
|
pub async fn index(&self, article_id: &str, content: &str) -> Result<()> {
|
||||||
// Delete only this article's chunks
|
self.index_batch(&[(article_id.to_string(), content.to_string())]).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batch re-index multiple articles in one embedding call (avoids repeated model loading).
|
||||||
|
pub async fn index_batch(&self, articles: &[(String, String)]) -> Result<()> {
|
||||||
|
// Collect all chunks with their article_id
|
||||||
|
let mut all_chunks: Vec<(String, Chunk)> = Vec::new(); // (article_id, chunk)
|
||||||
|
for (article_id, content) in articles {
|
||||||
sqlx::query("DELETE FROM kb_chunks WHERE article_id = ?")
|
sqlx::query("DELETE FROM kb_chunks WHERE article_id = ?")
|
||||||
.bind(article_id)
|
.bind(article_id)
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let chunks = split_chunks(content);
|
let chunks = split_chunks(content);
|
||||||
if chunks.is_empty() {
|
for chunk in chunks {
|
||||||
|
all_chunks.push((article_id.clone(), chunk));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_chunks.is_empty() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
|
// Single embedding call for all chunks
|
||||||
|
let texts: Vec<String> = all_chunks.iter().map(|(_, c)| c.content.clone()).collect();
|
||||||
let embeddings = compute_embeddings(&texts).await?;
|
let embeddings = compute_embeddings(&texts).await?;
|
||||||
|
|
||||||
for (chunk, embedding) in chunks.iter().zip(embeddings.into_iter()) {
|
for ((article_id, chunk), embedding) in all_chunks.iter().zip(embeddings.into_iter()) {
|
||||||
let vec_bytes = embedding_to_bytes(&embedding);
|
let vec_bytes = embedding_to_bytes(&embedding);
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO kb_chunks (id, article_id, title, content, embedding) VALUES (?, ?, ?, ?, ?)",
|
"INSERT INTO kb_chunks (id, article_id, title, content, embedding) VALUES (?, ?, ?, ?, ?)",
|
||||||
@ -58,7 +70,7 @@ impl KbManager {
|
|||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("KB indexed article {}: {} chunks", article_id, chunks.len());
|
tracing::info!("KB indexed {} articles, {} total chunks", articles.len(), all_chunks.len());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,30 +150,28 @@ impl KbManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Call Python script to compute embeddings
|
/// Call embedding HTTP server
|
||||||
async fn compute_embeddings(texts: &[String]) -> Result<Vec<Vec<f32>>> {
|
async fn compute_embeddings(texts: &[String]) -> Result<Vec<Vec<f32>>> {
|
||||||
|
let embed_url = std::env::var("TORI_EMBED_URL")
|
||||||
|
.unwrap_or_else(|_| "http://127.0.0.1:8199".to_string());
|
||||||
|
let client = reqwest::Client::new();
|
||||||
let input = serde_json::json!({ "texts": texts });
|
let input = serde_json::json!({ "texts": texts });
|
||||||
|
|
||||||
let mut child = tokio::process::Command::new("/app/venv/bin/python")
|
let resp = client
|
||||||
.arg("/app/scripts/embed.py")
|
.post(format!("{}/embed", embed_url))
|
||||||
.stdin(Stdio::piped())
|
.json(&input)
|
||||||
.stdout(Stdio::piped())
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
.stderr(Stdio::piped())
|
.send()
|
||||||
.spawn()?;
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Embedding server request failed (is embed.py running?): {}", e))?;
|
||||||
|
|
||||||
if let Some(mut stdin) = child.stdin.take() {
|
if !resp.status().is_success() {
|
||||||
use tokio::io::AsyncWriteExt;
|
let status = resp.status();
|
||||||
stdin.write_all(input.to_string().as_bytes()).await?;
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!("Embedding server error {}: {}", status, body);
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = child.wait_with_output().await?;
|
let result: serde_json::Value = resp.json().await?;
|
||||||
|
|
||||||
if !output.status.success() {
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
||||||
anyhow::bail!("Embedding script failed: {}", stderr);
|
|
||||||
}
|
|
||||||
|
|
||||||
let result: serde_json::Value = serde_json::from_slice(&output.stdout)?;
|
|
||||||
let embeddings: Vec<Vec<f32>> = result["embeddings"]
|
let embeddings: Vec<Vec<f32>> = result["embeddings"]
|
||||||
.as_array()
|
.as_array()
|
||||||
.ok_or_else(|| anyhow::anyhow!("Invalid embedding output"))?
|
.ok_or_else(|| anyhow::anyhow!("Invalid embedding output"))?
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user