nuonuo/experiments/exp03b_consolidation_stress.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

188 lines
6.3 KiB
Python

"""Experiment 3b: Consolidation near capacity limits.
With code_dim=16384 and k=20, capacity is so high that consolidation seems
unnecessary. Test with smaller code_dim (2048) where capacity limits are lower
and consolidation effects should be visible.
Also test: stronger homeostasis to control W_norm growth.
"""
import sys
import time
import json
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from nuonuo.consolidation import MemoryConsolidator, winner_take_all
DEVICE = "cuda"
RESULTS_DIR = Path(__file__).parent.parent / "doc"
def cosine(a, b):
if a.norm() == 0 or b.norm() == 0:
return 0.0
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
class SmallMemory:
"""Smaller memory for capacity-limited tests."""
def __init__(self, input_dim=768, code_dim=2048, k=50):
self.k = k
self.code_dim = code_dim
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
* (1.0 / input_dim**0.5))
self.target_proj = (torch.randn(input_dim, code_dim, device=DEVICE)
* (1.0 / input_dim**0.5))
self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE),
requires_grad=False)
self.consolidator = MemoryConsolidator(code_dim, k)
def sep(self, x):
return winner_take_all(x @ self.proj, self.k)
def sep_target(self, x):
return winner_take_all(x @ self.target_proj, self.k)
def learn(self, cue, target, record=True):
cc = self.sep(cue)
tc = self.sep_target(target)
self.W.data += torch.outer(tc, cc)
if record:
self.consolidator.record(cc, tc)
def recall(self, cue):
cc = self.sep(cue)
raw = self.W @ cc
return winner_take_all(raw, self.k)
def test_recall(self, cues, targets):
correct = []
for i in range(len(cues)):
recalled = self.recall(cues[i])
tc = self.sep_target(targets[i])
correct.append(cosine(recalled, tc))
return np.mean(correct), np.mean([s > 0.5 for s in correct])
def consolidate(self, **kwargs):
return self.consolidator.consolidate(
self.W, self.proj, self.target_proj, **kwargs)
def gen_memories(n, dim=768):
cues = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
for _ in range(n)]
targets = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
for _ in range(n)]
return cues, targets
def test_capacity_with_consolidation():
"""Find where small memory breaks and see if consolidation helps."""
print("=== Capacity with code_dim=2048, k=50 ===")
for n_pairs in [50, 100, 200, 500, 1000, 2000]:
mem_no_consol = SmallMemory()
mem_with_consol = SmallMemory()
mem_with_consol.proj = mem_no_consol.proj
mem_with_consol.target_proj = mem_no_consol.target_proj
cues, targets = gen_memories(n_pairs)
# Learn in both
for i in range(n_pairs):
mem_no_consol.learn(cues[i], targets[i], record=False)
mem_with_consol.learn(cues[i], targets[i], record=True)
cos_no, rate_no = mem_no_consol.test_recall(cues, targets)
# Consolidate with strong homeostasis
mem_with_consol.consolidate(num_epochs=3, homeostasis_factor=0.80,
prune_threshold=0.01)
cos_yes, rate_yes = mem_with_consol.test_recall(cues, targets)
w_no = mem_no_consol.W.data.norm().item()
w_yes = mem_with_consol.W.data.norm().item()
print(f" N={n_pairs:>5}: "
f"No_consol: CosSim={cos_no:.4f} Rate={rate_no:.0%} W={w_no:.0f} | "
f"With_consol: CosSim={cos_yes:.4f} Rate={rate_yes:.0%} W={w_yes:.0f}")
def test_multi_night_at_limit():
"""7-day scenario near capacity limits."""
print("\n=== 7-Day Scenario (code_dim=2048, k=50, 200/day) ===")
mem = SmallMemory()
all_cues = []
all_targets = []
for day in range(1, 8):
cues_today, targets_today = gen_memories(200)
all_cues.extend(cues_today)
all_targets.extend(targets_today)
for i in range(200):
mem.learn(cues_today[i], targets_today[i])
# Test on all memories so far
cos_all, rate_all = mem.test_recall(all_cues, all_targets)
cos_today, rate_today = mem.test_recall(cues_today, targets_today)
cos_day1, _ = mem.test_recall(all_cues[:200], all_targets[:200])
w_norm = mem.W.data.norm().item()
print(f" Day {day} (total={len(all_cues)}): "
f"All={cos_all:.4f}({rate_all:.0%}), "
f"Today={cos_today:.4f}, Day1={cos_day1:.4f}, "
f"W={w_norm:.0f}")
# Night: consolidate
mem.consolidate(num_epochs=3, homeostasis_factor=0.85,
prune_threshold=0.01)
mem.consolidator.selective_clear(keep_fraction=0.3)
cos_after, rate_after = mem.test_recall(all_cues, all_targets)
cos_day1_after, _ = mem.test_recall(all_cues[:200], all_targets[:200])
w_after = mem.W.data.norm().item()
print(f" → Night {day}: "
f"All={cos_after:.4f}({rate_after:.0%}), Day1={cos_day1_after:.4f}, "
f"W={w_after:.0f}")
def test_homeostasis_sweep():
"""Find the right homeostasis factor."""
print("\n=== Homeostasis Factor Sweep (500 pairs, 10 nights) ===")
for hf in [1.0, 0.99, 0.95, 0.90, 0.85, 0.80, 0.70]:
mem = SmallMemory()
cues, targets = gen_memories(500)
for i in range(500):
mem.learn(cues[i], targets[i])
for night in range(10):
mem.consolidate(num_epochs=1, homeostasis_factor=hf)
cos, rate = mem.test_recall(cues, targets)
w = mem.W.data.norm().item()
sp = (mem.W.data.abs() < 0.01).float().mean().item()
print(f" hf={hf:.2f}: CosSim={cos:.4f}, Rate={rate:.0%}, "
f"W_norm={w:.1f}, Sparsity={sp:.2%}")
def main():
print("=" * 60)
print("Experiment 3b: Consolidation Under Stress")
print("=" * 60)
test_capacity_with_consolidation()
test_multi_night_at_limit()
test_homeostasis_sweep()
if __name__ == "__main__":
main()