"""Experiment 1: Encoder roundtrip test. Goal: Can we encode an embedding into spikes and decode it back with acceptable loss? This is the foundation — if this fails, the whole approach is dead. We train a SpikeAutoencoder on random embeddings (simulating LLM hidden states) and measure reconstruction quality via cosine similarity and MSE. """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import torch.optim as optim import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from nuonuo.encoder import SpikeAutoencoder DEVICE = "cuda" if torch.cuda.is_available() else "cpu" RESULTS_DIR = Path(__file__).parent.parent / "doc" RESULTS_DIR.mkdir(exist_ok=True) def cosine_sim(a, b): """Batch cosine similarity.""" return nn.functional.cosine_similarity(a, b, dim=-1).mean().item() def run_config(embed_dim, num_neurons, num_steps, lr, epochs, batch_size, num_batches): """Train and evaluate one configuration.""" model = SpikeAutoencoder(embed_dim, num_neurons, num_steps).to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=lr) mse_loss = nn.MSELoss() cos_loss = nn.CosineEmbeddingLoss() param_count = sum(p.numel() for p in model.parameters()) print(f" Config: dim={embed_dim}, neurons={num_neurons}, steps={num_steps}") print(f" Parameters: {param_count:,}") history = {"train_mse": [], "train_cos": [], "epoch_time": []} target = torch.ones(batch_size, device=DEVICE) for epoch in range(epochs): t0 = time.time() epoch_mse = 0 epoch_cos = 0 for _ in range(num_batches): # Random embeddings — simulate LLM hidden states (normalized) emb = torch.randn(batch_size, embed_dim, device=DEVICE) emb = nn.functional.normalize(emb, dim=-1) recon, spikes, _ = model(emb) loss_mse = mse_loss(recon, emb) loss_cos = cos_loss(recon, emb, target) # Sparsity regularization: encourage ~10% firing rate firing_rate = spikes.mean() loss_sparse = (firing_rate - 0.1).pow(2) loss = loss_mse + 0.5 * loss_cos + 0.1 * loss_sparse optimizer.zero_grad() loss.backward() optimizer.step() epoch_mse += loss_mse.item() epoch_cos += cosine_sim(recon, emb) epoch_mse /= num_batches epoch_cos /= num_batches dt = time.time() - t0 history["train_mse"].append(epoch_mse) history["train_cos"].append(epoch_cos) history["epoch_time"].append(dt) if (epoch + 1) % 10 == 0 or epoch == 0: fr = spikes.mean().item() print(f" Epoch {epoch+1:3d}: MSE={epoch_mse:.6f}, " f"CosSim={epoch_cos:.4f}, FR={fr:.3f}, Time={dt:.1f}s") # Final eval on fresh data model.eval() with torch.no_grad(): test_emb = torch.randn(256, embed_dim, device=DEVICE) test_emb = nn.functional.normalize(test_emb, dim=-1) recon, spikes, _ = model(test_emb) final_mse = mse_loss(recon, test_emb).item() final_cos = cosine_sim(recon, test_emb) final_fr = spikes.mean().item() print(f" ** Final eval: MSE={final_mse:.6f}, CosSim={final_cos:.4f}, FR={final_fr:.3f}") return { "embed_dim": embed_dim, "num_neurons": num_neurons, "num_steps": num_steps, "param_count": param_count, "final_mse": final_mse, "final_cos": final_cos, "final_firing_rate": final_fr, "history": history, } def main(): print("=" * 60) print("Experiment 1: Encoder Roundtrip Test") print("=" * 60) configs = [ # (embed_dim, num_neurons, num_steps) # Start small, scale up if promising (256, 512, 32), (256, 1024, 32), (256, 1024, 64), (768, 2048, 64), (768, 4096, 64), (768, 4096, 128), ] all_results = [] for embed_dim, num_neurons, num_steps in configs: print(f"\n--- Config: dim={embed_dim}, neurons={num_neurons}, steps={num_steps} ---") result = run_config( embed_dim=embed_dim, num_neurons=num_neurons, num_steps=num_steps, lr=1e-3, epochs=50, batch_size=64, num_batches=20, ) all_results.append(result) torch.cuda.empty_cache() # Save results # Convert for JSON serialization for r in all_results: r["history"]["train_mse"] = [float(x) for x in r["history"]["train_mse"]] r["history"]["train_cos"] = [float(x) for x in r["history"]["train_cos"]] r["history"]["epoch_time"] = [float(x) for x in r["history"]["epoch_time"]] results_file = RESULTS_DIR / "exp01_results.json" with open(results_file, "w") as f: json.dump(all_results, f, indent=2) # Print summary table print("\n" + "=" * 80) print("SUMMARY") print("=" * 80) print(f"{'Dim':>5} {'Neurons':>8} {'Steps':>6} {'Params':>10} {'MSE':>10} {'CosSim':>8} {'FR':>6}") print("-" * 80) for r in all_results: print(f"{r['embed_dim']:>5} {r['num_neurons']:>8} {r['num_steps']:>6} " f"{r['param_count']:>10,} {r['final_mse']:>10.6f} " f"{r['final_cos']:>8.4f} {r['final_firing_rate']:>6.3f}") # Verdict best = max(all_results, key=lambda x: x["final_cos"]) print(f"\nBest config: dim={best['embed_dim']}, neurons={best['num_neurons']}, " f"steps={best['num_steps']}") print(f" CosSim={best['final_cos']:.4f}, MSE={best['final_mse']:.6f}") if best["final_cos"] > 0.9: print("\nāœ“ PASS: Roundtrip encoding is viable! CosSim > 0.9") elif best["final_cos"] > 0.7: print("\n~ MARGINAL: CosSim 0.7-0.9, might work for fuzzy associative recall") else: print("\nāœ— FAIL: Roundtrip encoding loses too much information") if __name__ == "__main__": main()