"""Experiment 3b: Consolidation near capacity limits. With code_dim=16384 and k=20, capacity is so high that consolidation seems unnecessary. Test with smaller code_dim (2048) where capacity limits are lower and consolidation effects should be visible. Also test: stronger homeostasis to control W_norm growth. """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from nuonuo.consolidation import MemoryConsolidator, winner_take_all DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def cosine(a, b): if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() class SmallMemory: """Smaller memory for capacity-limited tests.""" def __init__(self, input_dim=768, code_dim=2048, k=50): self.k = k self.code_dim = code_dim self.proj = (torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)) self.target_proj = (torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)) self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE), requires_grad=False) self.consolidator = MemoryConsolidator(code_dim, k) def sep(self, x): return winner_take_all(x @ self.proj, self.k) def sep_target(self, x): return winner_take_all(x @ self.target_proj, self.k) def learn(self, cue, target, record=True): cc = self.sep(cue) tc = self.sep_target(target) self.W.data += torch.outer(tc, cc) if record: self.consolidator.record(cc, tc) def recall(self, cue): cc = self.sep(cue) raw = self.W @ cc return winner_take_all(raw, self.k) def test_recall(self, cues, targets): correct = [] for i in range(len(cues)): recalled = self.recall(cues[i]) tc = self.sep_target(targets[i]) correct.append(cosine(recalled, tc)) return np.mean(correct), np.mean([s > 0.5 for s in correct]) def consolidate(self, **kwargs): return self.consolidator.consolidate( self.W, self.proj, self.target_proj, **kwargs) def gen_memories(n, dim=768): cues = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) for _ in range(n)] targets = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) for _ in range(n)] return cues, targets def test_capacity_with_consolidation(): """Find where small memory breaks and see if consolidation helps.""" print("=== Capacity with code_dim=2048, k=50 ===") for n_pairs in [50, 100, 200, 500, 1000, 2000]: mem_no_consol = SmallMemory() mem_with_consol = SmallMemory() mem_with_consol.proj = mem_no_consol.proj mem_with_consol.target_proj = mem_no_consol.target_proj cues, targets = gen_memories(n_pairs) # Learn in both for i in range(n_pairs): mem_no_consol.learn(cues[i], targets[i], record=False) mem_with_consol.learn(cues[i], targets[i], record=True) cos_no, rate_no = mem_no_consol.test_recall(cues, targets) # Consolidate with strong homeostasis mem_with_consol.consolidate(num_epochs=3, homeostasis_factor=0.80, prune_threshold=0.01) cos_yes, rate_yes = mem_with_consol.test_recall(cues, targets) w_no = mem_no_consol.W.data.norm().item() w_yes = mem_with_consol.W.data.norm().item() print(f" N={n_pairs:>5}: " f"No_consol: CosSim={cos_no:.4f} Rate={rate_no:.0%} W={w_no:.0f} | " f"With_consol: CosSim={cos_yes:.4f} Rate={rate_yes:.0%} W={w_yes:.0f}") def test_multi_night_at_limit(): """7-day scenario near capacity limits.""" print("\n=== 7-Day Scenario (code_dim=2048, k=50, 200/day) ===") mem = SmallMemory() all_cues = [] all_targets = [] for day in range(1, 8): cues_today, targets_today = gen_memories(200) all_cues.extend(cues_today) all_targets.extend(targets_today) for i in range(200): mem.learn(cues_today[i], targets_today[i]) # Test on all memories so far cos_all, rate_all = mem.test_recall(all_cues, all_targets) cos_today, rate_today = mem.test_recall(cues_today, targets_today) cos_day1, _ = mem.test_recall(all_cues[:200], all_targets[:200]) w_norm = mem.W.data.norm().item() print(f" Day {day} (total={len(all_cues)}): " f"All={cos_all:.4f}({rate_all:.0%}), " f"Today={cos_today:.4f}, Day1={cos_day1:.4f}, " f"W={w_norm:.0f}") # Night: consolidate mem.consolidate(num_epochs=3, homeostasis_factor=0.85, prune_threshold=0.01) mem.consolidator.selective_clear(keep_fraction=0.3) cos_after, rate_after = mem.test_recall(all_cues, all_targets) cos_day1_after, _ = mem.test_recall(all_cues[:200], all_targets[:200]) w_after = mem.W.data.norm().item() print(f" → Night {day}: " f"All={cos_after:.4f}({rate_after:.0%}), Day1={cos_day1_after:.4f}, " f"W={w_after:.0f}") def test_homeostasis_sweep(): """Find the right homeostasis factor.""" print("\n=== Homeostasis Factor Sweep (500 pairs, 10 nights) ===") for hf in [1.0, 0.99, 0.95, 0.90, 0.85, 0.80, 0.70]: mem = SmallMemory() cues, targets = gen_memories(500) for i in range(500): mem.learn(cues[i], targets[i]) for night in range(10): mem.consolidate(num_epochs=1, homeostasis_factor=hf) cos, rate = mem.test_recall(cues, targets) w = mem.W.data.norm().item() sp = (mem.W.data.abs() < 0.01).float().mean().item() print(f" hf={hf:.2f}: CosSim={cos:.4f}, Rate={rate:.0%}, " f"W_norm={w:.1f}, Sparsity={sp:.2%}") def main(): print("=" * 60) print("Experiment 3b: Consolidation Under Stress") print("=" * 60) test_capacity_with_consolidation() test_multi_night_at_limit() test_homeostasis_sweep() if __name__ == "__main__": main()