"""Experiment 2f: Check discrimination for soft WTA + test learned separator. Soft WTA temp=0.5 showed perfect noise tolerance but might have zero discrimination. Need to check: can it tell correct target from wrong targets? Then test: learned pattern separator (trained to be noise-robust via contrastive loss). """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import torch.optim as optim import numpy as np DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def cosine(a, b): if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() def winner_take_all(x, k): _, idx = x.topk(k, dim=-1) out = torch.zeros_like(x) out.scatter_(-1, idx, 1.0) return out class SoftWTAMemory(nn.Module): def __init__(self, input_dim=768, code_dim=16384, temperature=0.5): super().__init__() self.temperature = temperature proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('proj', proj) target_proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('target_proj', target_proj) self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False) def encode(self, x, proj): return torch.softmax((x @ proj) / self.temperature, dim=-1) def learn(self, cue, target): cc = self.encode(cue, self.proj) tc = self.encode(target, self.target_proj) self.W.data += torch.outer(tc, cc) def recall(self, cue): cc = self.encode(cue, self.proj) return self.W @ cc def check_discrimination(temperature, num_pairs=100): """Check correct vs wrong similarity for soft WTA.""" mem = SoftWTAMemory(temperature=temperature).to(DEVICE) cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] for i in range(num_pairs): mem.learn(cues[i], targets[i]) # Test with noise=0.1 for noise_std in [0.0, 0.1, 0.5]: correct_sims = [] wrong_sims = [] for i in range(num_pairs): noisy = nn.functional.normalize( cues[i] + torch.randn_like(cues[i]) * noise_std, dim=0) recalled = mem.recall(noisy) tc = mem.encode(targets[i], mem.target_proj) correct_sims.append(cosine(recalled, tc)) # Compare to random wrong targets for j in range(min(20, num_pairs)): if j != i: wc = mem.encode(targets[j], mem.target_proj) wrong_sims.append(cosine(recalled, wc)) mc = np.mean(correct_sims) mw = np.mean(wrong_sims) print(f" temp={temperature}, noise={noise_std:.1f}: " f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}") class LearnedSeparator(nn.Module): """Trained pattern separator: maps similar inputs to same code. Architecture: MLP → sparse output (WTA) Training: contrastive loss on (original, noisy) pairs """ def __init__(self, input_dim=768, code_dim=4096, k_active=50): super().__init__() self.k_active = k_active self.code_dim = code_dim self.net = nn.Sequential( nn.Linear(input_dim, code_dim), nn.ReLU(), nn.Linear(code_dim, code_dim), ) def forward(self, x): h = self.net(x) return winner_take_all(h, self.k_active) def forward_soft(self, x, temperature=0.1): """Soft version for training (differentiable).""" h = self.net(x) return torch.softmax(h / temperature, dim=-1) def train_learned_separator(input_dim=768, code_dim=4096, k_active=50, epochs=100, batch_size=128, noise_std=0.3): """Train separator to produce same codes for original and noisy versions.""" sep = LearnedSeparator(input_dim, code_dim, k_active).to(DEVICE) optimizer = optim.Adam(sep.parameters(), lr=1e-3) print(f"\nTraining learned separator (code_dim={code_dim}, k={k_active}, " f"noise={noise_std})") for epoch in range(epochs): # Generate batch of normalized vectors x = nn.functional.normalize(torch.randn(batch_size, input_dim, device=DEVICE), dim=1) # Noisy version x_noisy = nn.functional.normalize(x + torch.randn_like(x) * noise_std, dim=1) # Different vector (negative) x_neg = nn.functional.normalize(torch.randn(batch_size, input_dim, device=DEVICE), dim=1) # Soft codes code = sep.forward_soft(x) code_noisy = sep.forward_soft(x_noisy) code_neg = sep.forward_soft(x_neg) # Contrastive loss: same input → same code, diff input → diff code pos_sim = nn.functional.cosine_similarity(code, code_noisy, dim=1).mean() neg_sim = nn.functional.cosine_similarity(code, code_neg, dim=1).mean() loss = -pos_sim + 0.5 * torch.relu(neg_sim - 0.1) # Sparsity regularization entropy = -(code * (code + 1e-10).log()).sum(dim=1).mean() loss += 0.01 * entropy optimizer.zero_grad() loss.backward() optimizer.step() if (epoch + 1) % 20 == 0: with torch.no_grad(): hard_code = sep(x) hard_noisy = sep(x_noisy) hard_neg = sep(x_neg) # Exact match rate (same WTA pattern) match_rate = (hard_code * hard_noisy).sum(dim=1).mean() / k_active neg_match = (hard_code * hard_neg).sum(dim=1).mean() / k_active print(f" Epoch {epoch+1}: loss={loss.item():.4f}, " f"pos_match={match_rate:.4f}, neg_match={neg_match:.4f}") return sep def test_learned_memory(sep, num_pairs=100, noise_levels=None): """Test Hebbian memory using learned separator.""" if noise_levels is None: noise_levels = [0.0, 0.1, 0.2, 0.5, 1.0] code_dim = sep.code_dim k = sep.k_active W = torch.zeros(code_dim, code_dim, device=DEVICE) cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] # Learn with torch.no_grad(): cue_codes = [sep(c.unsqueeze(0)).squeeze() for c in cues] target_codes = [sep(t.unsqueeze(0)).squeeze() for t in targets] for i in range(num_pairs): W += torch.outer(target_codes[i], cue_codes[i]) # Test for ns in noise_levels: correct_sims = [] wrong_sims = [] for i in range(num_pairs): noisy = nn.functional.normalize( cues[i] + torch.randn_like(cues[i]) * ns, dim=0) with torch.no_grad(): nc = sep(noisy.unsqueeze(0)).squeeze() recalled_raw = W @ nc recalled = winner_take_all(recalled_raw, k) cs = cosine(recalled, target_codes[i]) correct_sims.append(cs) for j in range(min(20, num_pairs)): if j != i: wrong_sims.append(cosine(recalled, target_codes[j])) mc = np.mean(correct_sims) mw = np.mean(wrong_sims) exact = np.mean([s > 0.99 for s in correct_sims]) print(f" noise={ns:.2f}: Correct={mc:.4f}, Wrong={mw:.4f}, " f"Disc={mc-mw:.4f}, Exact={exact:.2%}") def main(): print("=" * 60) print("Experiment 2f: Discrimination Check + Learned Separator") print("=" * 60) # Part 1: Check discrimination for soft WTA print("\n=== Part 1: Soft WTA Discrimination ===") for temp in [0.01, 0.05, 0.1, 0.5, 1.0]: check_discrimination(temp) print() # Part 2: Learned separator print("\n=== Part 2: Learned Separator ===") # Train with different noise levels for train_noise in [0.1, 0.3, 0.5]: sep = train_learned_separator( code_dim=4096, k_active=50, epochs=200, noise_std=train_noise) print(f"\n Testing (trained with noise={train_noise}):") test_learned_memory(sep, num_pairs=100) print() # Part 3: Larger learned separator print("\n=== Part 3: Larger Learned Separator (code=8192, k=20) ===") sep = train_learned_separator( code_dim=8192, k_active=20, epochs=300, noise_std=0.3) print("\n Testing:") test_learned_memory(sep, num_pairs=200) if __name__ == "__main__": main()