"""Experiment 2d: Robustness and capacity limits. Pattern separation + Hebbian recall is perfect with clean cues. Now test: 1. Noisy cues: add gaussian noise to cue before recall 2. Partial cues: zero out part of the cue 3. Capacity stress test: push to 10K+ memories 4. Full pipeline: encoder → separator → memory → decoder """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import numpy as np DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def cosine(a, b): if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() def winner_take_all(x, k): topk_vals, topk_idx = x.topk(k, dim=-1) out = torch.zeros_like(x) out.scatter_(-1, topk_idx, 1.0) return out class PatternSeparator(nn.Module): def __init__(self, input_dim, code_dim, k_active): super().__init__() self.k_active = k_active proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('proj', proj) def forward(self, x): h = x @ self.proj return winner_take_all(h, self.k_active) class HebbianMemory(nn.Module): def __init__(self, input_dim, code_dim=16384, k_active=20, lr=1.0): super().__init__() self.separator = PatternSeparator(input_dim, code_dim, k_active) self.target_separator = PatternSeparator(input_dim, code_dim, k_active) self.code_dim = code_dim self.k_active = k_active self.lr = lr self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False) def learn(self, cue, target): cue_code = self.separator(cue) target_code = self.target_separator(target) self.W.data += self.lr * torch.outer(target_code, cue_code) def recall_code(self, cue_code): raw = self.W @ cue_code return winner_take_all(raw, self.k_active) def recall(self, cue): cue_code = self.separator(cue) return self.recall_code(cue_code) def run_noise_test(num_pairs, noise_levels, code_dim=16384, k=20, input_dim=768): """Test recall under noisy cues.""" mem = HebbianMemory(input_dim, code_dim, k).to(DEVICE) cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) for _ in range(num_pairs)] for i in range(num_pairs): mem.learn(cues[i], targets[i]) # Pre-compute target codes target_codes = [mem.target_separator(t) for t in targets] results = {} for noise_std in noise_levels: correct_sims = [] for i in range(num_pairs): # Add noise to cue noisy_cue = cues[i] + torch.randn_like(cues[i]) * noise_std noisy_cue = nn.functional.normalize(noisy_cue, dim=0) recalled = mem.recall(noisy_cue) cs = cosine(recalled, target_codes[i]) correct_sims.append(cs) mc = np.mean(correct_sims) # Exact match rate (CosSim > 0.99) exact_rate = np.mean([s > 0.99 for s in correct_sims]) results[noise_std] = {"mean_cos": mc, "exact_rate": exact_rate} print(f" noise={noise_std:.2f}: CosSim={mc:.4f}, Exact={exact_rate:.2%}") return results def run_partial_cue_test(num_pairs, mask_fractions, code_dim=16384, k=20, input_dim=768): """Test recall with partial cues (some dimensions zeroed out).""" mem = HebbianMemory(input_dim, code_dim, k).to(DEVICE) cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) for _ in range(num_pairs)] for i in range(num_pairs): mem.learn(cues[i], targets[i]) target_codes = [mem.target_separator(t) for t in targets] results = {} for frac in mask_fractions: correct_sims = [] for i in range(num_pairs): # Zero out frac% of dimensions mask = torch.ones(input_dim, device=DEVICE) n_zero = int(input_dim * frac) indices = torch.randperm(input_dim)[:n_zero] mask[indices] = 0 partial_cue = cues[i] * mask partial_cue = nn.functional.normalize(partial_cue, dim=0) recalled = mem.recall(partial_cue) cs = cosine(recalled, target_codes[i]) correct_sims.append(cs) mc = np.mean(correct_sims) exact_rate = np.mean([s > 0.99 for s in correct_sims]) results[frac] = {"mean_cos": mc, "exact_rate": exact_rate} print(f" mask={frac:.0%}: CosSim={mc:.4f}, Exact={exact_rate:.2%}") return results def run_capacity_stress_test(code_dim=16384, k=20, input_dim=768): """Push memory count until recall degrades.""" mem = HebbianMemory(input_dim, code_dim, k).to(DEVICE) all_cues = [] all_targets = [] all_target_codes = [] checkpoints = [100, 500, 1000, 2000, 5000, 10000, 20000] results = {} for n in range(max(checkpoints)): cue = nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) target = nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) mem.learn(cue, target) all_cues.append(cue) all_targets.append(target) all_target_codes.append(mem.target_separator(target)) if (n + 1) in checkpoints: # Test recall on random sample sample_size = min(100, n + 1) indices = torch.randperm(n + 1)[:sample_size].tolist() correct_sims = [] for idx in indices: recalled = mem.recall(all_cues[idx]) cs = cosine(recalled, all_target_codes[idx]) correct_sims.append(cs) mc = np.mean(correct_sims) exact_rate = np.mean([s > 0.99 for s in correct_sims]) # W stats w_abs = mem.W.data.abs().mean().item() print(f" N={n+1:>5}: CosSim={mc:.4f}, Exact={exact_rate:.2%}, " f"W_abs={w_abs:.4f}") results[n+1] = {"mean_cos": mc, "exact_rate": exact_rate, "w_abs": w_abs} return results def main(): print("=" * 60) print("Experiment 2d: Robustness & Capacity") print("=" * 60) all_results = {} # Test 1: Noise robustness print("\n=== Noise Robustness (100 pairs) ===") noise_results = run_noise_test( 100, [0.0, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0]) all_results["noise"] = {str(k): v for k, v in noise_results.items()} # Test 2: Partial cue print("\n=== Partial Cue Robustness (100 pairs) ===") partial_results = run_partial_cue_test( 100, [0.0, 0.1, 0.2, 0.3, 0.5, 0.7, 0.9]) all_results["partial"] = {str(k): v for k, v in partial_results.items()} # Test 3: Capacity print("\n=== Capacity Stress Test (code=16384, k=20) ===") cap_results = run_capacity_stress_test() all_results["capacity"] = {str(k): v for k, v in cap_results.items()} with open(RESULTS_DIR / "exp02d_results.json", "w") as f: json.dump(all_results, f, indent=2, default=float) if __name__ == "__main__": main()