nuonuo/experiments/exp12_lifecycle.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

259 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Experiment P4: Memory Lifecycle Management.
Questions:
1. What's worth storing? (not everything in a conversation is a "memory")
2. When to forget? (access-based decay, age-based decay, capacity pressure)
3. Can we merge similar memories? (deduplification / compression)
4. Importance scoring: how to prioritize during recall and forgetting?
Strategy: implement and test each mechanism, measure impact on recall quality.
"""
import sys
import time
from pathlib import Path
from collections import Counter
import torch
import torch.nn as nn
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from nuonuo.hippocampus import HippocampalMemory
DEVICE = "cuda"
def cosine(a, b):
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
def load_model():
from sentence_transformers import SentenceTransformer
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
def emb(model, text):
return model.encode([text], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
def test_deduplication(model):
"""Test: can we detect and merge duplicate/near-duplicate memories?"""
print("=== Test 1: Deduplication ===\n")
mem = HippocampalMemory(embed_dim=384)
# Store some memories with near-duplicates
memories = [
("The database is slow", "Check missing indexes"),
("Database is really slow today", "Check missing indexes on users table"), # near-dup
("DB performance is terrible", "Look at index usage"), # near-dup
("Deploy to production", "Use blue-green deployment"),
("Push to prod", "Blue-green deployment via GitHub Actions"), # near-dup
("The API returns 500 errors", "Check for OOM in Python worker"),
("Getting 500 errors from API", "Python worker might be OOM"), # near-dup
("Set up monitoring", "Prometheus + Grafana"),
("We need better observability", "Set up Prometheus and Grafana"), # near-dup
]
for cue, target in memories:
mem.store(emb(model, cue), emb(model, target),
metadata={"cue": cue, "target": target})
print(f" Before dedup: {len(mem.memories)} memories")
# Detect near-duplicates by cue similarity
entries = list(mem.memories.values())
groups = []
used = set()
for i, e1 in enumerate(entries):
if i in used:
continue
group = [i]
for j, e2 in enumerate(entries):
if j <= i or j in used:
continue
sim = cosine(e1.cue_embedding, e2.cue_embedding)
if sim > 0.7: # threshold for "near-duplicate"
group.append(j)
used.add(j)
groups.append(group)
used.add(i)
print(f" Found {len(groups)} groups (from {len(entries)} memories):")
for group in groups:
if len(group) > 1:
cues = [entries[i].metadata.get("cue", "?") for i in group]
print(f" Group ({len(group)}): {[c[:30] for c in cues]}")
# Merge: keep the one with longest target (most info)
to_remove = []
for group in groups:
if len(group) > 1:
# Keep the one with longest target text
best = max(group, key=lambda i: len(entries[i].metadata.get("target", "")))
for i in group:
if i != best:
to_remove.append(entries[i].memory_id)
for mid in to_remove:
mem.forget(mid)
print(f" After dedup: {len(mem.memories)} memories")
print(f" Removed {len(to_remove)} duplicates")
def test_importance_scoring(model):
"""Test: importance-based memory management."""
print("\n=== Test 2: Importance Scoring ===\n")
# Simulate conversation with varying importance
conversations = [
# (user, assistant, expected_importance)
("Hi there!", "Hello! How can I help?", "low"),
("What's the weather?", "It's sunny today.", "low"),
("The production database crashed at 3am", "Emergency: restore from latest backup at s3://backups/db-latest.sql", "high"),
("What time is it?", "It's 3:45 PM.", "low"),
("The auth service JWT secret was compromised", "Rotate secret immediately: kubectl set env deployment/auth JWT_SECRET=new_value", "critical"),
("Deploy the hotfix", "Deployed via GitHub Actions, monitor Grafana for 30 min", "high"),
("Thanks for your help", "You're welcome!", "low"),
]
def score_importance(user_msg, assistant_msg):
"""Simple heuristic importance scoring."""
score = 0.3 # base
# Length suggests complexity
if len(assistant_msg.split()) > 15:
score += 0.2
# Technical keywords
critical_words = ["crash", "emergency", "compromised", "secret", "password",
"production", "outage", "down", "data loss"]
high_words = ["deploy", "config", "fix", "bug", "error", "migrate",
"backup", "restore", "rollback"]
for w in critical_words:
if w in (user_msg + assistant_msg).lower():
score += 0.3
for w in high_words:
if w in (user_msg + assistant_msg).lower():
score += 0.1
# Questions suggest retrievable info
if "?" in user_msg:
score += 0.1
return min(score, 1.0)
for user, assistant, expected in conversations:
score = score_importance(user, assistant)
status = "" if (expected == "low" and score < 0.5) or \
(expected == "high" and 0.5 <= score < 0.8) or \
(expected == "critical" and score >= 0.8) else ""
should_store = score >= 0.4
print(f" {status} [{score:.2f}] {'STORE' if should_store else 'SKIP ':>5} "
f"({expected:>8}) '{user[:40]}...'")
def test_forgetting_strategies(model):
"""Test: different forgetting strategies under memory pressure."""
print("\n=== Test 3: Forgetting Strategies ===\n")
# Simulate 7 days of memories, each day 10 memories
days = 7
per_day = 10
max_capacity = 30 # Force forgetting after 30 memories
cue_template = "Day {day} task {i}: {topic}"
target_template = "Solution for day {day} task {i}"
topics = ["database", "deploy", "monitoring", "auth", "API",
"caching", "logging", "testing", "docker", "CI/CD"]
def run_strategy(strategy_name, forget_fn):
mem = HippocampalMemory(embed_dim=384)
day_memories = {} # day → list of memory_ids
for day in range(1, days + 1):
day_memories[day] = []
for i in range(per_day):
cue = cue_template.format(day=day, i=i, topic=topics[i])
target = target_template.format(day=day, i=i)
mid = mem.store(emb(model, cue), emb(model, target),
metadata={"day": day, "task": i},
timestamp=float(day))
day_memories[day].append(mid)
# Check capacity
if len(mem.memories) > max_capacity:
forget_fn(mem, max_capacity)
# Test recall for each day's memories
day_recall = {}
for day in range(1, days + 1):
correct = 0
total = 0
for i in range(per_day):
mid = day_memories[day][i] if i < len(day_memories[day]) else None
if mid is None or mid not in mem.memories:
continue
cue = cue_template.format(day=day, i=i, topic=topics[i])
results = mem.recall(emb(model, cue), top_k=1)
if results and results[0].memory_id == mid:
correct += 1
total += 1
day_recall[day] = (correct, total)
# Print results
surviving = len(mem.memories)
print(f" {strategy_name}: {surviving} memories surviving")
for day in range(1, days + 1):
c, t = day_recall[day]
pct = f"{c}/{t}" if t > 0 else "0/0"
print(f" Day {day}: {pct}")
# Strategy 1: FIFO (oldest first)
def forget_fifo(mem, cap):
entries = sorted(mem.memories.values(), key=lambda e: e.timestamp)
to_remove = len(mem.memories) - cap
for e in entries[:to_remove]:
mem.forget(e.memory_id)
# Strategy 2: LRU (least recently accessed)
def forget_lru(mem, cap):
entries = sorted(mem.memories.values(), key=lambda e: e.access_count)
to_remove = len(mem.memories) - cap
for e in entries[:to_remove]:
mem.forget(e.memory_id)
# Strategy 3: Low importance first (by timestamp recency as proxy)
def forget_low_importance(mem, cap):
entries = sorted(mem.memories.values(),
key=lambda e: e.timestamp + e.access_count * 0.5)
to_remove = len(mem.memories) - cap
for e in entries[:to_remove]:
mem.forget(e.memory_id)
print("(max_capacity=30, 7 days × 10 memories = 70 total)")
run_strategy("FIFO (oldest first)", forget_fifo)
print()
run_strategy("LRU (least accessed)", forget_lru)
print()
run_strategy("Importance (recency+access)", forget_low_importance)
def main():
print("=" * 60)
print("Experiment P4: Memory Lifecycle")
print("=" * 60)
model = load_model()
test_deduplication(model)
test_importance_scoring(model)
test_forgetting_strategies(model)
if __name__ == "__main__":
main()