Switch from fastembed to Python sentence-transformers for embedding

ort (ONNX Runtime) has no prebuilt binaries for aarch64-musl.
Use a Python subprocess with sentence-transformers instead:
- scripts/embed.py: reads JSON stdin, outputs embeddings
- kb.rs: calls Python script via tokio subprocess
- Dockerfile: install python3 + sentence-transformers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Fam Zheng 2026-03-01 08:31:31 +00:00
parent 8483359cbc
commit fbf636868c
5 changed files with 86 additions and 766 deletions

761
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -26,4 +26,3 @@ uuid = { version = "1", features = ["v4"] }
anyhow = "1" anyhow = "1"
mime_guess = "2" mime_guess = "2"
nix = { version = "0.29", features = ["signal"] } nix = { version = "0.29", features = ["signal"] }
fastembed = { version = "5", default-features = false, features = ["hf-hub", "hf-hub-rustls-tls", "ort-download-binaries-rustls-tls"] }

View File

@ -8,13 +8,15 @@ RUN npm run build
# Stage 2: Runtime # Stage 2: Runtime
FROM alpine:3.21 FROM alpine:3.21
RUN apk add --no-cache ca-certificates curl bash RUN apk add --no-cache ca-certificates curl bash python3 py3-pip
RUN curl -LsSf https://astral.sh/uv/install.sh | sh RUN curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:$PATH" ENV PATH="/root/.local/bin:$PATH"
RUN pip3 install --break-system-packages sentence-transformers
RUN mkdir -p /app/data/workspaces RUN mkdir -p /app/data/workspaces
WORKDIR /app WORKDIR /app
COPY target/aarch64-unknown-linux-musl/release/tori . COPY target/aarch64-unknown-linux-musl/release/tori .
COPY --from=frontend /app/web/dist ./web/dist/ COPY --from=frontend /app/web/dist ./web/dist/
COPY scripts/embed.py ./scripts/
COPY config.yaml . COPY config.yaml .
EXPOSE 3000 EXPOSE 3000

26
scripts/embed.py Normal file
View File

@ -0,0 +1,26 @@
#!/usr/bin/env python3
"""Generate embeddings for text chunks. Reads JSON from stdin, writes JSON to stdout.
Input: {"texts": ["text1", "text2", ...]}
Output: {"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...], ...]}
"""
import json
import sys
from sentence_transformers import SentenceTransformer
MODEL_NAME = "all-MiniLM-L6-v2"
def main():
data = json.loads(sys.stdin.read())
texts = data["texts"]
if not texts:
print(json.dumps({"embeddings": []}))
return
model = SentenceTransformer(MODEL_NAME)
embeddings = model.encode(texts, normalize_embeddings=True)
print(json.dumps({"embeddings": embeddings.tolist()}))
if __name__ == "__main__":
main()

View File

@ -1,11 +1,10 @@
use anyhow::Result; use anyhow::Result;
use sqlx::sqlite::SqlitePool; use sqlx::sqlite::SqlitePool;
use std::sync::Mutex; use std::process::Stdio;
const TOP_K: usize = 5; const TOP_K: usize = 5;
pub struct KbManager { pub struct KbManager {
embedder: Mutex<fastembed::TextEmbedding>,
pool: SqlitePool, pool: SqlitePool,
} }
@ -25,14 +24,10 @@ struct Chunk {
impl KbManager { impl KbManager {
pub fn new(pool: SqlitePool) -> Result<Self> { pub fn new(pool: SqlitePool) -> Result<Self> {
let embedder = fastembed::TextEmbedding::try_new( Ok(Self { pool })
fastembed::InitOptions::new(fastembed::EmbeddingModel::AllMiniLML6V2)
.with_show_download_progress(true),
)?;
Ok(Self { embedder: Mutex::new(embedder), pool })
} }
/// Re-index: chunk the content, embed, store in SQLite /// Re-index: chunk the content, embed via Python, store in SQLite
pub async fn index(&self, content: &str) -> Result<()> { pub async fn index(&self, content: &str) -> Result<()> {
// Clear old chunks // Clear old chunks
sqlx::query("DELETE FROM kb_chunks") sqlx::query("DELETE FROM kb_chunks")
@ -45,7 +40,7 @@ impl KbManager {
} }
let texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect(); let texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
let embeddings = self.embedder.lock().unwrap().embed(texts, None)?; let embeddings = compute_embeddings(&texts).await?;
for (chunk, embedding) in chunks.iter().zip(embeddings.into_iter()) { for (chunk, embedding) in chunks.iter().zip(embeddings.into_iter()) {
let vec_bytes = embedding_to_bytes(&embedding); let vec_bytes = embedding_to_bytes(&embedding);
@ -66,7 +61,7 @@ impl KbManager {
/// Search KB by query, returns top-k results /// Search KB by query, returns top-k results
pub async fn search(&self, query: &str) -> Result<Vec<SearchResult>> { pub async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
let query_embeddings = self.embedder.lock().unwrap().embed(vec![query.to_string()], None)?; let query_embeddings = compute_embeddings(&[query.to_string()]).await?;
let query_vec = query_embeddings let query_vec = query_embeddings
.into_iter() .into_iter()
.next() .next()
@ -102,7 +97,50 @@ impl KbManager {
} }
} }
/// Call Python script to compute embeddings
async fn compute_embeddings(texts: &[String]) -> Result<Vec<Vec<f32>>> {
let input = serde_json::json!({ "texts": texts });
let mut child = tokio::process::Command::new("python3")
.arg("/app/scripts/embed.py")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
if let Some(mut stdin) = child.stdin.take() {
use tokio::io::AsyncWriteExt;
stdin.write_all(input.to_string().as_bytes()).await?;
}
let output = child.wait_with_output().await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("Embedding script failed: {}", stderr);
}
let result: serde_json::Value = serde_json::from_slice(&output.stdout)?;
let embeddings: Vec<Vec<f32>> = result["embeddings"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("Invalid embedding output"))?
.iter()
.map(|arr| {
arr.as_array()
.unwrap_or(&vec![])
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect()
})
.collect();
Ok(embeddings)
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt(); let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt(); let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
@ -131,7 +169,6 @@ fn split_chunks(content: &str) -> Vec<Chunk> {
for line in content.lines() { for line in content.lines() {
if line.starts_with("## ") { if line.starts_with("## ") {
// Save previous chunk
let text = current_lines.join("\n").trim().to_string(); let text = current_lines.join("\n").trim().to_string();
if !text.is_empty() { if !text.is_empty() {
chunks.push(Chunk { chunks.push(Chunk {
@ -150,7 +187,6 @@ fn split_chunks(content: &str) -> Vec<Chunk> {
} }
} }
// Last chunk
let text = current_lines.join("\n").trim().to_string(); let text = current_lines.join("\n").trim().to_string();
if !text.is_empty() { if !text.is_empty() {
chunks.push(Chunk { chunks.push(Chunk {