Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
366 lines
14 KiB
Python
366 lines
14 KiB
Python
"""Experiment 4: End-to-end with real sentence embeddings.
|
|
|
|
All previous experiments used random vectors. Now test with actual semantic
|
|
embeddings from a sentence transformer model. Key questions:
|
|
|
|
1. Does pattern separation preserve semantic neighborhoods?
|
|
(Similar sentences → similar/related codes?)
|
|
2. Can we retrieve memories using paraphrased/related queries?
|
|
3. Does the multi-hop chaining work with semantic embeddings?
|
|
4. Noise tolerance: does embedding-space noise behave differently?
|
|
5. Does a learned separator trained on real data improve noise tolerance?
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def cosine(a, b):
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
_, idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, idx, 1.0)
|
|
return out
|
|
|
|
|
|
# --- Test data: conversation-like memory pairs ---
|
|
MEMORY_PAIRS = [
|
|
# (context/cue, memory/fact to recall)
|
|
("What's the weather like today?", "User prefers to check weather every morning"),
|
|
("Let's deploy the new version", "The deployment pipeline uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Last time DB was slow it was because of missing index on users table"),
|
|
("Can you review my pull request?", "User prefers small PRs with clear commit messages"),
|
|
("I need to fix the authentication bug", "Auth service uses JWT tokens with 24h expiry stored in Redis"),
|
|
("Let's set up monitoring", "Prometheus + Grafana stack is already running on the OCI cluster"),
|
|
("The API is returning 500 errors", "Last 500 error was caused by OOM in the Python worker"),
|
|
("I want to learn Rust", "User has strong Python and Go background, new to systems programming"),
|
|
("Schedule a meeting with the team", "Team standup is at 10am London time, Mon-Fri"),
|
|
("How do I configure nginx?", "The project uses Traefik as reverse proxy, not nginx"),
|
|
("The tests are failing in CI", "CI runs on Gitea Actions, tests need postgres service container"),
|
|
("Let's optimize the search function", "Search uses Elasticsearch, recently upgraded to v8"),
|
|
("I need to backup the database", "Backups run daily at 3am UTC via cron job to S3"),
|
|
("The memory usage is too high", "Python service has a known memory leak in the websocket handler"),
|
|
("Can you help with the Docker setup?", "Project uses docker-compose for local dev, k3s for production"),
|
|
("I want to add caching", "Redis is already available at redis.internal:6379"),
|
|
("The frontend is loading slowly", "CDN is CloudFlare, assets should be cached with 1h TTL"),
|
|
("Let's refactor the payment module", "Payment uses Stripe API, webhook handler is in payments/webhook.py"),
|
|
("I need to set up a new server", "Standard setup: Ubuntu 22.04, Docker, Tailscale, monitoring agent"),
|
|
("The log files are too large", "Logs rotate daily, kept for 30 days, shipped to Loki"),
|
|
]
|
|
|
|
# Paraphrased queries (semantically similar to cues but different wording)
|
|
PARAPHRASED_QUERIES = [
|
|
"How's the weather outside?",
|
|
"We should push the new release",
|
|
"The DB performance is terrible",
|
|
"Please look at my code changes",
|
|
"There's a login bug I need to fix",
|
|
"We need better observability",
|
|
"Getting internal server errors from the API",
|
|
"I'm interested in learning a new language like Rust",
|
|
"Need to organize a team meeting",
|
|
"How to set up nginx as a web server?",
|
|
"CI tests keep breaking",
|
|
"The search feature needs to be faster",
|
|
"How do I create a database backup?",
|
|
"The service is using too much RAM",
|
|
"Help me with Docker configuration",
|
|
"I want to implement caching for the API",
|
|
"The website is really slow",
|
|
"The payment system needs restructuring",
|
|
"Setting up a fresh Linux server",
|
|
"Logs are eating up disk space",
|
|
]
|
|
|
|
|
|
def load_model():
|
|
"""Load a small, fast sentence transformer."""
|
|
from sentence_transformers import SentenceTransformer
|
|
print("Loading sentence-transformers model...")
|
|
model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
print(f" Model loaded. Embedding dim: {model.get_sentence_embedding_dimension()}")
|
|
return model
|
|
|
|
|
|
def embed_texts(model, texts):
|
|
"""Encode texts to normalized embeddings on GPU."""
|
|
embeddings = model.encode(texts, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
return embeddings
|
|
|
|
|
|
class HebbianMemory:
|
|
def __init__(self, input_dim, code_dim=16384, k=20):
|
|
self.k = k
|
|
self.code_dim = code_dim
|
|
self.input_dim = input_dim
|
|
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
|
* (1.0 / input_dim**0.5))
|
|
self.target_proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
|
* (1.0 / input_dim**0.5))
|
|
self.W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
|
self.cue_store = [] # For coarse retrieval
|
|
self.target_store = []
|
|
self.metadata = [] # Store original text for debugging
|
|
|
|
def sep(self, x):
|
|
return winner_take_all(x @ self.proj, self.k)
|
|
|
|
def sep_target(self, x):
|
|
return winner_take_all(x @ self.target_proj, self.k)
|
|
|
|
def learn(self, cue_emb, target_emb, cue_text="", target_text=""):
|
|
cc = self.sep(cue_emb)
|
|
tc = self.sep_target(target_emb)
|
|
self.W += torch.outer(tc, cc)
|
|
self.cue_store.append(cue_emb.detach().clone())
|
|
self.target_store.append(target_emb.detach().clone())
|
|
self.metadata.append({"cue": cue_text, "target": target_text})
|
|
|
|
def recall_direct(self, query_emb):
|
|
"""Direct WTA recall (no coarse retrieval)."""
|
|
cc = self.sep(query_emb)
|
|
raw = self.W @ cc
|
|
return winner_take_all(raw, self.k)
|
|
|
|
def recall_coarse_to_fine(self, query_emb, top_n=3):
|
|
"""Coarse: NN in embedding space. Fine: Hebbian recall from best match."""
|
|
if not self.cue_store:
|
|
return torch.zeros(self.code_dim, device=DEVICE)
|
|
|
|
cue_matrix = torch.stack(self.cue_store)
|
|
sims = nn.functional.cosine_similarity(
|
|
query_emb.unsqueeze(0), cue_matrix, dim=-1)
|
|
best_idx = sims.argmax()
|
|
best_cue = self.cue_store[best_idx]
|
|
|
|
cc = self.sep(best_cue)
|
|
raw = self.W @ cc
|
|
return winner_take_all(raw, self.k), best_idx.item()
|
|
|
|
def find_nearest_target(self, recalled_code, top_n=3):
|
|
"""Given a recalled code, find which stored targets it matches."""
|
|
target_codes = [self.sep_target(t) for t in self.target_store]
|
|
sims = [cosine(recalled_code, tc) for tc in target_codes]
|
|
sorted_idx = np.argsort(sims)[::-1]
|
|
return [(int(i), sims[i], self.metadata[i]) for i in sorted_idx[:top_n]]
|
|
|
|
|
|
def test_basic_recall(model, mem):
|
|
"""Test: can we recall the correct memory for each cue?"""
|
|
print("\n=== Test 1: Direct Recall (exact cues) ===")
|
|
|
|
cue_texts = [p[0] for p in MEMORY_PAIRS]
|
|
target_texts = [p[1] for p in MEMORY_PAIRS]
|
|
|
|
correct_count = 0
|
|
for i in range(len(MEMORY_PAIRS)):
|
|
cue_emb = embed_texts(model, [cue_texts[i]])[0]
|
|
recalled = mem.recall_direct(cue_emb)
|
|
matches = mem.find_nearest_target(recalled, top_n=3)
|
|
|
|
is_correct = matches[0][0] == i
|
|
correct_count += is_correct
|
|
|
|
if not is_correct and i < 5: # Show first few errors
|
|
print(f" ✗ Cue: '{cue_texts[i][:40]}...'")
|
|
print(f" Expected: [{i}] '{target_texts[i][:50]}...'")
|
|
print(f" Got: [{matches[0][0]}] '{matches[0][2]['target'][:50]}...' "
|
|
f"(sim={matches[0][1]:.3f})")
|
|
|
|
print(f" Direct recall: {correct_count}/{len(MEMORY_PAIRS)} "
|
|
f"({correct_count/len(MEMORY_PAIRS):.0%})")
|
|
return correct_count / len(MEMORY_PAIRS)
|
|
|
|
|
|
def test_paraphrase_recall(model, mem):
|
|
"""Test: can we recall memories using paraphrased queries?"""
|
|
print("\n=== Test 2: Paraphrase Recall ===")
|
|
|
|
target_texts = [p[1] for p in MEMORY_PAIRS]
|
|
|
|
# Direct recall (WTA)
|
|
direct_correct = 0
|
|
coarse_correct = 0
|
|
|
|
for i, query in enumerate(PARAPHRASED_QUERIES):
|
|
query_emb = embed_texts(model, [query])[0]
|
|
|
|
# Direct
|
|
recalled = mem.recall_direct(query_emb)
|
|
matches = mem.find_nearest_target(recalled, top_n=3)
|
|
is_direct = matches[0][0] == i
|
|
direct_correct += is_direct
|
|
|
|
# Coarse-to-fine
|
|
recalled_cf, best_idx = mem.recall_coarse_to_fine(query_emb)
|
|
matches_cf = mem.find_nearest_target(recalled_cf, top_n=3)
|
|
is_coarse = matches_cf[0][0] == i
|
|
coarse_correct += is_coarse
|
|
|
|
if i < 5:
|
|
status_d = "✓" if is_direct else "✗"
|
|
status_c = "✓" if is_coarse else "✗"
|
|
print(f" [{status_d}/{status_c}] Q: '{query[:50]}...'")
|
|
if not is_direct:
|
|
print(f" Direct got: [{matches[0][0]}] "
|
|
f"'{matches[0][2]['target'][:50]}...'")
|
|
if is_coarse and not is_direct:
|
|
print(f" Coarse-fine got it right! (via cue #{best_idx})")
|
|
|
|
n = len(PARAPHRASED_QUERIES)
|
|
print(f"\n Direct recall: {direct_correct}/{n} ({direct_correct/n:.0%})")
|
|
print(f" Coarse-to-fine: {coarse_correct}/{n} ({coarse_correct/n:.0%})")
|
|
return direct_correct / n, coarse_correct / n
|
|
|
|
|
|
def test_semantic_neighborhood(model, mem):
|
|
"""Test: do semantically related cues retrieve related memories?"""
|
|
print("\n=== Test 3: Semantic Neighborhood ===")
|
|
|
|
test_queries = [
|
|
"server is down", # Should relate to: API 500, deployment, monitoring
|
|
"performance problem", # Should relate to: DB slow, memory, search
|
|
"security issue", # Should relate to: auth bug, JWT tokens
|
|
"infrastructure setup", # Should relate to: server, Docker, k3s
|
|
]
|
|
|
|
for query in test_queries:
|
|
query_emb = embed_texts(model, [query])[0]
|
|
recalled = mem.recall_direct(query_emb)
|
|
matches = mem.find_nearest_target(recalled, top_n=3)
|
|
|
|
print(f"\n Query: '{query}'")
|
|
for rank, (idx, sim, meta) in enumerate(matches):
|
|
print(f" #{rank+1} (sim={sim:.3f}): {meta['target'][:60]}...")
|
|
|
|
|
|
def test_multihop_semantic(model, mem):
|
|
"""Test: multi-hop with semantic embeddings.
|
|
Learn: "weather" → "morning routine" → "coffee shop"
|
|
Can we go from "weather" to "coffee shop" in 2 hops?
|
|
"""
|
|
print("\n=== Test 4: Multi-hop with Semantic Chains ===")
|
|
|
|
chains = [
|
|
["What's the weather?", "I usually check weather before going out",
|
|
"My favorite coffee shop is around the corner", "They have great latte art"],
|
|
["Let's review the code", "The code review found a memory leak",
|
|
"Memory leaks often cause OOM kills", "We need to add memory limits to k8s pods"],
|
|
["Deploy to production", "Production uses blue-green deployment",
|
|
"The blue environment is currently active", "Switch DNS to green when ready"],
|
|
]
|
|
|
|
for chain_idx, chain in enumerate(chains):
|
|
print(f"\n Chain {chain_idx+1}: {' → '.join([c[:20]+'...' for c in chain])}")
|
|
|
|
# Create a separate small memory for this chain
|
|
chain_mem = HebbianMemory(384, code_dim=8192, k=20)
|
|
|
|
chain_embs = [embed_texts(model, [text])[0] for text in chain]
|
|
|
|
# Learn consecutive pairs
|
|
for i in range(len(chain) - 1):
|
|
chain_mem.learn(chain_embs[i], chain_embs[i+1],
|
|
chain[i], chain[i+1])
|
|
|
|
# Test recall at each hop distance
|
|
for hops in range(1, len(chain)):
|
|
start_emb = chain_embs[0]
|
|
target_code = chain_mem.sep_target(chain_embs[hops])
|
|
|
|
# Multi-hop
|
|
code = chain_mem.sep(start_emb)
|
|
for _ in range(hops):
|
|
raw = chain_mem.W @ code
|
|
code = winner_take_all(raw, chain_mem.k)
|
|
|
|
sim = cosine(code, target_code)
|
|
print(f" {hops} hop(s): '{chain[0][:25]}...' → "
|
|
f"'{chain[hops][:25]}...' sim={sim:.4f}")
|
|
|
|
|
|
def test_embedding_distances(model):
|
|
"""Analyze: how far apart are original and paraphrased embeddings?"""
|
|
print("\n=== Test 5: Embedding Distance Analysis ===")
|
|
|
|
cue_texts = [p[0] for p in MEMORY_PAIRS]
|
|
cue_embs = embed_texts(model, cue_texts)
|
|
para_embs = embed_texts(model, PARAPHRASED_QUERIES)
|
|
|
|
# Same-pair distances
|
|
same_pair_sims = []
|
|
for i in range(len(cue_texts)):
|
|
s = cosine(cue_embs[i], para_embs[i])
|
|
same_pair_sims.append(s)
|
|
|
|
# Different-pair distances
|
|
diff_pair_sims = []
|
|
for i in range(len(cue_texts)):
|
|
for j in range(len(cue_texts)):
|
|
if i != j:
|
|
diff_pair_sims.append(cosine(cue_embs[i], para_embs[j]))
|
|
|
|
print(f" Same-pair cosine sim: mean={np.mean(same_pair_sims):.4f}, "
|
|
f"min={np.min(same_pair_sims):.4f}, max={np.max(same_pair_sims):.4f}")
|
|
print(f" Diff-pair cosine sim: mean={np.mean(diff_pair_sims):.4f}, "
|
|
f"min={np.min(diff_pair_sims):.4f}, max={np.max(diff_pair_sims):.4f}")
|
|
print(f" Gap: {np.mean(same_pair_sims) - np.mean(diff_pair_sims):.4f}")
|
|
|
|
# Show some examples
|
|
print("\n Sample distances:")
|
|
for i in range(5):
|
|
print(f" '{cue_texts[i][:35]}...' ↔ '{PARAPHRASED_QUERIES[i][:35]}...' "
|
|
f"sim={same_pair_sims[i]:.4f}")
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 4: Real Sentence Embeddings")
|
|
print("=" * 60)
|
|
|
|
model = load_model()
|
|
|
|
# Analyze embedding space first
|
|
test_embedding_distances(model)
|
|
|
|
# Build memory
|
|
print("\n--- Building memory ---")
|
|
embed_dim = model.get_sentence_embedding_dimension()
|
|
mem = HebbianMemory(embed_dim, code_dim=16384, k=20)
|
|
|
|
cue_texts = [p[0] for p in MEMORY_PAIRS]
|
|
target_texts = [p[1] for p in MEMORY_PAIRS]
|
|
|
|
cue_embs = embed_texts(model, cue_texts)
|
|
target_embs = embed_texts(model, target_texts)
|
|
|
|
for i in range(len(MEMORY_PAIRS)):
|
|
mem.learn(cue_embs[i], target_embs[i], cue_texts[i], target_texts[i])
|
|
|
|
print(f" Stored {len(MEMORY_PAIRS)} memory pairs")
|
|
|
|
# Run tests
|
|
test_basic_recall(model, mem)
|
|
test_paraphrase_recall(model, mem)
|
|
test_semantic_neighborhood(model, mem)
|
|
test_multihop_semantic(model, mem)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|