Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
369 lines
13 KiB
Python
369 lines
13 KiB
Python
"""Experiment 2e: Noise-tolerant retrieval.
|
||
|
||
Problem: WTA pattern separation is brittle to noise in cue embeddings.
|
||
Real use case requires retrieving from semantically similar (not identical) cues.
|
||
|
||
Approaches to test:
|
||
1. Soft-WTA: Use softmax temperature instead of hard top-k
|
||
2. Multi-probe: Multiple noisy retrievals + voting
|
||
3. Coarse-to-fine: Nearest-neighbor in embedding space → exact Hebbian recall
|
||
4. Learned similarity-preserving hash: train the separator to be noise-robust
|
||
5. Wider k: trade capacity for noise robustness
|
||
"""
|
||
|
||
import sys
|
||
import time
|
||
import json
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
|
||
DEVICE = "cuda"
|
||
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
||
|
||
|
||
def cosine(a, b):
|
||
if a.norm() == 0 or b.norm() == 0:
|
||
return 0.0
|
||
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
||
|
||
|
||
def winner_take_all(x, k):
|
||
_, topk_idx = x.topk(k, dim=-1)
|
||
out = torch.zeros_like(x)
|
||
out.scatter_(-1, topk_idx, 1.0)
|
||
return out
|
||
|
||
|
||
class SoftWTASeparator(nn.Module):
|
||
"""Soft winner-take-all using temperature-scaled softmax.
|
||
Instead of hard binary codes, produces soft sparse codes.
|
||
More robust to noise but reduces discrimination.
|
||
"""
|
||
def __init__(self, input_dim, code_dim, temperature=0.1):
|
||
super().__init__()
|
||
self.temperature = temperature
|
||
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||
self.register_buffer('proj', proj)
|
||
|
||
def forward(self, x):
|
||
h = x @ self.proj
|
||
# Soft WTA: high temp → more spread, low temp → more sparse
|
||
return torch.softmax(h / self.temperature, dim=-1)
|
||
|
||
|
||
class MultiProbeSeparator(nn.Module):
|
||
"""Multiple random projections, retrieve from all, majority vote."""
|
||
def __init__(self, input_dim, code_dim, k_active, num_probes=8):
|
||
super().__init__()
|
||
self.k_active = k_active
|
||
self.num_probes = num_probes
|
||
# Multiple random projections
|
||
projs = torch.randn(num_probes, input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||
self.register_buffer('projs', projs)
|
||
|
||
def forward(self, x):
|
||
"""Returns averaged code across all probes."""
|
||
votes = torch.zeros(self.projs.shape[2], device=x.device)
|
||
for i in range(self.num_probes):
|
||
h = x @ self.projs[i]
|
||
code = winner_take_all(h, self.k_active)
|
||
votes += code
|
||
# Threshold: active if majority of probes agree
|
||
threshold = self.num_probes / 2
|
||
return (votes > threshold).float()
|
||
|
||
|
||
class CoarseToFineMemory(nn.Module):
|
||
"""Coarse: nearest-neighbor in embedding space.
|
||
Fine: exact Hebbian recall from nearest stored cue.
|
||
|
||
This is the most practical approach: SNN/Hebbian provides the
|
||
association storage, but retrieval is bootstrapped by embedding similarity.
|
||
"""
|
||
def __init__(self, input_dim, code_dim=16384, k_active=20):
|
||
super().__init__()
|
||
self.code_dim = code_dim
|
||
self.k_active = k_active
|
||
|
||
proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)
|
||
self.register_buffer('proj', proj)
|
||
target_proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)
|
||
self.register_buffer('target_proj', target_proj)
|
||
|
||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE),
|
||
requires_grad=False)
|
||
|
||
# Store raw cue embeddings for nearest-neighbor lookup
|
||
self.cue_store = []
|
||
|
||
def separate(self, x, proj):
|
||
h = x @ proj
|
||
return winner_take_all(h, self.k_active)
|
||
|
||
def learn(self, cue, target):
|
||
self.cue_store.append(cue.detach().clone())
|
||
cue_code = self.separate(cue, self.proj)
|
||
target_code = self.separate(target, self.target_proj)
|
||
self.W.data += torch.outer(target_code, cue_code)
|
||
|
||
def recall(self, query):
|
||
"""Coarse: find nearest stored cue. Fine: Hebbian recall."""
|
||
if not self.cue_store:
|
||
return torch.zeros(self.code_dim, device=DEVICE)
|
||
|
||
# Nearest neighbor
|
||
cue_matrix = torch.stack(self.cue_store) # [N, dim]
|
||
sims = nn.functional.cosine_similarity(
|
||
query.unsqueeze(0), cue_matrix, dim=-1) # [N]
|
||
best_idx = sims.argmax()
|
||
best_cue = self.cue_store[best_idx]
|
||
|
||
# Exact Hebbian recall with nearest cue
|
||
cue_code = self.separate(best_cue, self.proj)
|
||
raw = self.W @ cue_code
|
||
return winner_take_all(raw, self.k_active)
|
||
|
||
|
||
def test_approach(name, mem_class, num_pairs=100, noise_levels=None, **kwargs):
|
||
"""Generic test harness."""
|
||
if noise_levels is None:
|
||
noise_levels = [0.0, 0.1, 0.2, 0.5, 1.0, 2.0]
|
||
|
||
input_dim = 768
|
||
cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
|
||
for _ in range(num_pairs)]
|
||
targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
|
||
for _ in range(num_pairs)]
|
||
|
||
mem = mem_class(**kwargs).to(DEVICE) if not isinstance(mem_class, nn.Module) else mem_class
|
||
|
||
# Learn
|
||
for i in range(num_pairs):
|
||
mem.learn(cues[i], targets[i])
|
||
|
||
results = {}
|
||
for noise_std in noise_levels:
|
||
correct_sims = []
|
||
for i in range(num_pairs):
|
||
noisy_cue = cues[i] + torch.randn_like(cues[i]) * noise_std
|
||
noisy_cue = nn.functional.normalize(noisy_cue, dim=0)
|
||
|
||
recalled = mem.recall(noisy_cue)
|
||
|
||
# Compare to target code
|
||
if hasattr(mem, 'target_separator'):
|
||
target_code = mem.target_separator(targets[i])
|
||
elif hasattr(mem, 'target_proj'):
|
||
target_code = winner_take_all(targets[i] @ mem.target_proj, mem.k_active)
|
||
else:
|
||
target_code = targets[i]
|
||
|
||
cs = cosine(recalled, target_code)
|
||
correct_sims.append(cs)
|
||
|
||
mc = np.mean(correct_sims)
|
||
exact = np.mean([s > 0.99 for s in correct_sims])
|
||
results[noise_std] = {"mean_cos": mc, "exact_rate": exact}
|
||
print(f" {name}: noise={noise_std:.2f} → CosSim={mc:.4f}, Exact={exact:.2%}")
|
||
|
||
return results
|
||
|
||
|
||
# --- Approach-specific memory classes ---
|
||
|
||
class SoftWTAMemory(nn.Module):
|
||
def __init__(self, input_dim=768, code_dim=16384, temperature=0.1):
|
||
super().__init__()
|
||
self.separator = SoftWTASeparator(input_dim, code_dim, temperature)
|
||
self.target_separator = SoftWTASeparator(input_dim, code_dim, temperature)
|
||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
|
||
|
||
def learn(self, cue, target):
|
||
cc = self.separator(cue)
|
||
tc = self.target_separator(target)
|
||
self.W.data += torch.outer(tc, cc)
|
||
|
||
def recall(self, cue):
|
||
cc = self.separator(cue)
|
||
return self.W @ cc
|
||
|
||
|
||
class MultiProbeMemory(nn.Module):
|
||
def __init__(self, input_dim=768, code_dim=8192, k_active=20, num_probes=16):
|
||
super().__init__()
|
||
self.separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes)
|
||
self.target_separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes)
|
||
self.k_active = k_active
|
||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
|
||
|
||
def learn(self, cue, target):
|
||
cc = self.separator(cue)
|
||
tc = self.target_separator(target)
|
||
self.W.data += torch.outer(tc, cc)
|
||
|
||
def recall(self, cue):
|
||
cc = self.separator(cue)
|
||
raw = self.W @ cc
|
||
return winner_take_all(raw, self.k_active)
|
||
|
||
|
||
class WiderKMemory(nn.Module):
|
||
"""Just use wider k — simple and might work."""
|
||
def __init__(self, input_dim=768, code_dim=16384, k_active=200):
|
||
super().__init__()
|
||
self.k_active = k_active
|
||
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||
self.register_buffer('proj', proj)
|
||
target_proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||
self.register_buffer('target_proj', target_proj)
|
||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
|
||
|
||
def learn(self, cue, target):
|
||
cc = winner_take_all(cue @ self.proj, self.k_active)
|
||
tc = winner_take_all(target @ self.target_proj, self.k_active)
|
||
self.W.data += torch.outer(tc, cc)
|
||
|
||
def recall(self, cue):
|
||
cc = winner_take_all(cue @ self.proj, self.k_active)
|
||
raw = self.W @ cc
|
||
return winner_take_all(raw, self.k_active)
|
||
|
||
@property
|
||
def target_separator(self):
|
||
return None # handled differently
|
||
|
||
|
||
def main():
|
||
print("=" * 60)
|
||
print("Experiment 2e: Noise-Tolerant Retrieval")
|
||
print("=" * 60)
|
||
|
||
noise_levels = [0.0, 0.05, 0.1, 0.2, 0.5, 1.0]
|
||
num_pairs = 100
|
||
all_results = {}
|
||
|
||
# 1. Soft WTA
|
||
print("\n=== 1. Soft WTA ===")
|
||
for temp in [0.01, 0.05, 0.1, 0.5]:
|
||
name = f"soft_wta_t{temp}"
|
||
print(f"\n-- temperature={temp} --")
|
||
mem = SoftWTAMemory(temperature=temp).to(DEVICE)
|
||
|
||
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||
for i in range(num_pairs):
|
||
mem.learn(cues[i], targets[i])
|
||
|
||
results = {}
|
||
for ns in noise_levels:
|
||
sims = []
|
||
for i in range(num_pairs):
|
||
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
||
recalled = mem.recall(noisy)
|
||
tc = mem.target_separator(targets[i])
|
||
sims.append(cosine(recalled, tc))
|
||
mc = np.mean(sims)
|
||
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
|
||
results[ns] = mc
|
||
all_results[name] = results
|
||
|
||
# 2. Multi-probe
|
||
print("\n=== 2. Multi-Probe ===")
|
||
for n_probes in [4, 8, 16, 32]:
|
||
name = f"multiprobe_{n_probes}"
|
||
print(f"\n-- probes={n_probes} --")
|
||
mem = MultiProbeMemory(num_probes=n_probes).to(DEVICE)
|
||
|
||
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||
for i in range(num_pairs):
|
||
mem.learn(cues[i], targets[i])
|
||
|
||
results = {}
|
||
for ns in noise_levels:
|
||
sims = []
|
||
for i in range(num_pairs):
|
||
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
||
recalled = mem.recall(noisy)
|
||
tc = mem.target_separator(targets[i])
|
||
sims.append(cosine(recalled, tc))
|
||
mc = np.mean(sims)
|
||
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
|
||
results[ns] = mc
|
||
all_results[name] = results
|
||
|
||
# 3. Coarse-to-fine
|
||
print("\n=== 3. Coarse-to-Fine (NN + Hebbian) ===")
|
||
mem = CoarseToFineMemory(768).to(DEVICE)
|
||
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||
for i in range(num_pairs):
|
||
mem.learn(cues[i], targets[i])
|
||
|
||
results = {}
|
||
for ns in noise_levels:
|
||
sims = []
|
||
for i in range(num_pairs):
|
||
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
||
recalled = mem.recall(noisy)
|
||
tc = winner_take_all(targets[i] @ mem.target_proj, mem.k_active)
|
||
sims.append(cosine(recalled, tc))
|
||
mc = np.mean(sims)
|
||
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
|
||
results[ns] = mc
|
||
all_results["coarse_to_fine"] = results
|
||
|
||
# 4. Wider k
|
||
print("\n=== 4. Wider K ===")
|
||
for k in [50, 100, 200, 500, 1000]:
|
||
name = f"wider_k_{k}"
|
||
print(f"\n-- k={k} --")
|
||
mem = WiderKMemory(k_active=k).to(DEVICE)
|
||
|
||
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||
for i in range(num_pairs):
|
||
mem.learn(cues[i], targets[i])
|
||
|
||
results = {}
|
||
for ns in noise_levels:
|
||
sims = []
|
||
for i in range(num_pairs):
|
||
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
||
recalled = mem.recall(noisy)
|
||
tc = winner_take_all(targets[i] @ mem.target_proj, k)
|
||
sims.append(cosine(recalled, tc))
|
||
mc = np.mean(sims)
|
||
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
|
||
results[ns] = mc
|
||
all_results[name] = results
|
||
|
||
# Save
|
||
serializable = {}
|
||
for k, v in all_results.items():
|
||
serializable[k] = {str(kk): float(vv) for kk, vv in v.items()}
|
||
with open(RESULTS_DIR / "exp02e_results.json", "w") as f:
|
||
json.dump(serializable, f, indent=2)
|
||
|
||
# Summary table
|
||
print("\n" + "=" * 80)
|
||
print("SUMMARY: CosSim at each noise level")
|
||
print(f"{'Method':<25}", end="")
|
||
for ns in noise_levels:
|
||
print(f" σ={ns:.2f}", end="")
|
||
print()
|
||
print("-" * 80)
|
||
for method, res in all_results.items():
|
||
print(f"{method:<25}", end="")
|
||
for ns in noise_levels:
|
||
v = res.get(ns, 0)
|
||
print(f" {v:>5.3f}", end="")
|
||
print()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|