nuonuo/experiments/exp02c_pattern_separation.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

210 lines
7.4 KiB
Python

"""Experiment 2c: Pattern separation + improved associative recall.
Key insight from 2b: random spike patterns have too much overlap,
causing catastrophic interference in associative memory.
Fix: Implement pattern separation (like dentate gyrus in hippocampus):
1. Winner-take-all: only top-k neurons fire → guaranteed sparse, minimal overlap
2. Random sparse projection: patterns projected through sparse random matrix
3. Scale up neurons to improve signal-to-noise ratio (capacity ∝ N/P)
Also test: direct Hebbian in rate-space (skip spike conversion entirely)
"""
import sys
import time
import json
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
DEVICE = "cuda"
RESULTS_DIR = Path(__file__).parent.parent / "doc"
def cosine(a, b):
if a.norm() == 0 or b.norm() == 0:
return 0.0
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
def winner_take_all(x, k):
"""Keep only top-k values, zero out the rest. Differentiable-ish."""
topk_vals, topk_idx = x.topk(k, dim=-1)
out = torch.zeros_like(x)
out.scatter_(-1, topk_idx, 1.0) # Binary: active or not
return out
class PatternSeparator(nn.Module):
"""Dentate gyrus analog: transforms input patterns into sparse, orthogonal codes."""
def __init__(self, input_dim, code_dim, k_active):
super().__init__()
self.k_active = k_active
# Sparse random projection (fixed, not learned)
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('proj', proj)
def forward(self, x):
"""x: [input_dim] → [code_dim] sparse binary"""
h = x @ self.proj
return winner_take_all(h, self.k_active)
class HebbianMemory(nn.Module):
"""Heteroassociative memory with pattern separation."""
def __init__(self, input_dim, code_dim=8192, k_active=50, lr=1.0):
super().__init__()
self.separator = PatternSeparator(input_dim, code_dim, k_active)
self.code_dim = code_dim
self.lr = lr
# Separate separator for targets (different random projection)
self.target_separator = PatternSeparator(input_dim, code_dim, k_active)
# Association matrix: separated_cue → separated_target
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
def learn(self, cue, target):
"""cue, target: [dim] continuous vectors"""
cue_code = self.separator(cue)
target_code = self.target_separator(target)
# Outer product Hebbian update
self.W.data += self.lr * torch.outer(target_code, cue_code)
def recall(self, cue, k_recall=50):
"""Returns separated target code."""
cue_code = self.separator(cue)
raw = self.W @ cue_code
# WTA on output to clean up
return winner_take_all(raw, k_recall)
def recall_continuous(self, cue):
"""Returns continuous activation (for cosine sim)."""
cue_code = self.separator(cue)
return self.W @ cue_code
def test_hebbian_with_separation(input_dim, code_dim, k_active, num_pairs, lr):
"""Test associative recall with pattern separation."""
mem = HebbianMemory(input_dim, code_dim, k_active, lr).to(DEVICE)
# Generate random normalized vectors as memories
cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
for _ in range(num_pairs)]
# Learn
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
# Test recall in code space (after separation)
correct_sims = []
wrong_sims = []
for i in range(num_pairs):
recalled = mem.recall(cues[i], k_recall=k_active)
target_code = mem.target_separator(targets[i])
cs = cosine(recalled, target_code)
correct_sims.append(cs)
for j in range(min(num_pairs, 20)): # limit comparisons for speed
if j != i:
wrong_code = mem.target_separator(targets[j])
wrong_sims.append(cosine(recalled, wrong_code))
mc = np.mean(correct_sims)
mw = np.mean(wrong_sims) if wrong_sims else 0
print(f" code={code_dim}, k={k_active}, pairs={num_pairs}, lr={lr:.2f} | "
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}")
return {"correct": mc, "wrong": mw, "disc": mc - mw,
"code_dim": code_dim, "k_active": k_active,
"num_pairs": num_pairs, "lr": lr}
def test_overlap_analysis(code_dim, k_active, num_patterns):
"""Measure how orthogonal the separated patterns actually are."""
sep = PatternSeparator(768, code_dim, k_active).to(DEVICE)
patterns = []
for _ in range(num_patterns):
x = nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0)
code = sep(x)
patterns.append(code)
# Pairwise cosine similarity
sims = []
for i in range(num_patterns):
for j in range(i+1, num_patterns):
s = cosine(patterns[i], patterns[j])
sims.append(s)
mean_sim = np.mean(sims)
max_sim = np.max(sims)
print(f" code={code_dim}, k={k_active}: mean_overlap={mean_sim:.4f}, max_overlap={max_sim:.4f}")
return {"mean_overlap": mean_sim, "max_overlap": max_sim}
def main():
print("=" * 60)
print("Experiment 2c: Pattern Separation + Hebbian Memory")
print("=" * 60)
results = []
# Part 1: Overlap analysis — how orthogonal are separated patterns?
print("\n=== Part 1: Pattern overlap after separation ===")
for code_dim in [2048, 4096, 8192, 16384]:
for k in [20, 50, 100]:
ov = test_overlap_analysis(code_dim, k, 100)
results.append({"test": "overlap", "code_dim": code_dim, "k": k, **ov})
# Part 2: Associative recall with separation
print("\n=== Part 2: Recall with pattern separation ===")
print("\n-- Scaling pairs --")
for n in [1, 5, 10, 20, 50, 100, 200, 500]:
r = test_hebbian_with_separation(768, 8192, 50, n, lr=1.0)
results.append({"test": f"sep_pairs_{n}", **r})
print("\n-- Code dimension sweep (100 pairs) --")
for cd in [2048, 4096, 8192, 16384]:
r = test_hebbian_with_separation(768, cd, 50, 100, lr=1.0)
results.append({"test": f"sep_codedim_{cd}", **r})
print("\n-- Sparsity sweep (100 pairs, code=8192) --")
for k in [10, 20, 50, 100, 200]:
r = test_hebbian_with_separation(768, 8192, k, 100, lr=1.0)
results.append({"test": f"sep_k_{k}", **r})
print("\n-- Capacity test: find the breaking point (code=16384, k=20) --")
for n in [10, 50, 100, 200, 500, 1000, 2000]:
r = test_hebbian_with_separation(768, 16384, 20, n, lr=1.0)
results.append({"test": f"cap_{n}", **r})
# Save
with open(RESULTS_DIR / "exp02c_results.json", "w") as f:
json.dump(results, f, indent=2, default=float)
# Find best config
recall_results = [r for r in results if r.get("disc") is not None and "cap_" in r.get("test", "")]
if recall_results:
print("\n=== Capacity curve (code=16384, k=20) ===")
print(f"{'Pairs':>6} {'Correct':>8} {'Wrong':>8} {'Disc':>8}")
for r in recall_results:
print(f"{r['num_pairs']:>6} {r['correct']:>8.4f} {r['wrong']:>8.4f} {r['disc']:>8.4f}")
if __name__ == "__main__":
main()