"""Experiment 7d: Two-stage retrieval for scale. Problem: Embedding Hopfield degrades at 10K+ (80%). Fix: Pre-filter with approximate NN (top-K), then Hopfield settle on candidates. This is O(N) for pre-filter (can be O(log N) with FAISS) + O(K) for Hopfield. Also: test adaptive β based on attention entropy (low entropy = confident). """ import sys import time from pathlib import Path import torch import torch.nn as nn import numpy as np DEVICE = "cuda" def cosine(a, b): return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() class TwoStageHopfield: """Pre-filter + Hopfield settle. Stage 1: cosine NN → top-K candidates (fast, O(N) or O(log N) with index) Stage 2: Hopfield attention over K candidates (precise, O(K)) """ def __init__(self, beta=16.0, top_k=50): self.beta = beta self.top_k = top_k self.cue_embs = [] self.target_embs = [] self._cue_matrix = None # Cached for batch NN def learn(self, cue_emb, target_emb): self.cue_embs.append(cue_emb.detach()) self.target_embs.append(target_emb.detach()) self._cue_matrix = None # Invalidate cache def _get_cue_matrix(self): if self._cue_matrix is None: self._cue_matrix = torch.stack(self.cue_embs) return self._cue_matrix def recall(self, query_emb, steps=3): cue_mat = self._get_cue_matrix() target_mat = torch.stack(self.target_embs) N = cue_mat.shape[0] # Stage 1: Fast NN pre-filter k = min(self.top_k, N) sims = query_emb @ cue_mat.T # [N] top_sims, top_indices = sims.topk(k) # Stage 2: Hopfield settle on candidates only cand_cues = cue_mat[top_indices] # [K, dim] cand_targets = target_mat[top_indices] # [K, dim] xi = query_emb for _ in range(steps): scores = self.beta * (xi @ cand_cues.T) attn = torch.softmax(scores, dim=0) xi = attn @ cand_cues xi = nn.functional.normalize(xi, dim=0) # Final association scores = self.beta * (xi @ cand_cues.T) attn = torch.softmax(scores, dim=0) target = attn @ cand_targets # Map back to global index best_local = attn.argmax().item() best_global = top_indices[best_local].item() return nn.functional.normalize(target, dim=0), best_global, attn def recall_multihop(self, query_emb, hops=2, steps=3): """Multi-hop: each hop does two-stage retrieval.""" xi = query_emb results = [] for _ in range(hops): target, idx, attn = self.recall(xi, steps=steps) results.append((target, idx)) xi = target # Use target as next query return results def load_model(): from sentence_transformers import SentenceTransformer return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE) def test_scale(model): """Scale test comparing pure Hopfield vs two-stage.""" print("\n=== Scale Comparison ===") pairs = [ ("What's the weather like today?", "User checks weather every morning"), ("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"), ("The database is slow again", "Missing index on users table"), ("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"), ("The API returns 500 errors", "OOM in the Python worker"), ("Let's set up monitoring", "Prometheus + Grafana on OCI"), ("Tests failing in CI", "CI needs postgres service container"), ("Memory usage too high", "Leak in websocket handler"), ("Help with Docker setup", "docker-compose for dev, k3s for prod"), ("Log files too large", "Logs rotate daily, shipped to Loki"), ] paraphrases = [ "How's the weather outside?", "Push the new release", "DB performance terrible", "Login bug needs fixing", "Getting 500 errors", "Need better observability", "CI tests breaking", "Service using too much RAM", "Docker config help", "Logs eating disk space", ] cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) para_embs = model.encode(paraphrases, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) for n_bg in [0, 100, 500, 1000, 5000, 10000, 20000]: # Two-stage with different K for top_k in [20, 50, 100]: if n_bg < top_k and n_bg > 0: continue mem = TwoStageHopfield(beta=16.0, top_k=top_k) for i in range(len(pairs)): mem.learn(cue_embs[i], target_embs[i]) if n_bg > 0: topics = ["server", "database", "API", "frontend", "backend", "cache", "queue", "network", "storage", "auth", "docker", "kubernetes", "redis", "nginx", "postgres"] bg_cues = [f"The {topics[i%len(topics)]} system has issue {i}" for i in range(n_bg)] bg_targets = [f"Fix {topics[i%len(topics)]} issue {i} urgently" for i in range(n_bg)] for start in range(0, n_bg, 256): end = min(start + 256, n_bg) bc = model.encode(bg_cues[start:end], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) bt = model.encode(bg_targets[start:end], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) for j in range(bc.shape[0]): mem.learn(bc[j], bt[j]) # Test t0 = time.time() correct = 0 for i in range(len(paraphrases)): with torch.no_grad(): recalled, idx, attn = mem.recall(para_embs[i]) all_sims = [cosine(recalled, target_embs[j]) for j in range(len(pairs))] if np.argmax(all_sims) == i: correct += 1 dt = (time.time() - t0) / len(paraphrases) * 1000 n = len(paraphrases) total = len(mem.cue_embs) print(f" N={total:>6}, K={top_k:>3}: " f"Para={correct}/{n} ({correct/n:>3.0%}), " f"time={dt:.1f}ms") del mem torch.cuda.empty_cache() if n_bg > 0: print() def test_multihop_at_scale(model): """Multi-hop with two-stage at scale.""" print("\n=== Multi-hop Two-Stage (500 bg) ===") chains = [ ["What's the weather?", "Check weather before going out", "My coffee shop nearby", "Great latte art"], ["Review the code", "Found memory leak", "Leaks cause OOM", "Add k8s limits"], ["Deploy to prod", "Blue-green deployment", "Blue is active", "Switch to green"], ] mem = TwoStageHopfield(beta=16.0, top_k=50) all_embs = [] for chain in chains: embs = [model.encode([t], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)[0] for t in chain] all_embs.append(embs) for i in range(len(chain) - 1): mem.learn(embs[i], embs[i+1]) # Background bg = [f"Background about {['code','ops','ml','data','infra'][i%5]} number {i}" for i in range(500)] bg_embs = model.encode(bg, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE, batch_size=256) for i in range(499): mem.learn(bg_embs[i], bg_embs[i+1]) for ci, chain in enumerate(chains): results = mem.recall_multihop(all_embs[ci][0], hops=len(chain)-1) for hop_idx, (recalled, idx) in enumerate(results): target = all_embs[ci][hop_idx + 1] sim = cosine(recalled, target) status = "✓" if sim > 0.7 else "✗" print(f" {status} Chain{ci+1} hop{hop_idx+1}: sim={sim:.3f}") def test_diverse_queries(model): """Larger test set with more diverse queries.""" print("\n=== Diverse Query Test (20 pairs, 2000 bg) ===") pairs = [ ("What's the weather like today?", "User checks weather every morning"), ("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"), ("The database is slow again", "Missing index on users table"), ("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"), ("The API returns 500 errors", "OOM in the Python worker"), ("Let's set up monitoring", "Prometheus + Grafana on OCI"), ("Tests failing in CI", "CI needs postgres service container"), ("Memory usage too high", "Leak in websocket handler"), ("Help with Docker setup", "docker-compose for dev, k3s for prod"), ("Log files too large", "Logs rotate daily, shipped to Loki"), ("How to add caching?", "Redis available at redis.internal:6379"), ("Frontend loads slowly", "CDN CloudFlare, 1h TTL for assets"), ("Refactor payment module", "Stripe API, webhook in payments/webhook.py"), ("Set up new server", "Ubuntu 22.04, Docker, Tailscale, monitoring"), ("Optimize search", "Elasticsearch v8, recently upgraded"), ("Backup the database", "Daily 3am UTC cron to S3"), ("Configure reverse proxy", "Traefik, not nginx"), ("Team meeting schedule", "Standup 10am London, Mon-Fri"), ("Learn a new language", "User has Python+Go, new to systems programming"), ("Review my PR", "User prefers small PRs with clear commits"), ] paraphrases = [ "How's the weather?", "Ship the release", "DB is crawling", "Fix the login issue", "Server errors everywhere", "Need observability", "CI is broken", "Too much RAM usage", "Docker help please", "Disk full from logs", "Want to add a cache layer", "Website too slow", "Payment code needs rework", "Provision a new machine", "Search is slow", "Need a DB backup", "Proxy configuration", "When's the standup?", "Want to learn Rust", "Check my pull request", ] cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) para_embs = model.encode(paraphrases, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) mem = TwoStageHopfield(beta=16.0, top_k=50) for i in range(len(pairs)): mem.learn(cue_embs[i], target_embs[i]) # 2000 diverse background topics = ["server", "database", "API", "frontend", "backend", "cache", "queue", "network", "storage", "auth", "docker", "kubernetes", "redis", "nginx", "postgres", "python", "golang", "react", "terraform", "ansible"] actions = ["crashed", "is slow", "needs update", "has bug", "timed out", "needs migration", "needs backup", "has leak", "is down", "needs config"] bg_cues = [f"The {topics[i%len(topics)]} {actions[i%len(actions)]} (ticket {i})" for i in range(2000)] bg_targets = [f"Fix {topics[i%len(topics)]} {actions[i%len(actions)]}: see wiki page {i}" for i in range(2000)] for start in range(0, 2000, 256): end = min(start + 256, 2000) bc = model.encode(bg_cues[start:end], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) bt = model.encode(bg_targets[start:end], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) for j in range(bc.shape[0]): mem.learn(bc[j], bt[j]) # Test correct = 0 failures = [] for i in range(len(paraphrases)): with torch.no_grad(): recalled, idx, attn = mem.recall(para_embs[i]) all_sims = [cosine(recalled, target_embs[j]) for j in range(len(pairs))] best = np.argmax(all_sims) if best == i: correct += 1 else: failures.append((i, best, all_sims[i], all_sims[best])) n = len(paraphrases) print(f" Result: {correct}/{n} ({correct/n:.0%})") if failures: print(f" Failures:") for qi, gi, sim_correct, sim_got in failures: print(f" Q: '{paraphrases[qi][:30]}...' → got [{gi}] " f"(sim_correct={sim_correct:.3f}, sim_got={sim_got:.3f})") def main(): print("=" * 60) print("Experiment 7d: Two-Stage Hopfield") print("=" * 60) model = load_model() test_scale(model) test_multihop_at_scale(model) test_diverse_queries(model) if __name__ == "__main__": main()