"""Experiment 2: STDP Associative Recall. Core question: Can STDP learn associations between spike patterns, such that presenting a cue recalls the correct target? Test protocol: 1. Generate N pairs of (cue, target) spike patterns 2. Train STDP network on all pairs 3. Present each cue and measure similarity between recall and correct target 4. Measure interference: does recall of pair K degrade after learning pair K+1? This is the make-or-break experiment for the whole approach. """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from nuonuo.memory import STDPMemoryNetwork DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def spike_similarity(a, b): """Cosine similarity between two spike trains (flattened).""" a_flat = a.flatten().float() b_flat = b.flatten().float() if a_flat.norm() == 0 or b_flat.norm() == 0: return 0.0 return nn.functional.cosine_similarity( a_flat.unsqueeze(0), b_flat.unsqueeze(0) ).item() def firing_rate_similarity(a, b): """Similarity based on per-neuron firing rates.""" fr_a = a.float().mean(dim=0) fr_b = b.float().mean(dim=0) if fr_a.norm() == 0 or fr_b.norm() == 0: return 0.0 return nn.functional.cosine_similarity( fr_a.unsqueeze(0), fr_b.unsqueeze(0) ).item() def generate_spike_pattern(num_steps, num_neurons, firing_rate=0.05, device="cuda"): """Generate a random sparse spike pattern.""" return (torch.rand(num_steps, num_neurons, device=device) < firing_rate).float() def run_recall_test(num_neurons, num_steps, num_pairs, firing_rate, num_presentations, a_plus, a_minus): """Test associative recall with given parameters.""" print(f" neurons={num_neurons}, steps={num_steps}, pairs={num_pairs}, " f"FR={firing_rate}, pres={num_presentations}, " f"A+={a_plus}, A-={a_minus}") net = STDPMemoryNetwork( num_neurons=num_neurons, a_plus=a_plus, a_minus=a_minus, ).to(DEVICE) # Generate pattern pairs cues = [] targets = [] for _ in range(num_pairs): cue = generate_spike_pattern(num_steps, num_neurons, firing_rate, DEVICE) target = generate_spike_pattern(num_steps, num_neurons, firing_rate, DEVICE) cues.append(cue) targets.append(target) # Learn all pairs t0 = time.time() for i in range(num_pairs): net.learn_association(cues[i], targets[i], num_presentations=num_presentations) learn_time = time.time() - t0 # Test recall correct_sims = [] wrong_sims = [] for i in range(num_pairs): recalled = net.recall(cues[i], num_recall_steps=num_steps) # Similarity to correct target correct_sim = firing_rate_similarity(recalled, targets[i]) correct_sims.append(correct_sim) # Similarity to wrong targets (average) wrong_sim_list = [] for j in range(num_pairs): if j != i: wrong_sim_list.append(firing_rate_similarity(recalled, targets[j])) if wrong_sim_list: wrong_sims.append(np.mean(wrong_sim_list)) mean_correct = np.mean(correct_sims) mean_wrong = np.mean(wrong_sims) if wrong_sims else 0 discrimination = mean_correct - mean_wrong w_stats = net.get_weight_stats() recall_fr = recalled.mean().item() if len(correct_sims) > 0 else 0 print(f" Correct sim: {mean_correct:.4f}, Wrong sim: {mean_wrong:.4f}, " f"Discrimination: {discrimination:.4f}") print(f" Recall FR: {recall_fr:.4f}, W stats: mean={w_stats['abs_mean']:.4f}, " f"sparsity={w_stats['sparsity']:.2f}") print(f" Learn time: {learn_time:.1f}s") return { "num_neurons": num_neurons, "num_steps": num_steps, "num_pairs": num_pairs, "firing_rate": firing_rate, "num_presentations": num_presentations, "a_plus": a_plus, "a_minus": a_minus, "mean_correct_sim": mean_correct, "mean_wrong_sim": mean_wrong, "discrimination": discrimination, "correct_sims": correct_sims, "recall_firing_rate": recall_fr, "weight_stats": w_stats, "learn_time": learn_time, } def main(): print("=" * 60) print("Experiment 2: STDP Associative Recall") print("=" * 60) results = [] # Test 1: Baseline — can it learn even 1 pair? print("\n--- Test 1: Single pair (sanity check) ---") r = run_recall_test( num_neurons=2048, num_steps=64, num_pairs=1, firing_rate=0.05, num_presentations=5, a_plus=0.005, a_minus=0.006, ) results.append({**r, "test": "single_pair"}) # Test 2: Vary number of pairs print("\n--- Test 2: Scaling pairs ---") for n_pairs in [5, 10, 20, 50]: r = run_recall_test( num_neurons=2048, num_steps=64, num_pairs=n_pairs, firing_rate=0.05, num_presentations=5, a_plus=0.005, a_minus=0.006, ) results.append({**r, "test": f"pairs_{n_pairs}"}) # Test 3: Vary STDP learning rates print("\n--- Test 3: STDP learning rate sweep ---") for a_plus in [0.001, 0.005, 0.01, 0.05]: r = run_recall_test( num_neurons=2048, num_steps=64, num_pairs=10, firing_rate=0.05, num_presentations=5, a_plus=a_plus, a_minus=a_plus * 1.2, ) results.append({**r, "test": f"lr_{a_plus}"}) # Test 4: Vary firing rate print("\n--- Test 4: Firing rate sweep ---") for fr in [0.02, 0.05, 0.10, 0.20]: r = run_recall_test( num_neurons=2048, num_steps=64, num_pairs=10, firing_rate=fr, num_presentations=5, a_plus=0.005, a_minus=0.006, ) results.append({**r, "test": f"fr_{fr}"}) # Test 5: More presentations print("\n--- Test 5: Presentation count ---") for n_pres in [1, 3, 5, 10, 20]: r = run_recall_test( num_neurons=2048, num_steps=64, num_pairs=10, firing_rate=0.05, num_presentations=n_pres, a_plus=0.005, a_minus=0.006, ) results.append({**r, "test": f"pres_{n_pres}"}) # Test 6: Wider network print("\n--- Test 6: Network width ---") for neurons in [1024, 2048, 4096, 8192]: r = run_recall_test( num_neurons=neurons, num_steps=64, num_pairs=10, firing_rate=0.05, num_presentations=5, a_plus=0.005, a_minus=0.006, ) results.append({**r, "test": f"width_{neurons}"}) # Save results for r in results: r["correct_sims"] = [float(x) for x in r["correct_sims"]] with open(RESULTS_DIR / "exp02_results.json", "w") as f: json.dump(results, f, indent=2, default=float) # Summary print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) print(f"{'Test':<15} {'Correct':>8} {'Wrong':>8} {'Discrim':>8} {'RecallFR':>8}") print("-" * 50) for r in results: print(f"{r['test']:<15} {r['mean_correct_sim']:>8.4f} " f"{r['mean_wrong_sim']:>8.4f} {r['discrimination']:>8.4f} " f"{r['recall_firing_rate']:>8.4f}") if __name__ == "__main__": main()