"""Experiment 6: BioHash — Learnable Fly Algorithm. Replace random projection with learned projection trained via contrastive loss on real sentence embeddings. The key insight from Dasgupta 2017 (Science): random projection + WTA already preserves neighborhoods. Learning the projection should make it even better. Training objective: - Positive pairs (similar sentences): maximize Jaccard overlap of sparse codes - Negative pairs (different sentences): minimize overlap Since WTA is not differentiable, we use a soft relaxation during training (Gumbel-softmax or straight-through estimator) and hard WTA at test time. """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import torch.optim as optim import numpy as np DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def winner_take_all(x, k): _, idx = x.topk(k, dim=-1) out = torch.zeros_like(x) out.scatter_(-1, idx, 1.0) return out def jaccard(a, b): """Jaccard similarity of two binary vectors.""" intersection = (a * b).sum(dim=-1) union = ((a + b) > 0).float().sum(dim=-1) return (intersection / union.clamp(min=1)).mean().item() def soft_topk(x, k, temperature=1.0): """Differentiable approximation of WTA using softmax.""" # Straight-through estimator: hard WTA forward, soft backward hard = winner_take_all(x, k) soft = torch.softmax(x / temperature, dim=-1) * k # scaled softmax return hard + (soft - soft.detach()) # STE trick class BioHash(nn.Module): """Learnable Fly Hash with WTA sparsification. Architecture mirrors fruit fly olfactory circuit: - Projection neurons (PN): input → high-dim (learned, replaces random) - Kenyon cells (KC): WTA top-k → sparse binary code """ def __init__(self, input_dim=384, code_dim=16384, k=50): super().__init__() self.k = k self.code_dim = code_dim # Learnable projection (replaces random matrix) self.proj = nn.Linear(input_dim, code_dim, bias=False) # Initialize like random fly projection nn.init.normal_(self.proj.weight, std=1.0 / input_dim**0.5) def forward(self, x, soft=False, temperature=1.0): """ x: [batch, input_dim] normalized embeddings Returns: [batch, code_dim] sparse binary codes """ h = self.proj(x) # [batch, code_dim] if soft: return soft_topk(h, self.k, temperature) return winner_take_all(h, self.k) def encode_hard(self, x): """Hard WTA encoding (for inference).""" with torch.no_grad(): return winner_take_all(self.proj(x), self.k) class RandomFlyHash(nn.Module): """Baseline: original random Fly algorithm (not learned).""" def __init__(self, input_dim=384, code_dim=16384, k=50): super().__init__() self.k = k proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('proj', proj) def encode_hard(self, x): with torch.no_grad(): return winner_take_all(x @ self.proj, self.k) def generate_training_data(model, n_pairs=5000, noise_std=0.3): """Generate contrastive pairs from sentence embeddings. Positive pairs: same sentence with noise (simulating paraphrase) Negative pairs: different sentences """ # Diverse training sentences templates = [ "The {} is having {} issues", "We need to {} the {} system", "The {} team is working on {}", "There's a bug in the {} {}", "Let's deploy {} to {}", "The {} performance is {}", "How do I configure {}?", "The {} logs show {}", "We should monitor the {} {}", "The {} needs {} upgrade", ] subjects = ["database", "API", "server", "frontend", "backend", "auth", "cache", "queue", "storage", "network", "deployment", "monitoring", "logging", "testing", "CI/CD"] modifiers = ["critical", "minor", "performance", "security", "timeout", "memory", "disk", "CPU", "latency", "throughput"] sentences = [] for t in templates: for s in subjects: for m in modifiers: sentences.append(t.format(s, m)) np.random.shuffle(sentences) sentences = sentences[:n_pairs * 2] # enough for pairs # Encode embs = model.encode(sentences, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE, batch_size=256) return embs def train_biohash(model, code_dim=16384, k=50, epochs=100, batch_size=256, lr=1e-3, noise_std=0.3, margin=0.2): """Train BioHash with contrastive loss on sentence embeddings.""" embed_dim = model.get_sentence_embedding_dimension() hasher = BioHash(embed_dim, code_dim, k).to(DEVICE) optimizer = optim.Adam(hasher.parameters(), lr=lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) print(f"Training BioHash: code={code_dim}, k={k}, noise={noise_std}") # Generate training embeddings embs = generate_training_data(model, n_pairs=5000) for epoch in range(epochs): # Sample batch idx = torch.randperm(embs.shape[0])[:batch_size] anchor = embs[idx] # Positive: add noise (simulate paraphrase) pos = nn.functional.normalize( anchor + torch.randn_like(anchor) * noise_std, dim=-1) # Negative: random different embeddings neg_idx = torch.randperm(embs.shape[0])[:batch_size] neg = embs[neg_idx] # Forward with STE code_anchor = hasher(anchor, soft=True, temperature=0.5) code_pos = hasher(pos, soft=True, temperature=0.5) code_neg = hasher(neg, soft=True, temperature=0.5) # Jaccard-like loss (differentiable via STE) # Positive overlap: maximize pos_overlap = (code_anchor * code_pos).sum(dim=-1) / k # Negative overlap: minimize (with margin) neg_overlap = (code_anchor * code_neg).sum(dim=-1) / k loss = -pos_overlap.mean() + torch.relu(neg_overlap - margin).mean() optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(hasher.parameters(), 1.0) optimizer.step() scheduler.step() if (epoch + 1) % 20 == 0: # Eval with hard WTA with torch.no_grad(): h_anchor = hasher.encode_hard(anchor) h_pos = hasher.encode_hard(pos) h_neg = hasher.encode_hard(neg) j_pos = jaccard(h_anchor, h_pos) j_neg = jaccard(h_anchor, h_neg) print(f" Epoch {epoch+1}: loss={loss.item():.4f}, " f"Jaccard_pos={j_pos:.4f}, Jaccard_neg={j_neg:.4f}, " f"gap={j_pos-j_neg:.4f}") return hasher def evaluate_recall(hasher, model, label=""): """Test associative recall with this hasher.""" # Memory pairs pairs = [ ("What's the weather like today?", "User prefers to check weather every morning"), ("Let's deploy the new version", "The deployment pipeline uses GitHub Actions with k3s"), ("The database is slow again", "Missing index on users table caused slowdown"), ("I need to fix the auth bug", "Auth uses JWT tokens with 24h expiry in Redis"), ("The API returns 500 errors", "Last 500 was OOM in the Python worker"), ("Let's set up monitoring", "Prometheus + Grafana on OCI cluster"), ("The tests are failing", "CI needs postgres service container"), ("Memory usage is too high", "Known leak in websocket handler"), ("Help with Docker setup", "docker-compose for dev, k3s for prod"), ("Log files are too large", "Logs rotate daily, 30 days retention, shipped to Loki"), ] paraphrases = [ "How's the weather outside?", "We should push the new release", "DB performance is terrible", "There's a login bug to fix", "Getting internal server errors", "We need better observability", "CI tests keep breaking", "The service is using too much RAM", "Help me with Docker configuration", "Logs are eating up disk space", ] cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) para_embs = model.encode(paraphrases, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) # Build Hebbian memory code_dim = hasher.encode_hard(cue_embs[:1]).shape[-1] k = int(hasher.encode_hard(cue_embs[:1]).sum().item()) W = torch.zeros(code_dim, code_dim, device=DEVICE) cue_codes = hasher.encode_hard(cue_embs) target_codes = hasher.encode_hard(target_embs) for i in range(len(pairs)): W += torch.outer(target_codes[i], cue_codes[i]) # Test exact recall exact_correct = 0 for i in range(len(pairs)): recalled = winner_take_all(W @ cue_codes[i], k) sims = nn.functional.cosine_similarity( recalled.unsqueeze(0), target_codes, dim=-1) if sims.argmax().item() == i: exact_correct += 1 # Test paraphrase recall para_correct = 0 para_codes = hasher.encode_hard(para_embs) for i in range(len(paraphrases)): recalled = winner_take_all(W @ para_codes[i], k) sims = nn.functional.cosine_similarity( recalled.unsqueeze(0), target_codes, dim=-1) if sims.argmax().item() == i: para_correct += 1 # Code overlap analysis pos_overlaps = [] neg_overlaps = [] for i in range(len(pairs)): # Positive: cue vs paraphrase overlap = (cue_codes[i] * para_codes[i]).sum().item() / k pos_overlaps.append(overlap) # Negative: cue vs random other paraphrase j = (i + 1) % len(pairs) overlap_neg = (cue_codes[i] * para_codes[j]).sum().item() / k neg_overlaps.append(overlap_neg) n = len(pairs) print(f" {label}: Exact={exact_correct}/{n}, Para={para_correct}/{n}, " f"CodeOverlap: pos={np.mean(pos_overlaps):.3f}, " f"neg={np.mean(neg_overlaps):.3f}, " f"gap={np.mean(pos_overlaps)-np.mean(neg_overlaps):.3f}") return exact_correct / n, para_correct / n, np.mean(pos_overlaps) def evaluate_at_scale(hasher, model, n_background, label=""): """Test with background memories (the real challenge).""" pairs = [ ("The database is slow", "Check missing indexes on users table"), ("Deploy to production", "Use blue-green via GitHub Actions"), ("Server crashed", "Check logs, likely OOM in Python worker"), ("Fix the auth bug", "JWT tokens with 24h expiry in Redis"), ("API returns 500", "OOM in Python worker process"), ] paraphrases = [ "DB performance terrible", "Push the new release", "Server is down", "Login bug needs fixing", "Getting 500 errors from API", ] # Background noise bg_sentences = [f"Background task {i} about topic {i%20}" for i in range(n_background)] bg_targets = [f"Background detail {i} with info {i%10}" for i in range(n_background)] all_cues = [p[0] for p in pairs] + bg_sentences all_targets = [p[1] for p in pairs] + bg_targets cue_embs = model.encode(all_cues, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE, batch_size=256) target_embs = model.encode(all_targets, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE, batch_size=256) para_embs = model.encode(paraphrases, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE) # Build memory cue_codes = hasher.encode_hard(cue_embs) target_codes = hasher.encode_hard(target_embs) code_dim = cue_codes.shape[-1] k = int(cue_codes[0].sum().item()) W = torch.zeros(code_dim, code_dim, device=DEVICE) for i in range(len(all_cues)): W += torch.outer(target_codes[i], cue_codes[i]) # Test paraphrase recall para_codes = hasher.encode_hard(para_embs) correct = 0 for i in range(len(paraphrases)): recalled = winner_take_all(W @ para_codes[i], k) sims = nn.functional.cosine_similarity( recalled.unsqueeze(0), target_codes[:len(pairs)], dim=-1) if sims.argmax().item() == i: correct += 1 n = len(paraphrases) print(f" {label} (bg={n_background}): Para={correct}/{n} ({correct/n:.0%})") return correct / n def main(): print("=" * 60) print("Experiment 6: BioHash — Learnable Fly Algorithm") print("=" * 60) from sentence_transformers import SentenceTransformer model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE) # Baseline: random projection (current approach) print("\n=== Baseline: Random Fly Hash ===") random_hasher = RandomFlyHash(384, 16384, 50).to(DEVICE) evaluate_recall(random_hasher, model, "Random") for n_bg in [0, 100, 500]: evaluate_at_scale(random_hasher, model, n_bg, "Random") # Train BioHash with different configs print("\n=== Training BioHash ===") for noise_std in [0.2, 0.5]: print(f"\n--- noise_std={noise_std} ---") hasher = train_biohash(model, code_dim=16384, k=50, epochs=200, noise_std=noise_std, lr=1e-3) evaluate_recall(hasher, model, f"BioHash(noise={noise_std})") for n_bg in [0, 100, 500]: evaluate_at_scale(hasher, model, n_bg, f"BioHash(noise={noise_std})") # Try different k values with BioHash print("\n=== BioHash: k sweep ===") for k in [20, 50, 100, 200]: hasher = train_biohash(model, code_dim=16384, k=k, epochs=200, noise_std=0.3, lr=1e-3) evaluate_recall(hasher, model, f"BioHash(k={k})") evaluate_at_scale(hasher, model, 500, f"BioHash(k={k})") if __name__ == "__main__": main()