Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
195 lines
5.8 KiB
Python
195 lines
5.8 KiB
Python
"""Experiment 2g: Multi-hop associative recall.
|
|
|
|
The unique advantage of Hebbian memory over simple cosine retrieval:
|
|
If A→B and B→C are learned, can we recall C from A by chaining through B?
|
|
|
|
This is impossible with standard RAG (which only does single-hop NN lookup).
|
|
If this works, it's the strongest argument for the Hebbian approach.
|
|
"""
|
|
|
|
import sys
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
DEVICE = "cuda"
|
|
|
|
|
|
def cosine(a, b):
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
_, idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, idx, 1.0)
|
|
return out
|
|
|
|
|
|
class HebbianMemory:
|
|
"""Simple Hebbian memory for multi-hop tests."""
|
|
def __init__(self, input_dim=768, code_dim=16384, k=20):
|
|
self.k = k
|
|
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
|
* (1.0 / input_dim**0.5))
|
|
self.W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
|
|
|
def sep(self, x):
|
|
return winner_take_all(x @ self.proj, self.k)
|
|
|
|
def learn(self, cue, target):
|
|
cc = self.sep(cue)
|
|
tc = self.sep(target)
|
|
self.W += torch.outer(tc, cc)
|
|
|
|
def recall_code(self, code, k=None):
|
|
if k is None:
|
|
k = self.k
|
|
raw = self.W @ code
|
|
return winner_take_all(raw, k)
|
|
|
|
def recall(self, cue):
|
|
return self.recall_code(self.sep(cue))
|
|
|
|
def multi_hop_recall(self, cue, hops=2):
|
|
"""Chain through associations: cue → hop1 → hop2 → ..."""
|
|
code = self.sep(cue)
|
|
for _ in range(hops):
|
|
code = self.recall_code(code)
|
|
return code
|
|
|
|
|
|
def test_chain(chain_length, num_chains, dim=768, code_dim=16384, k=20):
|
|
"""Test multi-hop recall along chains of length L.
|
|
|
|
Create chains: A₁→A₂→...→Aₗ
|
|
Learn pairs: (A₁,A₂), (A₂,A₃), ..., (Aₗ₋₁,Aₗ)
|
|
Test: given A₁, can we reach A₂, A₃, ..., Aₗ in 1, 2, ... hops?
|
|
"""
|
|
mem = HebbianMemory(dim, code_dim, k)
|
|
|
|
chains = []
|
|
for _ in range(num_chains):
|
|
chain = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
for _ in range(chain_length)]
|
|
chains.append(chain)
|
|
|
|
# Learn consecutive pairs
|
|
for i in range(chain_length - 1):
|
|
mem.learn(chain[i], chain[i+1])
|
|
|
|
# Test recall at different hop distances
|
|
results = {}
|
|
for hops in range(1, chain_length):
|
|
correct_sims = []
|
|
for chain in chains:
|
|
start = chain[0]
|
|
target = chain[hops]
|
|
target_code = mem.sep(target)
|
|
|
|
recalled = mem.multi_hop_recall(start, hops=hops)
|
|
cs = cosine(recalled, target_code)
|
|
correct_sims.append(cs)
|
|
|
|
mc = np.mean(correct_sims)
|
|
exact = np.mean([s > 0.5 for s in correct_sims])
|
|
results[hops] = {"mean_cos": mc, "recall_rate": exact}
|
|
print(f" chain_len={chain_length}, chains={num_chains}, "
|
|
f"hops={hops}: CosSim={mc:.4f}, recall>{0.5:.0%}={exact:.2%}")
|
|
|
|
return results
|
|
|
|
|
|
def test_convergent_chains(dim=768, code_dim=16384, k=20):
|
|
"""Test convergent chains: A→C and B→C.
|
|
Can we recall C from both A and B?"""
|
|
mem = HebbianMemory(dim, code_dim, k)
|
|
|
|
# Create convergent pattern
|
|
a = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
b = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
c = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
|
|
mem.learn(a, c)
|
|
mem.learn(b, c)
|
|
|
|
c_code = mem.sep(c)
|
|
|
|
# Recall from A
|
|
ra = mem.recall(a)
|
|
sim_a = cosine(ra, c_code)
|
|
|
|
# Recall from B
|
|
rb = mem.recall(b)
|
|
sim_b = cosine(rb, c_code)
|
|
|
|
print(f" Convergent: A→C sim={sim_a:.4f}, B→C sim={sim_b:.4f}")
|
|
return {"a_to_c": sim_a, "b_to_c": sim_b}
|
|
|
|
|
|
def test_divergent_chains(dim=768, code_dim=16384, k=20):
|
|
"""Test divergent chains: A→B and A→C.
|
|
Do B and C interfere?"""
|
|
mem = HebbianMemory(dim, code_dim, k)
|
|
|
|
a = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
b = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
c = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
|
|
mem.learn(a, b)
|
|
mem.learn(a, c)
|
|
|
|
b_code = mem.sep(b)
|
|
c_code = mem.sep(c)
|
|
|
|
recalled = mem.recall(a)
|
|
sim_b = cosine(recalled, b_code)
|
|
sim_c = cosine(recalled, c_code)
|
|
|
|
print(f" Divergent: A→B sim={sim_b:.4f}, A→C sim={sim_c:.4f}")
|
|
return {"a_to_b": sim_b, "a_to_c": sim_c}
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 2g: Multi-hop Associative Recall")
|
|
print("=" * 60)
|
|
|
|
# Test 1: Simple chains
|
|
print("\n=== Chain recall (single chain) ===")
|
|
for L in [3, 5, 7]:
|
|
test_chain(L, num_chains=1)
|
|
|
|
# Test 2: Multiple chains (interference between chains)
|
|
print("\n=== Chain recall (multiple chains, interference) ===")
|
|
for n_chains in [1, 5, 10, 50, 100]:
|
|
print(f"\n-- {n_chains} chains of length 4 --")
|
|
test_chain(4, num_chains=n_chains)
|
|
|
|
# Test 3: Convergent
|
|
print("\n=== Convergent chains (A→C, B→C) ===")
|
|
results = []
|
|
for _ in range(20):
|
|
r = test_convergent_chains()
|
|
results.append(r)
|
|
mean_a = np.mean([r["a_to_c"] for r in results])
|
|
mean_b = np.mean([r["b_to_c"] for r in results])
|
|
print(f" Average: A→C={mean_a:.4f}, B→C={mean_b:.4f}")
|
|
|
|
# Test 4: Divergent
|
|
print("\n=== Divergent chains (A→B, A→C) ===")
|
|
results = []
|
|
for _ in range(20):
|
|
r = test_divergent_chains()
|
|
results.append(r)
|
|
mean_b = np.mean([r["a_to_b"] for r in results])
|
|
mean_c = np.mean([r["a_to_c"] for r in results])
|
|
print(f" Average: A→B={mean_b:.4f}, A→C={mean_c:.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|