Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
180 lines
5.9 KiB
Python
180 lines
5.9 KiB
Python
"""Experiment 1: Encoder roundtrip test.
|
|
|
|
Goal: Can we encode an embedding into spikes and decode it back with acceptable loss?
|
|
This is the foundation — if this fails, the whole approach is dead.
|
|
|
|
We train a SpikeAutoencoder on random embeddings (simulating LLM hidden states)
|
|
and measure reconstruction quality via cosine similarity and MSE.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
from nuonuo.encoder import SpikeAutoencoder
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
RESULTS_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
def cosine_sim(a, b):
|
|
"""Batch cosine similarity."""
|
|
return nn.functional.cosine_similarity(a, b, dim=-1).mean().item()
|
|
|
|
|
|
def run_config(embed_dim, num_neurons, num_steps, lr, epochs, batch_size, num_batches):
|
|
"""Train and evaluate one configuration."""
|
|
model = SpikeAutoencoder(embed_dim, num_neurons, num_steps).to(DEVICE)
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
mse_loss = nn.MSELoss()
|
|
cos_loss = nn.CosineEmbeddingLoss()
|
|
|
|
param_count = sum(p.numel() for p in model.parameters())
|
|
print(f" Config: dim={embed_dim}, neurons={num_neurons}, steps={num_steps}")
|
|
print(f" Parameters: {param_count:,}")
|
|
|
|
history = {"train_mse": [], "train_cos": [], "epoch_time": []}
|
|
target = torch.ones(batch_size, device=DEVICE)
|
|
|
|
for epoch in range(epochs):
|
|
t0 = time.time()
|
|
epoch_mse = 0
|
|
epoch_cos = 0
|
|
|
|
for _ in range(num_batches):
|
|
# Random embeddings — simulate LLM hidden states (normalized)
|
|
emb = torch.randn(batch_size, embed_dim, device=DEVICE)
|
|
emb = nn.functional.normalize(emb, dim=-1)
|
|
|
|
recon, spikes, _ = model(emb)
|
|
|
|
loss_mse = mse_loss(recon, emb)
|
|
loss_cos = cos_loss(recon, emb, target)
|
|
# Sparsity regularization: encourage ~10% firing rate
|
|
firing_rate = spikes.mean()
|
|
loss_sparse = (firing_rate - 0.1).pow(2)
|
|
|
|
loss = loss_mse + 0.5 * loss_cos + 0.1 * loss_sparse
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
epoch_mse += loss_mse.item()
|
|
epoch_cos += cosine_sim(recon, emb)
|
|
|
|
epoch_mse /= num_batches
|
|
epoch_cos /= num_batches
|
|
dt = time.time() - t0
|
|
|
|
history["train_mse"].append(epoch_mse)
|
|
history["train_cos"].append(epoch_cos)
|
|
history["epoch_time"].append(dt)
|
|
|
|
if (epoch + 1) % 10 == 0 or epoch == 0:
|
|
fr = spikes.mean().item()
|
|
print(f" Epoch {epoch+1:3d}: MSE={epoch_mse:.6f}, "
|
|
f"CosSim={epoch_cos:.4f}, FR={fr:.3f}, Time={dt:.1f}s")
|
|
|
|
# Final eval on fresh data
|
|
model.eval()
|
|
with torch.no_grad():
|
|
test_emb = torch.randn(256, embed_dim, device=DEVICE)
|
|
test_emb = nn.functional.normalize(test_emb, dim=-1)
|
|
recon, spikes, _ = model(test_emb)
|
|
final_mse = mse_loss(recon, test_emb).item()
|
|
final_cos = cosine_sim(recon, test_emb)
|
|
final_fr = spikes.mean().item()
|
|
|
|
print(f" ** Final eval: MSE={final_mse:.6f}, CosSim={final_cos:.4f}, FR={final_fr:.3f}")
|
|
|
|
return {
|
|
"embed_dim": embed_dim,
|
|
"num_neurons": num_neurons,
|
|
"num_steps": num_steps,
|
|
"param_count": param_count,
|
|
"final_mse": final_mse,
|
|
"final_cos": final_cos,
|
|
"final_firing_rate": final_fr,
|
|
"history": history,
|
|
}
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 1: Encoder Roundtrip Test")
|
|
print("=" * 60)
|
|
|
|
configs = [
|
|
# (embed_dim, num_neurons, num_steps)
|
|
# Start small, scale up if promising
|
|
(256, 512, 32),
|
|
(256, 1024, 32),
|
|
(256, 1024, 64),
|
|
(768, 2048, 64),
|
|
(768, 4096, 64),
|
|
(768, 4096, 128),
|
|
]
|
|
|
|
all_results = []
|
|
for embed_dim, num_neurons, num_steps in configs:
|
|
print(f"\n--- Config: dim={embed_dim}, neurons={num_neurons}, steps={num_steps} ---")
|
|
result = run_config(
|
|
embed_dim=embed_dim,
|
|
num_neurons=num_neurons,
|
|
num_steps=num_steps,
|
|
lr=1e-3,
|
|
epochs=50,
|
|
batch_size=64,
|
|
num_batches=20,
|
|
)
|
|
all_results.append(result)
|
|
torch.cuda.empty_cache()
|
|
|
|
# Save results
|
|
# Convert for JSON serialization
|
|
for r in all_results:
|
|
r["history"]["train_mse"] = [float(x) for x in r["history"]["train_mse"]]
|
|
r["history"]["train_cos"] = [float(x) for x in r["history"]["train_cos"]]
|
|
r["history"]["epoch_time"] = [float(x) for x in r["history"]["epoch_time"]]
|
|
|
|
results_file = RESULTS_DIR / "exp01_results.json"
|
|
with open(results_file, "w") as f:
|
|
json.dump(all_results, f, indent=2)
|
|
|
|
# Print summary table
|
|
print("\n" + "=" * 80)
|
|
print("SUMMARY")
|
|
print("=" * 80)
|
|
print(f"{'Dim':>5} {'Neurons':>8} {'Steps':>6} {'Params':>10} {'MSE':>10} {'CosSim':>8} {'FR':>6}")
|
|
print("-" * 80)
|
|
for r in all_results:
|
|
print(f"{r['embed_dim']:>5} {r['num_neurons']:>8} {r['num_steps']:>6} "
|
|
f"{r['param_count']:>10,} {r['final_mse']:>10.6f} "
|
|
f"{r['final_cos']:>8.4f} {r['final_firing_rate']:>6.3f}")
|
|
|
|
# Verdict
|
|
best = max(all_results, key=lambda x: x["final_cos"])
|
|
print(f"\nBest config: dim={best['embed_dim']}, neurons={best['num_neurons']}, "
|
|
f"steps={best['num_steps']}")
|
|
print(f" CosSim={best['final_cos']:.4f}, MSE={best['final_mse']:.6f}")
|
|
|
|
if best["final_cos"] > 0.9:
|
|
print("\n✓ PASS: Roundtrip encoding is viable! CosSim > 0.9")
|
|
elif best["final_cos"] > 0.7:
|
|
print("\n~ MARGINAL: CosSim 0.7-0.9, might work for fuzzy associative recall")
|
|
else:
|
|
print("\n✗ FAIL: Roundtrip encoding loses too much information")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|