nuonuo/experiments/exp13_snn_hopfield.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

307 lines
10 KiB
Python

"""Experiment P5: SNN-native Hopfield (spike-based attention).
Goal: Implement Hopfield-like attractor dynamics using LIF neurons.
The connection: Hopfield softmax attention with inverse temperature β
is equivalent to a Boltzmann distribution at temperature 1/β.
In SNN terms: β maps to membrane time constant / threshold ratio.
Approach: Replace softmax(β * q @ K^T) @ V with:
1. Encode query as spike train
2. Feed through recurrent LIF network with stored patterns as synaptic weights
3. Network settles to attractor (nearest stored pattern)
4. Read out associated target
This is closer to biological CA3 recurrent dynamics.
"""
import sys
import time
from pathlib import Path
import torch
import torch.nn as nn
import snntorch as snn
import numpy as np
DEVICE = "cuda"
def cosine(a, b):
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
class SNNHopfield(nn.Module):
"""Spike-based Hopfield network.
Architecture:
- Input layer: converts query embedding to current injection
- Recurrent layer: LIF neurons with Hopfield-like connection weights
- Readout: spike rates → attention weights → target embedding
The recurrent weights are set (not trained) based on stored patterns,
making this a "configured" SNN, not a "trained" one.
"""
def __init__(self, dim, beta=0.9, threshold=1.0, num_steps=50):
super().__init__()
self.dim = dim
self.num_steps = num_steps
self.beta_lif = beta # LIF membrane decay
self.threshold = threshold
self.lif = snn.Leaky(beta=beta, threshold=threshold)
# Stored patterns
self.cue_patterns = []
self.target_patterns = []
def store(self, cue_emb, target_emb):
self.cue_patterns.append(cue_emb.detach())
self.target_patterns.append(target_emb.detach())
def _build_weights(self):
"""Build Hopfield-like recurrent weights from stored patterns.
W_ij = Σ_μ (pattern_μ_i * pattern_μ_j) / N
This creates attractor states at each stored pattern.
"""
if not self.cue_patterns:
return torch.zeros(self.dim, self.dim, device=DEVICE)
patterns = torch.stack(self.cue_patterns) # [N_patterns, dim]
W = patterns.T @ patterns / len(self.cue_patterns) # [dim, dim]
# Remove diagonal (no self-connections, like biological networks)
W.fill_diagonal_(0)
return W
def recall(self, query_emb):
"""Spike-based attractor dynamics.
1. Inject query as constant current
2. Let network settle via recurrent dynamics
3. Read spike rates → find nearest stored pattern → get target
"""
W = self._build_weights()
# LIF dynamics
mem = torch.zeros(self.dim, device=DEVICE)
spike_counts = torch.zeros(self.dim, device=DEVICE)
# Constant input current from query (scaled)
input_current = query_emb * 2.0 # Scale to help reach threshold
for step in range(self.num_steps):
# Total current: external input + recurrent
if step < self.num_steps // 2:
# First half: external input drives the network
total_current = input_current + W @ (mem / self.threshold)
else:
# Second half: only recurrent (free running, settle to attractor)
total_current = W @ (mem / self.threshold)
spk, mem = self.lif(total_current, mem)
spike_counts += spk
# Spike rates as representation
spike_rates = spike_counts / self.num_steps # [dim]
# Find nearest stored pattern by spike rate similarity
if not self.cue_patterns:
return None, None
cue_mat = torch.stack(self.cue_patterns)
sims = nn.functional.cosine_similarity(
spike_rates.unsqueeze(0), cue_mat, dim=-1)
# Softmax attention based on similarity (hybrid: spike settle + soft readout)
attn = torch.softmax(sims * 16.0, dim=0)
target_mat = torch.stack(self.target_patterns)
recalled = attn @ target_mat
recalled = nn.functional.normalize(recalled, dim=0)
best_idx = sims.argmax().item()
return recalled, best_idx
def recall_pure_spike(self, query_emb):
"""Fully spike-based recall (no softmax at readout)."""
W = self._build_weights()
mem = torch.zeros(self.dim, device=DEVICE)
spike_counts = torch.zeros(self.dim, device=DEVICE)
input_current = query_emb * 2.0
for step in range(self.num_steps):
if step < self.num_steps // 2:
total_current = input_current + W @ (mem / self.threshold)
else:
total_current = W @ (mem / self.threshold)
spk, mem = self.lif(total_current, mem)
spike_counts += spk
spike_rates = spike_counts / self.num_steps
# Pure spike readout: direct cosine similarity (no softmax)
cue_mat = torch.stack(self.cue_patterns)
sims = nn.functional.cosine_similarity(
spike_rates.unsqueeze(0), cue_mat, dim=-1)
best_idx = sims.argmax().item()
return self.target_patterns[best_idx], best_idx
def load_model():
from sentence_transformers import SentenceTransformer
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
def emb(model, text):
return model.encode([text], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
def test_basic(model):
"""Basic SNN Hopfield recall."""
print("=== Test 1: Basic SNN Hopfield ===\n")
pairs = [
("The database is slow", "Check missing indexes"),
("Deploy to production", "Use blue-green deployment"),
("The API returns 500", "Check for OOM in worker"),
("Set up monitoring", "Prometheus and Grafana"),
("Tests failing in CI", "Need postgres container"),
]
for num_steps in [20, 50, 100, 200]:
for beta in [0.8, 0.9, 0.95]:
net = SNNHopfield(384, beta=beta, num_steps=num_steps).to(DEVICE)
for cue, target in pairs:
net.store(emb(model, cue), emb(model, target))
# Test exact recall
correct = 0
for i, (cue, target) in enumerate(pairs):
recalled, idx = net.recall(emb(model, cue))
if idx == i:
correct += 1
# Test paraphrase
paraphrases = ["DB is crawling", "Ship the release",
"Getting 500 errors", "Need observability", "CI broken"]
para_correct = 0
for i, para in enumerate(paraphrases):
recalled, idx = net.recall(emb(model, para))
if idx == i:
para_correct += 1
n = len(pairs)
print(f" steps={num_steps:>3}, β={beta}: "
f"Exact={correct}/{n}, Para={para_correct}/{n}")
def test_comparison(model):
"""Compare SNN Hopfield vs standard Hopfield."""
print("\n=== Test 2: SNN vs Standard Hopfield ===\n")
pairs = [
("The database is slow", "Check missing indexes"),
("Deploy to production", "Use blue-green deployment"),
("The API returns 500", "Check for OOM in worker"),
("Set up monitoring", "Prometheus and Grafana"),
("Tests failing in CI", "Need postgres container"),
]
paraphrases = ["DB is crawling", "Ship the release",
"Getting 500 errors", "Need observability", "CI broken"]
# SNN Hopfield
snn_net = SNNHopfield(384, beta=0.9, num_steps=100).to(DEVICE)
for cue, target in pairs:
snn_net.store(emb(model, cue), emb(model, target))
snn_correct = 0
t0 = time.time()
for i, para in enumerate(paraphrases):
_, idx = snn_net.recall(emb(model, para))
if idx == i:
snn_correct += 1
snn_time = (time.time() - t0) / len(paraphrases) * 1000
# Standard Hopfield (softmax attention)
cue_embs = [emb(model, p[0]) for p in pairs]
target_embs = [emb(model, p[1]) for p in pairs]
cue_mat = torch.stack(cue_embs)
target_mat = torch.stack(target_embs)
std_correct = 0
t0 = time.time()
for i, para in enumerate(paraphrases):
q = emb(model, para)
xi = q
for _ in range(3):
scores = 16.0 * (xi @ cue_mat.T)
attn = torch.softmax(scores, dim=0)
xi = attn @ cue_mat
xi = nn.functional.normalize(xi, dim=0)
scores = 16.0 * (xi @ cue_mat.T)
attn = torch.softmax(scores, dim=0)
best = attn.argmax().item()
if best == i:
std_correct += 1
std_time = (time.time() - t0) / len(paraphrases) * 1000
n = len(paraphrases)
print(f" SNN Hopfield: {snn_correct}/{n} ({snn_correct/n:.0%}), {snn_time:.1f}ms/query")
print(f" Standard Hopfield: {std_correct}/{n} ({std_correct/n:.0%}), {std_time:.1f}ms/query")
def test_with_background(model):
"""SNN Hopfield with background noise."""
print("\n=== Test 3: SNN Hopfield with Background ===\n")
pairs = [
("The database is slow", "Check missing indexes"),
("Deploy to production", "Use blue-green deployment"),
("The API returns 500", "Check for OOM in worker"),
]
paraphrases = ["DB is crawling", "Ship the release", "Getting 500 errors"]
for n_bg in [0, 10, 50]:
net = SNNHopfield(384, beta=0.9, num_steps=100).to(DEVICE)
for cue, target in pairs:
net.store(emb(model, cue), emb(model, target))
for i in range(n_bg):
net.store(
emb(model, f"Background task {i} about topic {i%5}"),
emb(model, f"Background detail {i}"),
)
correct = 0
for i, para in enumerate(paraphrases):
_, idx = net.recall(emb(model, para))
if idx == i:
correct += 1
n = len(paraphrases)
t0 = time.time()
net.recall(emb(model, paraphrases[0]))
dt = (time.time() - t0) * 1000
print(f" bg={n_bg:>3}: Para={correct}/{n} ({correct/n:.0%}), "
f"latency={dt:.1f}ms, "
f"W_size={net.dim**2*4/1024/1024:.0f}MB")
def main():
print("=" * 60)
print("Experiment P5: SNN-native Hopfield")
print("=" * 60)
model = load_model()
test_basic(model)
test_comparison(model)
test_with_background(model)
if __name__ == "__main__":
main()