"""Experiment 2g: Multi-hop associative recall. The unique advantage of Hebbian memory over simple cosine retrieval: If A→B and B→C are learned, can we recall C from A by chaining through B? This is impossible with standard RAG (which only does single-hop NN lookup). If this works, it's the strongest argument for the Hebbian approach. """ import sys import torch import torch.nn as nn import numpy as np from pathlib import Path DEVICE = "cuda" def cosine(a, b): if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() def winner_take_all(x, k): _, idx = x.topk(k, dim=-1) out = torch.zeros_like(x) out.scatter_(-1, idx, 1.0) return out class HebbianMemory: """Simple Hebbian memory for multi-hop tests.""" def __init__(self, input_dim=768, code_dim=16384, k=20): self.k = k self.proj = (torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)) self.W = torch.zeros(code_dim, code_dim, device=DEVICE) def sep(self, x): return winner_take_all(x @ self.proj, self.k) def learn(self, cue, target): cc = self.sep(cue) tc = self.sep(target) self.W += torch.outer(tc, cc) def recall_code(self, code, k=None): if k is None: k = self.k raw = self.W @ code return winner_take_all(raw, k) def recall(self, cue): return self.recall_code(self.sep(cue)) def multi_hop_recall(self, cue, hops=2): """Chain through associations: cue → hop1 → hop2 → ...""" code = self.sep(cue) for _ in range(hops): code = self.recall_code(code) return code def test_chain(chain_length, num_chains, dim=768, code_dim=16384, k=20): """Test multi-hop recall along chains of length L. Create chains: A₁→A₂→...→Aₗ Learn pairs: (A₁,A₂), (A₂,A₃), ..., (Aₗ₋₁,Aₗ) Test: given A₁, can we reach A₂, A₃, ..., Aₗ in 1, 2, ... hops? """ mem = HebbianMemory(dim, code_dim, k) chains = [] for _ in range(num_chains): chain = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) for _ in range(chain_length)] chains.append(chain) # Learn consecutive pairs for i in range(chain_length - 1): mem.learn(chain[i], chain[i+1]) # Test recall at different hop distances results = {} for hops in range(1, chain_length): correct_sims = [] for chain in chains: start = chain[0] target = chain[hops] target_code = mem.sep(target) recalled = mem.multi_hop_recall(start, hops=hops) cs = cosine(recalled, target_code) correct_sims.append(cs) mc = np.mean(correct_sims) exact = np.mean([s > 0.5 for s in correct_sims]) results[hops] = {"mean_cos": mc, "recall_rate": exact} print(f" chain_len={chain_length}, chains={num_chains}, " f"hops={hops}: CosSim={mc:.4f}, recall>{0.5:.0%}={exact:.2%}") return results def test_convergent_chains(dim=768, code_dim=16384, k=20): """Test convergent chains: A→C and B→C. Can we recall C from both A and B?""" mem = HebbianMemory(dim, code_dim, k) # Create convergent pattern a = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) b = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) c = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) mem.learn(a, c) mem.learn(b, c) c_code = mem.sep(c) # Recall from A ra = mem.recall(a) sim_a = cosine(ra, c_code) # Recall from B rb = mem.recall(b) sim_b = cosine(rb, c_code) print(f" Convergent: A→C sim={sim_a:.4f}, B→C sim={sim_b:.4f}") return {"a_to_c": sim_a, "b_to_c": sim_b} def test_divergent_chains(dim=768, code_dim=16384, k=20): """Test divergent chains: A→B and A→C. Do B and C interfere?""" mem = HebbianMemory(dim, code_dim, k) a = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) b = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) c = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0) mem.learn(a, b) mem.learn(a, c) b_code = mem.sep(b) c_code = mem.sep(c) recalled = mem.recall(a) sim_b = cosine(recalled, b_code) sim_c = cosine(recalled, c_code) print(f" Divergent: A→B sim={sim_b:.4f}, A→C sim={sim_c:.4f}") return {"a_to_b": sim_b, "a_to_c": sim_c} def main(): print("=" * 60) print("Experiment 2g: Multi-hop Associative Recall") print("=" * 60) # Test 1: Simple chains print("\n=== Chain recall (single chain) ===") for L in [3, 5, 7]: test_chain(L, num_chains=1) # Test 2: Multiple chains (interference between chains) print("\n=== Chain recall (multiple chains, interference) ===") for n_chains in [1, 5, 10, 50, 100]: print(f"\n-- {n_chains} chains of length 4 --") test_chain(4, num_chains=n_chains) # Test 3: Convergent print("\n=== Convergent chains (A→C, B→C) ===") results = [] for _ in range(20): r = test_convergent_chains() results.append(r) mean_a = np.mean([r["a_to_c"] for r in results]) mean_b = np.mean([r["b_to_c"] for r in results]) print(f" Average: A→C={mean_a:.4f}, B→C={mean_b:.4f}") # Test 4: Divergent print("\n=== Divergent chains (A→B, A→C) ===") results = [] for _ in range(20): r = test_divergent_chains() results.append(r) mean_b = np.mean([r["a_to_b"] for r in results]) mean_c = np.mean([r["a_to_c"] for r in results]) print(f" Average: A→B={mean_b:.4f}, A→C={mean_c:.4f}") if __name__ == "__main__": main()