"""Experiment 2b: STDP Associative Recall (v2 - fixed learning). v1 failed completely because W=0 → no spikes → no STDP updates (chicken-egg). v2 fixes this with teacher-forced STDP: directly use (cue, target) as (pre, post). Also tests DirectAssociativeMemory (simple outer-product Hebbian) as baseline. """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from nuonuo.memory import STDPMemoryNetwork, DirectAssociativeMemory DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def spike_cosine(a, b): """Cosine similarity on firing rate vectors.""" if a.dim() == 2: a = a.mean(dim=0) if b.dim() == 2: b = b.mean(dim=0) if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() def vec_cosine(a, b): """Cosine similarity of two 1D vectors.""" if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() def gen_spikes(num_steps, num_neurons, fr=0.05, device="cuda"): return (torch.rand(num_steps, num_neurons, device=device) < fr).float() def test_stdp_v2(num_neurons, num_steps, num_pairs, fr, num_pres, a_plus): """Test the v2 STDP network.""" net = STDPMemoryNetwork( num_neurons=num_neurons, a_plus=a_plus, a_minus=a_plus*1.2, w_init_std=0.01 ).to(DEVICE) cues = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)] targets = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)] # Learn t0 = time.time() for i in range(num_pairs): net.learn_association(cues[i], targets[i], num_presentations=num_pres) learn_t = time.time() - t0 # Recall correct_sims = [] wrong_sims = [] for i in range(num_pairs): recalled = net.recall(cues[i]) cs = spike_cosine(recalled, targets[i]) correct_sims.append(cs) for j in range(num_pairs): if j != i: wrong_sims.append(spike_cosine(recalled, targets[j])) mc = np.mean(correct_sims) mw = np.mean(wrong_sims) if wrong_sims else 0 ws = net.get_weight_stats() print(f" STDP: pairs={num_pairs}, pres={num_pres}, A+={a_plus:.3f} | " f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}, " f"W_abs={ws['abs_mean']:.4f}, sparsity={ws['sparsity']:.2f}, " f"time={learn_t:.1f}s") return {"method": "stdp_v2", "correct": mc, "wrong": mw, "disc": mc-mw, "w_stats": ws, "time": learn_t, "num_pairs": num_pairs, "a_plus": a_plus, "num_pres": num_pres} def test_direct_hebbian(num_neurons, num_steps, num_pairs, fr, lr): """Test the direct outer-product Hebbian memory.""" net = DirectAssociativeMemory(num_neurons=num_neurons, lr=lr).to(DEVICE) cues = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)] targets = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)] # Learn t0 = time.time() for i in range(num_pairs): net.learn(cues[i], targets[i]) learn_t = time.time() - t0 # Recall correct_sims = [] wrong_sims = [] for i in range(num_pairs): recalled = net.recall(cues[i]) # continuous vector target_rate = targets[i].mean(dim=0) cs = vec_cosine(recalled, target_rate) correct_sims.append(cs) for j in range(num_pairs): if j != i: wrong_sims.append(vec_cosine(recalled, targets[j].mean(dim=0))) mc = np.mean(correct_sims) mw = np.mean(wrong_sims) if wrong_sims else 0 ws = net.get_weight_stats() print(f" Hebbian: pairs={num_pairs}, lr={lr:.3f} | " f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}, " f"W_abs={ws['abs_mean']:.6f}, sparsity={ws['sparsity']:.2f}, " f"time={learn_t:.3f}s") return {"method": "direct_hebbian", "correct": mc, "wrong": mw, "disc": mc-mw, "w_stats": ws, "time": learn_t, "num_pairs": num_pairs, "lr": lr} def main(): print("=" * 60) print("Experiment 2b: STDP v2 + Direct Hebbian") print("=" * 60) results = [] N = 2048 S = 64 FR = 0.05 # --- Part A: Direct Hebbian (baseline) --- print("\n=== Part A: Direct Hebbian Memory ===") print("\nA1: Scaling pairs (lr=0.5)") for n in [1, 5, 10, 20, 50, 100]: r = test_direct_hebbian(N, S, n, FR, lr=0.5) results.append({**r, "test": f"hebb_pairs_{n}"}) print("\nA2: Learning rate sweep (10 pairs)") for lr in [0.01, 0.1, 0.5, 1.0, 5.0]: r = test_direct_hebbian(N, S, 10, FR, lr=lr) results.append({**r, "test": f"hebb_lr_{lr}"}) # --- Part B: STDP v2 --- print("\n=== Part B: STDP v2 (teacher-forced) ===") print("\nB1: Sanity check - single pair") r = test_stdp_v2(N, S, 1, FR, num_pres=5, a_plus=0.01) results.append({**r, "test": "stdp_single"}) print("\nB2: A+ sweep (10 pairs, 5 presentations)") for ap in [0.001, 0.005, 0.01, 0.05, 0.1]: r = test_stdp_v2(N, S, 10, FR, num_pres=5, a_plus=ap) results.append({**r, "test": f"stdp_ap_{ap}"}) print("\nB3: Presentation count (10 pairs, A+=0.01)") for pres in [1, 3, 5, 10, 20]: r = test_stdp_v2(N, S, 10, FR, num_pres=pres, a_plus=0.01) results.append({**r, "test": f"stdp_pres_{pres}"}) print("\nB4: Scaling pairs (A+=0.01, 5 presentations)") for n in [1, 5, 10, 20, 50]: r = test_stdp_v2(N, S, n, FR, num_pres=5, a_plus=0.01) results.append({**r, "test": f"stdp_pairs_{n}"}) # Save with open(RESULTS_DIR / "exp02b_results.json", "w") as f: json.dump(results, f, indent=2, default=float) # Best from each method print("\n" + "=" * 60) hebb_best = max([r for r in results if r["method"] == "direct_hebbian"], key=lambda x: x["disc"], default=None) stdp_best = max([r for r in results if r["method"] == "stdp_v2"], key=lambda x: x["disc"], default=None) if hebb_best: print(f"Best Hebbian: {hebb_best['test']} — " f"Correct={hebb_best['correct']:.4f}, Disc={hebb_best['disc']:.4f}") if stdp_best: print(f"Best STDP: {stdp_best['test']} — " f"Correct={stdp_best['correct']:.4f}, Disc={stdp_best['disc']:.4f}") if __name__ == "__main__": main()