"""Experiment P5: SNN-native Hopfield (spike-based attention). Goal: Implement Hopfield-like attractor dynamics using LIF neurons. The connection: Hopfield softmax attention with inverse temperature β is equivalent to a Boltzmann distribution at temperature 1/β. In SNN terms: β maps to membrane time constant / threshold ratio. Approach: Replace softmax(β * q @ K^T) @ V with: 1. Encode query as spike train 2. Feed through recurrent LIF network with stored patterns as synaptic weights 3. Network settles to attractor (nearest stored pattern) 4. Read out associated target This is closer to biological CA3 recurrent dynamics. """ import sys import time from pathlib import Path import torch import torch.nn as nn import snntorch as snn import numpy as np DEVICE = "cuda" def cosine(a, b): return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() class SNNHopfield(nn.Module): """Spike-based Hopfield network. Architecture: - Input layer: converts query embedding to current injection - Recurrent layer: LIF neurons with Hopfield-like connection weights - Readout: spike rates → attention weights → target embedding The recurrent weights are set (not trained) based on stored patterns, making this a "configured" SNN, not a "trained" one. """ def __init__(self, dim, beta=0.9, threshold=1.0, num_steps=50): super().__init__() self.dim = dim self.num_steps = num_steps self.beta_lif = beta # LIF membrane decay self.threshold = threshold self.lif = snn.Leaky(beta=beta, threshold=threshold) # Stored patterns self.cue_patterns = [] self.target_patterns = [] def store(self, cue_emb, target_emb): self.cue_patterns.append(cue_emb.detach()) self.target_patterns.append(target_emb.detach()) def _build_weights(self): """Build Hopfield-like recurrent weights from stored patterns. W_ij = Σ_μ (pattern_μ_i * pattern_μ_j) / N This creates attractor states at each stored pattern. """ if not self.cue_patterns: return torch.zeros(self.dim, self.dim, device=DEVICE) patterns = torch.stack(self.cue_patterns) # [N_patterns, dim] W = patterns.T @ patterns / len(self.cue_patterns) # [dim, dim] # Remove diagonal (no self-connections, like biological networks) W.fill_diagonal_(0) return W def recall(self, query_emb): """Spike-based attractor dynamics. 1. Inject query as constant current 2. Let network settle via recurrent dynamics 3. Read spike rates → find nearest stored pattern → get target """ W = self._build_weights() # LIF dynamics mem = torch.zeros(self.dim, device=DEVICE) spike_counts = torch.zeros(self.dim, device=DEVICE) # Constant input current from query (scaled) input_current = query_emb * 2.0 # Scale to help reach threshold for step in range(self.num_steps): # Total current: external input + recurrent if step < self.num_steps // 2: # First half: external input drives the network total_current = input_current + W @ (mem / self.threshold) else: # Second half: only recurrent (free running, settle to attractor) total_current = W @ (mem / self.threshold) spk, mem = self.lif(total_current, mem) spike_counts += spk # Spike rates as representation spike_rates = spike_counts / self.num_steps # [dim] # Find nearest stored pattern by spike rate similarity if not self.cue_patterns: return None, None cue_mat = torch.stack(self.cue_patterns) sims = nn.functional.cosine_similarity( spike_rates.unsqueeze(0), cue_mat, dim=-1) # Softmax attention based on similarity (hybrid: spike settle + soft readout) attn = torch.softmax(sims * 16.0, dim=0) target_mat = torch.stack(self.target_patterns) recalled = attn @ target_mat recalled = nn.functional.normalize(recalled, dim=0) best_idx = sims.argmax().item() return recalled, best_idx def recall_pure_spike(self, query_emb): """Fully spike-based recall (no softmax at readout).""" W = self._build_weights() mem = torch.zeros(self.dim, device=DEVICE) spike_counts = torch.zeros(self.dim, device=DEVICE) input_current = query_emb * 2.0 for step in range(self.num_steps): if step < self.num_steps // 2: total_current = input_current + W @ (mem / self.threshold) else: total_current = W @ (mem / self.threshold) spk, mem = self.lif(total_current, mem) spike_counts += spk spike_rates = spike_counts / self.num_steps # Pure spike readout: direct cosine similarity (no softmax) cue_mat = torch.stack(self.cue_patterns) sims = nn.functional.cosine_similarity( spike_rates.unsqueeze(0), cue_mat, dim=-1) best_idx = sims.argmax().item() return self.target_patterns[best_idx], best_idx def load_model(): from sentence_transformers import SentenceTransformer return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE) def emb(model, text): return model.encode([text], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)[0] def test_basic(model): """Basic SNN Hopfield recall.""" print("=== Test 1: Basic SNN Hopfield ===\n") pairs = [ ("The database is slow", "Check missing indexes"), ("Deploy to production", "Use blue-green deployment"), ("The API returns 500", "Check for OOM in worker"), ("Set up monitoring", "Prometheus and Grafana"), ("Tests failing in CI", "Need postgres container"), ] for num_steps in [20, 50, 100, 200]: for beta in [0.8, 0.9, 0.95]: net = SNNHopfield(384, beta=beta, num_steps=num_steps).to(DEVICE) for cue, target in pairs: net.store(emb(model, cue), emb(model, target)) # Test exact recall correct = 0 for i, (cue, target) in enumerate(pairs): recalled, idx = net.recall(emb(model, cue)) if idx == i: correct += 1 # Test paraphrase paraphrases = ["DB is crawling", "Ship the release", "Getting 500 errors", "Need observability", "CI broken"] para_correct = 0 for i, para in enumerate(paraphrases): recalled, idx = net.recall(emb(model, para)) if idx == i: para_correct += 1 n = len(pairs) print(f" steps={num_steps:>3}, β={beta}: " f"Exact={correct}/{n}, Para={para_correct}/{n}") def test_comparison(model): """Compare SNN Hopfield vs standard Hopfield.""" print("\n=== Test 2: SNN vs Standard Hopfield ===\n") pairs = [ ("The database is slow", "Check missing indexes"), ("Deploy to production", "Use blue-green deployment"), ("The API returns 500", "Check for OOM in worker"), ("Set up monitoring", "Prometheus and Grafana"), ("Tests failing in CI", "Need postgres container"), ] paraphrases = ["DB is crawling", "Ship the release", "Getting 500 errors", "Need observability", "CI broken"] # SNN Hopfield snn_net = SNNHopfield(384, beta=0.9, num_steps=100).to(DEVICE) for cue, target in pairs: snn_net.store(emb(model, cue), emb(model, target)) snn_correct = 0 t0 = time.time() for i, para in enumerate(paraphrases): _, idx = snn_net.recall(emb(model, para)) if idx == i: snn_correct += 1 snn_time = (time.time() - t0) / len(paraphrases) * 1000 # Standard Hopfield (softmax attention) cue_embs = [emb(model, p[0]) for p in pairs] target_embs = [emb(model, p[1]) for p in pairs] cue_mat = torch.stack(cue_embs) target_mat = torch.stack(target_embs) std_correct = 0 t0 = time.time() for i, para in enumerate(paraphrases): q = emb(model, para) xi = q for _ in range(3): scores = 16.0 * (xi @ cue_mat.T) attn = torch.softmax(scores, dim=0) xi = attn @ cue_mat xi = nn.functional.normalize(xi, dim=0) scores = 16.0 * (xi @ cue_mat.T) attn = torch.softmax(scores, dim=0) best = attn.argmax().item() if best == i: std_correct += 1 std_time = (time.time() - t0) / len(paraphrases) * 1000 n = len(paraphrases) print(f" SNN Hopfield: {snn_correct}/{n} ({snn_correct/n:.0%}), {snn_time:.1f}ms/query") print(f" Standard Hopfield: {std_correct}/{n} ({std_correct/n:.0%}), {std_time:.1f}ms/query") def test_with_background(model): """SNN Hopfield with background noise.""" print("\n=== Test 3: SNN Hopfield with Background ===\n") pairs = [ ("The database is slow", "Check missing indexes"), ("Deploy to production", "Use blue-green deployment"), ("The API returns 500", "Check for OOM in worker"), ] paraphrases = ["DB is crawling", "Ship the release", "Getting 500 errors"] for n_bg in [0, 10, 50]: net = SNNHopfield(384, beta=0.9, num_steps=100).to(DEVICE) for cue, target in pairs: net.store(emb(model, cue), emb(model, target)) for i in range(n_bg): net.store( emb(model, f"Background task {i} about topic {i%5}"), emb(model, f"Background detail {i}"), ) correct = 0 for i, para in enumerate(paraphrases): _, idx = net.recall(emb(model, para)) if idx == i: correct += 1 n = len(paraphrases) t0 = time.time() net.recall(emb(model, paraphrases[0])) dt = (time.time() - t0) * 1000 print(f" bg={n_bg:>3}: Para={correct}/{n} ({correct/n:.0%}), " f"latency={dt:.1f}ms, " f"W_size={net.dim**2*4/1024/1024:.0f}MB") def main(): print("=" * 60) print("Experiment P5: SNN-native Hopfield") print("=" * 60) model = load_model() test_basic(model) test_comparison(model) test_with_background(model) if __name__ == "__main__": main()