nuonuo/experiments/exp10_auto_paraphrase.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

221 lines
9.5 KiB
Python

"""Experiment P2: Auto Paraphrase Generation.
LLM gateway down, so test:
1. Heuristic paraphrase effect on recall (how much does crappy augmentation help?)
2. "Oracle" paraphrase (hand-crafted) vs heuristic vs none
3. Design: what makes a good paraphrase for memory augmentation?
4. Analysis: which failures are fixable by paraphrase vs need better embeddings?
"""
import sys
import time
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
sys.path.insert(0, str(Path(__file__).parent.parent))
from llm import generate_paraphrases_heuristic
DEVICE = "cuda"
PAIRS = [
("What's the weather like today?", "User checks weather every morning"),
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
("The database is slow again", "Missing index on users table"),
("I need to fix the authentication bug", "JWT tokens with 24h expiry in Redis"),
("The API returns 500 errors", "OOM in the Python worker"),
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
("Tests failing in CI", "CI needs postgres service container"),
("Memory usage too high", "Leak in websocket handler"),
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
("Log files too large", "Logs rotate daily, shipped to Loki"),
("How to add caching?", "Redis available at redis.internal:6379"),
("Frontend loads slowly", "CDN CloudFlare, 1h TTL for assets"),
("Refactor payment module", "Stripe API, webhook in payments/webhook.py"),
("Set up new server", "Ubuntu 22.04, Docker, Tailscale, monitoring"),
("Optimize search", "Elasticsearch v8, recently upgraded"),
("Backup the database", "Daily 3am UTC cron to S3"),
("Configure reverse proxy", "Traefik, not nginx"),
("Team meeting schedule", "Standup 10am London, Mon-Fri"),
("Learn a new programming language", "User has Python+Go, new to systems"),
("Review my pull request", "User prefers small PRs with clear commits"),
]
PARAPHRASES = [
"How's the weather?", "Ship the release", "DB performance terrible",
"Fix the login issue", "Server errors everywhere", "Need observability",
"CI tests breaking", "Service using too much RAM", "Docker config help",
"Logs eating disk space", "Want to add a cache layer", "Website too slow",
"Payment code needs rework", "Provision a new machine", "Search is slow",
"Need a DB backup", "Proxy configuration", "When's the standup?",
"Want to learn Rust", "Check my pull request",
]
# Oracle paraphrases: hand-crafted to cover the semantic gaps
ORACLE_PARAPHRASES = {
1: ["Ship the release", "Push to production", "Release the new build", "Deploy new code"],
3: ["Fix the login issue", "Authentication broken", "Login doesn't work", "Auth bug"],
4: ["Server errors everywhere", "Getting 500s", "Internal server error", "API is down"],
5: ["Need observability", "Set up alerts", "Monitor services", "Add monitoring"],
10: ["Add a cache layer", "Implement caching", "Cache responses"],
11: ["Website too slow", "Page loads slowly", "Frontend performance bad"],
12: ["Payment code needs rework", "Refactor payments", "Payment system restructure"],
13: ["Provision a new machine", "Need a new server", "Set up new box", "New machine setup"],
14: ["Search is slow", "Search performance", "Optimize search queries"],
17: ["When's the standup?", "Meeting time?", "Daily sync schedule", "What time is standup?"],
18: ["Want to learn Rust", "Learning Rust", "Getting into Rust", "Start with Rust"],
19: ["Check my pull request", "Look at my code", "PR review please", "Review my code changes"],
}
def cosine(a, b):
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
class TwoStageHopfield:
def __init__(self, beta=16.0, top_k=20):
self.beta = beta
self.top_k = top_k
self.cue_embs = []
self.target_embs = []
self.memory_ids = []
def learn(self, cue_emb, target_emb, mid):
self.cue_embs.append(cue_emb.detach())
self.target_embs.append(target_emb.detach())
self.memory_ids.append(mid)
def recall(self, query_emb, steps=3):
cue_mat = torch.stack(self.cue_embs)
target_mat = torch.stack(self.target_embs)
K = min(self.top_k, len(self.cue_embs))
sims = query_emb @ cue_mat.T
_, top_idx = sims.topk(K)
cand_cues = cue_mat[top_idx]
cand_targets = target_mat[top_idx]
cand_mids = [self.memory_ids[i] for i in top_idx.tolist()]
xi = query_emb
for _ in range(steps):
scores = self.beta * (xi @ cand_cues.T)
attn = torch.softmax(scores, dim=0)
xi = attn @ cand_cues
xi = nn.functional.normalize(xi, dim=0)
scores = self.beta * (xi @ cand_cues.T)
attn = torch.softmax(scores, dim=0)
# Aggregate by memory_id
mid_scores = {}
for i, mid in enumerate(cand_mids):
mid_scores[mid] = mid_scores.get(mid, 0) + attn[i].item()
best_mid = max(mid_scores, key=mid_scores.get)
target = nn.functional.normalize(attn @ cand_targets, dim=0)
return target, best_mid
def evaluate(model, augmentation_mode, n_background=2000):
"""Test recall with different augmentation strategies."""
from sentence_transformers import SentenceTransformer
cue_embs = model.encode([p[0] for p in PAIRS], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
target_embs = model.encode([p[1] for p in PAIRS], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
para_embs = model.encode(PARAPHRASES, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
mem = TwoStageHopfield(beta=16.0, top_k=20)
for i in range(len(PAIRS)):
mem.learn(cue_embs[i], target_embs[i], mid=i)
if augmentation_mode == "heuristic":
paras = generate_paraphrases_heuristic(PAIRS[i][0], n=3)
para_e = model.encode(paras, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
for j in range(len(paras)):
mem.learn(para_e[j], target_embs[i], mid=i)
elif augmentation_mode == "oracle":
if i in ORACLE_PARAPHRASES:
paras = ORACLE_PARAPHRASES[i]
para_e = model.encode(paras, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
for j in range(len(paras)):
mem.learn(para_e[j], target_embs[i], mid=i)
elif augmentation_mode == "oracle_all":
# Oracle for all pairs (3 generic paraphrases each)
if i in ORACLE_PARAPHRASES:
paras = ORACLE_PARAPHRASES[i]
else:
paras = generate_paraphrases_heuristic(PAIRS[i][0], n=3)
para_e = model.encode(paras, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
for j in range(len(paras)):
mem.learn(para_e[j], target_embs[i], mid=i)
# Background
if n_background > 0:
topics = ["server", "db", "api", "fe", "be", "cache"]
bg_cues = [f"The {topics[i%6]} has issue {i}" for i in range(n_background)]
bg_targets = [f"Fix issue {i}" for i in range(n_background)]
bg_c = model.encode(bg_cues, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE, batch_size=256)
bg_t = model.encode(bg_targets, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE, batch_size=256)
for i in range(n_background):
mem.learn(bg_c[i], bg_t[i], mid=100+i)
correct = 0
failures = []
for i in range(len(PARAPHRASES)):
_, best_mid = mem.recall(para_embs[i])
if best_mid == i:
correct += 1
else:
failures.append((i, best_mid))
n = len(PARAPHRASES)
return correct, n, failures
def main():
print("=" * 60)
print("Experiment P2: Auto Paraphrase Analysis")
print("=" * 60)
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
for bg in [0, 500, 2000]:
print(f"\n=== Background: {bg} ===")
for mode in ["none", "heuristic", "oracle", "oracle_all"]:
correct, n, failures = evaluate(model, mode, n_background=bg)
fail_ids = [f[0] for f in failures]
print(f" {mode:<15}: {correct}/{n} ({correct/n:.0%})"
+ (f" | Failures: {fail_ids}" if failures else ""))
# Analyze: which failures are fixable?
print("\n=== Failure Analysis (2K bg, no augmentation) ===")
correct, n, failures = evaluate(model, "none", 2000)
cue_texts = [p[0] for p in PAIRS]
for qi, gi in failures:
cue_emb = model.encode([cue_texts[qi]], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
para_emb = model.encode([PARAPHRASES[qi]], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
sim = cosine(cue_emb, para_emb)
fixable = qi in ORACLE_PARAPHRASES
print(f" [{qi}] '{PARAPHRASES[qi][:25]}...' → got [{gi}], "
f"cue_sim={sim:.3f}, oracle_fix={'' if fixable else ''}")
if __name__ == "__main__":
main()