"""Experiment 3: Sleep Consolidation Effects. Test questions: 1. Does consolidation (replay + homeostasis) help or hurt recall? 2. Does replay with noise improve noise tolerance? 3. How does pruning affect capacity? 4. Multi-night scenario: learn day 1, consolidate, learn day 2, consolidate. Do day 1 memories survive? 5. Selective consolidation: replay important memories more → priority memory """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from nuonuo.consolidation import MemoryConsolidator, winner_take_all DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def cosine(a, b): if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() class TestableMemory: """Memory with consolidation support for testing.""" def __init__(self, input_dim=768, code_dim=16384, k=20): self.k = k self.code_dim = code_dim self.proj = (torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)) self.target_proj = (torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)) self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE), requires_grad=False) self.consolidator = MemoryConsolidator(code_dim, k) def sep(self, x): return winner_take_all(x @ self.proj, self.k) def sep_target(self, x): return winner_take_all(x @ self.target_proj, self.k) def learn(self, cue, target, record=True): cc = self.sep(cue) tc = self.sep_target(target) self.W.data += torch.outer(tc, cc) if record: self.consolidator.record(cc, tc) def recall(self, cue): cc = self.sep(cue) raw = self.W @ cc return winner_take_all(raw, self.k) def test_recall(self, cues, targets, noise_std=0.0): """Test recall accuracy.""" correct = [] for i in range(len(cues)): if noise_std > 0: c = nn.functional.normalize( cues[i] + torch.randn_like(cues[i]) * noise_std, dim=0) else: c = cues[i] recalled = self.recall(c) tc = self.sep_target(targets[i]) correct.append(cosine(recalled, tc)) return np.mean(correct), np.mean([s > 0.5 for s in correct]) def consolidate(self, **kwargs): return self.consolidator.consolidate( self.W, self.proj, self.target_proj, **kwargs) def gen_memories(n, dim=768): cues = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) for _ in range(n)] targets = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) for _ in range(n)] return cues, targets def test_basic_consolidation(): """Does replay + homeostasis help?""" print("=== Test 1: Basic Consolidation Effect ===") for n_pairs in [100, 500]: mem = TestableMemory() cues, targets = gen_memories(n_pairs) # Learn for i in range(n_pairs): mem.learn(cues[i], targets[i]) # Before consolidation cos_before, rate_before = mem.test_recall(cues, targets) w_norm_before = mem.W.data.norm().item() print(f"\n {n_pairs} pairs:") print(f" Before: CosSim={cos_before:.4f}, Rate={rate_before:.2%}, " f"W_norm={w_norm_before:.2f}") # Consolidation with different settings for epochs in [1, 3, 5, 10]: # Clone memory for each test mem_test = TestableMemory() mem_test.W.data.copy_(mem.W.data) mem_test.proj = mem.proj mem_test.target_proj = mem.target_proj mem_test.consolidator.replay_buffer = list(mem.consolidator.replay_buffer) stats = mem_test.consolidate( num_epochs=epochs, homeostasis_factor=0.95, prune_threshold=0.001) cos_after, rate_after = mem_test.test_recall(cues, targets) print(f" After (epochs={epochs}): CosSim={cos_after:.4f}, " f"Rate={rate_after:.2%}, " f"W_norm={stats['final_w_norm']:.2f}, " f"Sparsity={stats['final_sparsity']:.2%}") def test_noisy_replay(): """Does replay with noise improve noise tolerance?""" print("\n=== Test 2: Noisy Replay for Robustness ===") n_pairs = 100 mem_base = TestableMemory() cues, targets = gen_memories(n_pairs) for i in range(n_pairs): mem_base.learn(cues[i], targets[i]) # Test at different noise levels test_noises = [0.0, 0.05, 0.1, 0.2] # No consolidation (baseline) print("\n No consolidation:") for ns in test_noises: cos, rate = mem_base.test_recall(cues, targets, noise_std=ns) print(f" test_noise={ns:.2f}: CosSim={cos:.4f}, Rate={rate:.2%}") # Consolidation with different replay noise for replay_noise in [0.0, 0.1, 0.5, 1.0]: mem_test = TestableMemory() mem_test.W.data.copy_(mem_base.W.data) mem_test.proj = mem_base.proj mem_test.target_proj = mem_base.target_proj mem_test.consolidator.replay_buffer = list(mem_base.consolidator.replay_buffer) mem_test.consolidate(num_epochs=5, replay_noise=replay_noise, homeostasis_factor=0.95) print(f"\n Consolidated (replay_noise={replay_noise}):") for ns in test_noises: cos, rate = mem_test.test_recall(cues, targets, noise_std=ns) print(f" test_noise={ns:.2f}: CosSim={cos:.4f}, Rate={rate:.2%}") def test_multi_night(): """Multi-night scenario: learn, consolidate, learn more. Do old memories survive?""" print("\n=== Test 3: Multi-Night Memory Survival ===") mem = TestableMemory() # Day 1: Learn 100 memories cues_d1, targets_d1 = gen_memories(100) for i in range(100): mem.learn(cues_d1[i], targets_d1[i]) cos_d1, _ = mem.test_recall(cues_d1, targets_d1) print(f" After Day 1 (100 memories): CosSim={cos_d1:.4f}") # Night 1: Consolidate stats = mem.consolidate(num_epochs=5, homeostasis_factor=0.95) cos_d1_after, _ = mem.test_recall(cues_d1, targets_d1) print(f" After Night 1 consolidation: CosSim={cos_d1_after:.4f}, " f"W_norm={stats['final_w_norm']:.2f}") mem.consolidator.selective_clear(keep_fraction=0.3) # Day 2: Learn 100 more memories cues_d2, targets_d2 = gen_memories(100) for i in range(100): mem.learn(cues_d2[i], targets_d2[i]) cos_d1_mid, _ = mem.test_recall(cues_d1, targets_d1) cos_d2_mid, _ = mem.test_recall(cues_d2, targets_d2) print(f" After Day 2 (100 more): Day1={cos_d1_mid:.4f}, Day2={cos_d2_mid:.4f}") # Night 2: Consolidate (with day 1 carryover + day 2) stats = mem.consolidate(num_epochs=5, homeostasis_factor=0.95) cos_d1_final, _ = mem.test_recall(cues_d1, targets_d1) cos_d2_final, _ = mem.test_recall(cues_d2, targets_d2) print(f" After Night 2: Day1={cos_d1_final:.4f}, Day2={cos_d2_final:.4f}, " f"W_norm={stats['final_w_norm']:.2f}") # Continue for 5 more days for day in range(3, 8): mem.consolidator.selective_clear(keep_fraction=0.3) cues_new, targets_new = gen_memories(100) for i in range(100): mem.learn(cues_new[i], targets_new[i]) mem.consolidate(num_epochs=5, homeostasis_factor=0.95) cos_d1_now, _ = mem.test_recall(cues_d1, targets_d1) cos_d2_now, _ = mem.test_recall(cues_d2, targets_d2) cos_new, _ = mem.test_recall(cues_new, targets_new) w_norm = mem.W.data.norm().item() sparsity = (mem.W.data.abs() < 0.001).float().mean().item() print(f" After Day {day}: Day1={cos_d1_now:.4f}, Day2={cos_d2_now:.4f}, " f"Latest={cos_new:.4f}, W_norm={w_norm:.1f}, Sparsity={sparsity:.2%}") def test_priority_replay(): """Test selective consolidation: replay important memories more.""" print("\n=== Test 4: Priority Replay ===") mem = TestableMemory() # 50 "important" memories (replay 5x) cues_imp, targets_imp = gen_memories(50) for i in range(50): mem.learn(cues_imp[i], targets_imp[i]) # Record extra copies for priority replay cc = mem.sep(cues_imp[i]) tc = mem.sep_target(targets_imp[i]) for _ in range(4): # 4 extra = 5x total mem.consolidator.record(cc, tc) # 50 "unimportant" memories (replay 1x, normal) cues_unimp, targets_unimp = gen_memories(50) for i in range(50): mem.learn(cues_unimp[i], targets_unimp[i]) cos_imp_before, _ = mem.test_recall(cues_imp, targets_imp) cos_unimp_before, _ = mem.test_recall(cues_unimp, targets_unimp) print(f" Before consolidation: Important={cos_imp_before:.4f}, " f"Unimportant={cos_unimp_before:.4f}") # Consolidate with strong homeostasis (will decay unimportant more) mem.consolidate(num_epochs=10, homeostasis_factor=0.90) cos_imp_after, _ = mem.test_recall(cues_imp, targets_imp) cos_unimp_after, _ = mem.test_recall(cues_unimp, targets_unimp) print(f" After consolidation: Important={cos_imp_after:.4f}, " f"Unimportant={cos_unimp_after:.4f}") print(f" Priority effect: Δimportant={cos_imp_after-cos_imp_before:+.4f}, " f"Δunimportant={cos_unimp_after-cos_unimp_before:+.4f}") def test_forgetting_curve(): """Measure memory decay over multiple consolidation cycles without replay.""" print("\n=== Test 5: Forgetting Curve ===") mem = TestableMemory() cues, targets = gen_memories(100) for i in range(100): mem.learn(cues[i], targets[i]) cos0, _ = mem.test_recall(cues, targets) print(f" Day 0: CosSim={cos0:.4f}") # Simulate nights with homeostasis but NO replay for night in range(1, 11): # Only homeostasis + pruning, no replay mem.W.data *= 0.95 mask = mem.W.data.abs() >= 0.001 mem.W.data *= mask.float() cos, rate = mem.test_recall(cues, targets) w_norm = mem.W.data.norm().item() print(f" Night {night:2d} (no replay): CosSim={cos:.4f}, " f"Rate={rate:.2%}, W_norm={w_norm:.2f}") # Same but WITH replay print("\n --- With replay ---") mem2 = TestableMemory() mem2.proj = mem.proj mem2.target_proj = mem.target_proj for i in range(100): mem2.learn(cues[i], targets[i]) for night in range(1, 11): mem2.consolidate(num_epochs=1, homeostasis_factor=0.95) cos, rate = mem2.test_recall(cues, targets) w_norm = mem2.W.data.norm().item() print(f" Night {night:2d} (with replay): CosSim={cos:.4f}, " f"Rate={rate:.2%}, W_norm={w_norm:.2f}") def main(): print("=" * 60) print("Experiment 3: Sleep Consolidation") print("=" * 60) test_basic_consolidation() test_noisy_replay() test_multi_night() test_priority_replay() test_forgetting_curve() if __name__ == "__main__": main()