"""Experiment 5b: Lightweight performance benchmarks. Skip the 65536 config that OOMs. """ import sys import time from pathlib import Path import torch import torch.nn as nn DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def winner_take_all(x, k): _, idx = x.topk(k, dim=-1) out = torch.zeros_like(x) out.scatter_(-1, idx, 1.0) return out class BenchMemory: def __init__(self, input_dim, code_dim, k): self.k = k self.code_dim = code_dim self.proj = (torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)) self.W = torch.zeros(code_dim, code_dim, device=DEVICE) def sep(self, x): return winner_take_all(x @ self.proj, self.k) def learn(self, cue, target): self.W += torch.outer(self.sep(target), self.sep(cue)) def recall(self, query, hops=1): code = self.sep(query) for _ in range(hops): code = winner_take_all(self.W @ code, self.k) return code def main(): input_dim = 384 # Learning throughput print("=== Learning Throughput ===") for code_dim, k in [(8192, 50), (16384, 50), (32768, 50)]: mem = BenchMemory(input_dim, code_dim, k) n = 5000 cues = torch.randn(n, input_dim, device=DEVICE) targets = torch.randn(n, input_dim, device=DEVICE) torch.cuda.synchronize() t0 = time.time() for i in range(n): mem.learn(cues[i], targets[i]) torch.cuda.synchronize() dt = time.time() - t0 print(f" code={code_dim}, k={k}: {n/dt:.0f} memories/s ({dt:.2f}s for {n})") del mem torch.cuda.empty_cache() # Recall latency print("\n=== Recall Latency ===") for code_dim, k in [(8192, 50), (16384, 50), (32768, 50)]: mem = BenchMemory(input_dim, code_dim, k) for _ in range(1000): mem.learn(torch.randn(input_dim, device=DEVICE), torch.randn(input_dim, device=DEVICE)) queries = torch.randn(1000, input_dim, device=DEVICE) torch.cuda.synchronize() t0 = time.time() for i in range(1000): mem.recall(queries[i]) torch.cuda.synchronize() ms = (time.time() - t0) / 1000 * 1000 print(f" code={code_dim}, k={k}: {ms:.3f} ms/query") del mem torch.cuda.empty_cache() # Multi-hop latency print("\n=== Multi-hop Latency (code=16384, k=50) ===") mem = BenchMemory(input_dim, 16384, 50) for _ in range(1000): mem.learn(torch.randn(input_dim, device=DEVICE), torch.randn(input_dim, device=DEVICE)) queries = torch.randn(500, input_dim, device=DEVICE) for hops in [1, 2, 3, 5, 10]: torch.cuda.synchronize() t0 = time.time() for i in range(500): mem.recall(queries[i], hops=hops) torch.cuda.synchronize() ms = (time.time() - t0) / 500 * 1000 print(f" hops={hops:>2}: {ms:.3f} ms/query") del mem torch.cuda.empty_cache() # Memory usage print("\n=== GPU Memory Usage ===") for cd in [4096, 8192, 16384, 32768]: torch.cuda.empty_cache() before = torch.cuda.memory_allocated() mem = BenchMemory(input_dim, cd, 50) for _ in range(1000): mem.learn(torch.randn(input_dim, device=DEVICE), torch.randn(input_dim, device=DEVICE)) after = torch.cuda.memory_allocated() mb = (after - before) / 1024**2 w_mb = cd * cd * 4 / 1024**2 print(f" code_dim={cd:>5}: total={mb:.0f} MB (W matrix={w_mb:.0f} MB)") del mem torch.cuda.empty_cache() # E2E with sentence-transformers print("\n=== End-to-End Pipeline ===") from sentence_transformers import SentenceTransformer model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE) mem = BenchMemory(384, 16384, 50) embs = model.encode([f"Sentence {i}" for i in range(1000)], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) for i in range(999): mem.learn(embs[i], embs[i+1]) query = "What is the test?" n_runs = 50 # Embedding time torch.cuda.synchronize() t0 = time.time() for _ in range(n_runs): q = model.encode([query], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)[0] torch.cuda.synchronize() embed_ms = (time.time() - t0) / n_runs * 1000 # Recall time torch.cuda.synchronize() t0 = time.time() for _ in range(n_runs): mem.recall(q) torch.cuda.synchronize() recall_ms = (time.time() - t0) / n_runs * 1000 print(f" Embedding: {embed_ms:.1f} ms") print(f" Recall: {recall_ms:.3f} ms") print(f" Total: {embed_ms + recall_ms:.1f} ms") print(f" Bottleneck: embedding is {embed_ms/recall_ms:.0f}x slower than recall") if __name__ == "__main__": main()