Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
159 lines
4.9 KiB
Python
159 lines
4.9 KiB
Python
"""Experiment 5b: Lightweight performance benchmarks.
|
|
Skip the 65536 config that OOMs.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
_, idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, idx, 1.0)
|
|
return out
|
|
|
|
|
|
class BenchMemory:
|
|
def __init__(self, input_dim, code_dim, k):
|
|
self.k = k
|
|
self.code_dim = code_dim
|
|
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
|
* (1.0 / input_dim**0.5))
|
|
self.W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
|
|
|
def sep(self, x):
|
|
return winner_take_all(x @ self.proj, self.k)
|
|
|
|
def learn(self, cue, target):
|
|
self.W += torch.outer(self.sep(target), self.sep(cue))
|
|
|
|
def recall(self, query, hops=1):
|
|
code = self.sep(query)
|
|
for _ in range(hops):
|
|
code = winner_take_all(self.W @ code, self.k)
|
|
return code
|
|
|
|
|
|
def main():
|
|
input_dim = 384
|
|
|
|
# Learning throughput
|
|
print("=== Learning Throughput ===")
|
|
for code_dim, k in [(8192, 50), (16384, 50), (32768, 50)]:
|
|
mem = BenchMemory(input_dim, code_dim, k)
|
|
n = 5000
|
|
cues = torch.randn(n, input_dim, device=DEVICE)
|
|
targets = torch.randn(n, input_dim, device=DEVICE)
|
|
|
|
torch.cuda.synchronize()
|
|
t0 = time.time()
|
|
for i in range(n):
|
|
mem.learn(cues[i], targets[i])
|
|
torch.cuda.synchronize()
|
|
dt = time.time() - t0
|
|
print(f" code={code_dim}, k={k}: {n/dt:.0f} memories/s ({dt:.2f}s for {n})")
|
|
del mem
|
|
torch.cuda.empty_cache()
|
|
|
|
# Recall latency
|
|
print("\n=== Recall Latency ===")
|
|
for code_dim, k in [(8192, 50), (16384, 50), (32768, 50)]:
|
|
mem = BenchMemory(input_dim, code_dim, k)
|
|
for _ in range(1000):
|
|
mem.learn(torch.randn(input_dim, device=DEVICE),
|
|
torch.randn(input_dim, device=DEVICE))
|
|
|
|
queries = torch.randn(1000, input_dim, device=DEVICE)
|
|
torch.cuda.synchronize()
|
|
t0 = time.time()
|
|
for i in range(1000):
|
|
mem.recall(queries[i])
|
|
torch.cuda.synchronize()
|
|
ms = (time.time() - t0) / 1000 * 1000
|
|
print(f" code={code_dim}, k={k}: {ms:.3f} ms/query")
|
|
del mem
|
|
torch.cuda.empty_cache()
|
|
|
|
# Multi-hop latency
|
|
print("\n=== Multi-hop Latency (code=16384, k=50) ===")
|
|
mem = BenchMemory(input_dim, 16384, 50)
|
|
for _ in range(1000):
|
|
mem.learn(torch.randn(input_dim, device=DEVICE),
|
|
torch.randn(input_dim, device=DEVICE))
|
|
|
|
queries = torch.randn(500, input_dim, device=DEVICE)
|
|
for hops in [1, 2, 3, 5, 10]:
|
|
torch.cuda.synchronize()
|
|
t0 = time.time()
|
|
for i in range(500):
|
|
mem.recall(queries[i], hops=hops)
|
|
torch.cuda.synchronize()
|
|
ms = (time.time() - t0) / 500 * 1000
|
|
print(f" hops={hops:>2}: {ms:.3f} ms/query")
|
|
del mem
|
|
torch.cuda.empty_cache()
|
|
|
|
# Memory usage
|
|
print("\n=== GPU Memory Usage ===")
|
|
for cd in [4096, 8192, 16384, 32768]:
|
|
torch.cuda.empty_cache()
|
|
before = torch.cuda.memory_allocated()
|
|
mem = BenchMemory(input_dim, cd, 50)
|
|
for _ in range(1000):
|
|
mem.learn(torch.randn(input_dim, device=DEVICE),
|
|
torch.randn(input_dim, device=DEVICE))
|
|
after = torch.cuda.memory_allocated()
|
|
mb = (after - before) / 1024**2
|
|
w_mb = cd * cd * 4 / 1024**2
|
|
print(f" code_dim={cd:>5}: total={mb:.0f} MB (W matrix={w_mb:.0f} MB)")
|
|
del mem
|
|
torch.cuda.empty_cache()
|
|
|
|
# E2E with sentence-transformers
|
|
print("\n=== End-to-End Pipeline ===")
|
|
from sentence_transformers import SentenceTransformer
|
|
model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
mem = BenchMemory(384, 16384, 50)
|
|
embs = model.encode([f"Sentence {i}" for i in range(1000)],
|
|
convert_to_tensor=True, normalize_embeddings=True,
|
|
device=DEVICE)
|
|
for i in range(999):
|
|
mem.learn(embs[i], embs[i+1])
|
|
|
|
query = "What is the test?"
|
|
n_runs = 50
|
|
|
|
# Embedding time
|
|
torch.cuda.synchronize()
|
|
t0 = time.time()
|
|
for _ in range(n_runs):
|
|
q = model.encode([query], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
torch.cuda.synchronize()
|
|
embed_ms = (time.time() - t0) / n_runs * 1000
|
|
|
|
# Recall time
|
|
torch.cuda.synchronize()
|
|
t0 = time.time()
|
|
for _ in range(n_runs):
|
|
mem.recall(q)
|
|
torch.cuda.synchronize()
|
|
recall_ms = (time.time() - t0) / n_runs * 1000
|
|
|
|
print(f" Embedding: {embed_ms:.1f} ms")
|
|
print(f" Recall: {recall_ms:.3f} ms")
|
|
print(f" Total: {embed_ms + recall_ms:.1f} ms")
|
|
print(f" Bottleneck: embedding is {embed_ms/recall_ms:.0f}x slower than recall")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|