"""Experiment 2c: Pattern separation + improved associative recall. Key insight from 2b: random spike patterns have too much overlap, causing catastrophic interference in associative memory. Fix: Implement pattern separation (like dentate gyrus in hippocampus): 1. Winner-take-all: only top-k neurons fire → guaranteed sparse, minimal overlap 2. Random sparse projection: patterns projected through sparse random matrix 3. Scale up neurons to improve signal-to-noise ratio (capacity ∝ N/P) Also test: direct Hebbian in rate-space (skip spike conversion entirely) """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent / "src")) DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def cosine(a, b): if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() def winner_take_all(x, k): """Keep only top-k values, zero out the rest. Differentiable-ish.""" topk_vals, topk_idx = x.topk(k, dim=-1) out = torch.zeros_like(x) out.scatter_(-1, topk_idx, 1.0) # Binary: active or not return out class PatternSeparator(nn.Module): """Dentate gyrus analog: transforms input patterns into sparse, orthogonal codes.""" def __init__(self, input_dim, code_dim, k_active): super().__init__() self.k_active = k_active # Sparse random projection (fixed, not learned) proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('proj', proj) def forward(self, x): """x: [input_dim] → [code_dim] sparse binary""" h = x @ self.proj return winner_take_all(h, self.k_active) class HebbianMemory(nn.Module): """Heteroassociative memory with pattern separation.""" def __init__(self, input_dim, code_dim=8192, k_active=50, lr=1.0): super().__init__() self.separator = PatternSeparator(input_dim, code_dim, k_active) self.code_dim = code_dim self.lr = lr # Separate separator for targets (different random projection) self.target_separator = PatternSeparator(input_dim, code_dim, k_active) # Association matrix: separated_cue → separated_target self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False) def learn(self, cue, target): """cue, target: [dim] continuous vectors""" cue_code = self.separator(cue) target_code = self.target_separator(target) # Outer product Hebbian update self.W.data += self.lr * torch.outer(target_code, cue_code) def recall(self, cue, k_recall=50): """Returns separated target code.""" cue_code = self.separator(cue) raw = self.W @ cue_code # WTA on output to clean up return winner_take_all(raw, k_recall) def recall_continuous(self, cue): """Returns continuous activation (for cosine sim).""" cue_code = self.separator(cue) return self.W @ cue_code def test_hebbian_with_separation(input_dim, code_dim, k_active, num_pairs, lr): """Test associative recall with pattern separation.""" mem = HebbianMemory(input_dim, code_dim, k_active, lr).to(DEVICE) # Generate random normalized vectors as memories cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) for _ in range(num_pairs)] # Learn for i in range(num_pairs): mem.learn(cues[i], targets[i]) # Test recall in code space (after separation) correct_sims = [] wrong_sims = [] for i in range(num_pairs): recalled = mem.recall(cues[i], k_recall=k_active) target_code = mem.target_separator(targets[i]) cs = cosine(recalled, target_code) correct_sims.append(cs) for j in range(min(num_pairs, 20)): # limit comparisons for speed if j != i: wrong_code = mem.target_separator(targets[j]) wrong_sims.append(cosine(recalled, wrong_code)) mc = np.mean(correct_sims) mw = np.mean(wrong_sims) if wrong_sims else 0 print(f" code={code_dim}, k={k_active}, pairs={num_pairs}, lr={lr:.2f} | " f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}") return {"correct": mc, "wrong": mw, "disc": mc - mw, "code_dim": code_dim, "k_active": k_active, "num_pairs": num_pairs, "lr": lr} def test_overlap_analysis(code_dim, k_active, num_patterns): """Measure how orthogonal the separated patterns actually are.""" sep = PatternSeparator(768, code_dim, k_active).to(DEVICE) patterns = [] for _ in range(num_patterns): x = nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) code = sep(x) patterns.append(code) # Pairwise cosine similarity sims = [] for i in range(num_patterns): for j in range(i+1, num_patterns): s = cosine(patterns[i], patterns[j]) sims.append(s) mean_sim = np.mean(sims) max_sim = np.max(sims) print(f" code={code_dim}, k={k_active}: mean_overlap={mean_sim:.4f}, max_overlap={max_sim:.4f}") return {"mean_overlap": mean_sim, "max_overlap": max_sim} def main(): print("=" * 60) print("Experiment 2c: Pattern Separation + Hebbian Memory") print("=" * 60) results = [] # Part 1: Overlap analysis — how orthogonal are separated patterns? print("\n=== Part 1: Pattern overlap after separation ===") for code_dim in [2048, 4096, 8192, 16384]: for k in [20, 50, 100]: ov = test_overlap_analysis(code_dim, k, 100) results.append({"test": "overlap", "code_dim": code_dim, "k": k, **ov}) # Part 2: Associative recall with separation print("\n=== Part 2: Recall with pattern separation ===") print("\n-- Scaling pairs --") for n in [1, 5, 10, 20, 50, 100, 200, 500]: r = test_hebbian_with_separation(768, 8192, 50, n, lr=1.0) results.append({"test": f"sep_pairs_{n}", **r}) print("\n-- Code dimension sweep (100 pairs) --") for cd in [2048, 4096, 8192, 16384]: r = test_hebbian_with_separation(768, cd, 50, 100, lr=1.0) results.append({"test": f"sep_codedim_{cd}", **r}) print("\n-- Sparsity sweep (100 pairs, code=8192) --") for k in [10, 20, 50, 100, 200]: r = test_hebbian_with_separation(768, 8192, k, 100, lr=1.0) results.append({"test": f"sep_k_{k}", **r}) print("\n-- Capacity test: find the breaking point (code=16384, k=20) --") for n in [10, 50, 100, 200, 500, 1000, 2000]: r = test_hebbian_with_separation(768, 16384, 20, n, lr=1.0) results.append({"test": f"cap_{n}", **r}) # Save with open(RESULTS_DIR / "exp02c_results.json", "w") as f: json.dump(results, f, indent=2, default=float) # Find best config recall_results = [r for r in results if r.get("disc") is not None and "cap_" in r.get("test", "")] if recall_results: print("\n=== Capacity curve (code=16384, k=20) ===") print(f"{'Pairs':>6} {'Correct':>8} {'Wrong':>8} {'Disc':>8}") for r in recall_results: print(f"{r['num_pairs']:>6} {r['correct']:>8.4f} {r['wrong']:>8.4f} {r['disc']:>8.4f}") if __name__ == "__main__": main()