add engram-lite, add log, tune scaling laws analysis scripts
This commit is contained in:
+134
@@ -4,6 +4,140 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
|
||||
|
||||
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2506.08046) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
|
||||
|
||||
### Background
|
||||
|
||||
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
|
||||
|
||||
### What We Tried
|
||||
|
||||
**1. Full Engram module with context-aware gating (paper design)**
|
||||
```python
|
||||
# Hash bigrams to retrieve embeddings, then gate with hidden state
|
||||
e = embed(hash(prev_token, curr_token))
|
||||
q = RMSNorm(h) # hidden state as query
|
||||
k = RMSNorm(W_k @ e) # projected embedding as key
|
||||
v = W_v @ e
|
||||
α = sigmoid(q · k / √d) # scalar gate per position
|
||||
output = α * v
|
||||
```
|
||||
- Injected after block 1 (paper found early injection optimal)
|
||||
- Slight improvement, but quite a bit of complexity added.
|
||||
|
||||
**2. Early-layer only injection**
|
||||
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
|
||||
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
|
||||
|
||||
**3. Trigrams**
|
||||
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
|
||||
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
|
||||
|
||||
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
|
||||
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
|
||||
- Zero-init embedding table, learned per-layer lambdas
|
||||
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
|
||||
- **Result:** This simple approach works and provides a consistent improvement.
|
||||
|
||||
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
|
||||
|
||||
```python
|
||||
class BigramEmbed(nn.Module):
|
||||
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
|
||||
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
|
||||
|
||||
def forward(self, idx):
|
||||
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
|
||||
return self.embed(h)
|
||||
```
|
||||
|
||||
As for optimal hyperparameters:
|
||||
|
||||
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
|
||||
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
|
||||
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
|
||||
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
|
||||
- **Optimizer:** AdamW with same LR as token embeddings
|
||||
|
||||
### Key Learnings
|
||||
|
||||
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
|
||||
|
||||
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
|
||||
|
||||
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
|
||||
|
||||
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
|
||||
|
||||
### Parameters Added
|
||||
|
||||
For d12 model with `table_multiplier=5`:
|
||||
- Bigram embedding: 32768 × 5 × 768 = ~126M params
|
||||
- Per-layer lambdas: 12 scalars (negligible)
|
||||
|
||||
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
|
||||
|
||||
```
|
||||
Parameter counts:
|
||||
wte : 25,165,824
|
||||
bigram_embed : 125,829,120
|
||||
value_embeds : 150,994,944
|
||||
lm_head : 25,165,824
|
||||
transformer_matrices : 84,935,808
|
||||
scalars : 36
|
||||
total : 412,091,556
|
||||
```
|
||||
|
||||
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
|
||||
|
||||
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
|
||||
|
||||
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
|
||||
|
||||
```
|
||||
Kaplan-style (all projections including lm_head and no embeddings)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 110,678,115 1,241,505,403 11.2 0.8972
|
||||
2e+18 167,797,457 1,785,336,422 10.7 0.8616
|
||||
5e+18 250,650,865 2,642,234,152 10.8 0.8293
|
||||
1e+19 381,758,347 3,806,871,243 10.3 0.7999
|
||||
|
||||
N \propto C^0.54, D \propto C^0.49
|
||||
|
||||
Chinchilla-style (all parameters, period.)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 416,320,605 1,232,157,011 3.0 0.8974
|
||||
2e+18 560,239,841 1,763,669,281 3.2 0.8616
|
||||
5e+18 741,495,903 2,629,909,368 3.6 0.8291
|
||||
1e+19 988,644,331 3,884,841,895 4.0 0.7999
|
||||
|
||||
N \propto C^0.37, D \propto C^0.50
|
||||
|
||||
Transformer-only-style (only the projections inside the transformer)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 80,259,665 1,315,639,547 17.2 0.8966
|
||||
2e+18 131,488,566 1,864,134,141 14.5 0.8622
|
||||
5e+18 220,985,474 2,595,328,843 12.1 0.8302
|
||||
1e+19 401,213,504 3,328,704,512 8.5 0.7994
|
||||
|
||||
N \propto C^0.70, D \propto C^0.41
|
||||
```
|
||||
|
||||
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
|
||||
|
||||
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
|
||||
|
||||
+155
-9
@@ -15,14 +15,16 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# Load results\n",
|
||||
"tag = \"jan26\"\n",
|
||||
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
|
||||
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
|
||||
"results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(results_path)\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())\n",
|
||||
@@ -31,6 +33,99 @@
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# FILTERING: Remove incomplete or problematic runs\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"print(f\"Before filtering: {len(df)} runs\")\n",
|
||||
"\n",
|
||||
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
|
||||
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
|
||||
"\n",
|
||||
"# Optional: exclude specific flops budgets that aren't done yet\n",
|
||||
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
|
||||
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
|
||||
"\n",
|
||||
"# Optional: exclude specific depths\n",
|
||||
"# exclude_depths = [18, 20]\n",
|
||||
"# df = df[~df['depth'].isin(exclude_depths)]\n",
|
||||
"\n",
|
||||
"print(f\"After filtering: {len(df)} runs\")\n",
|
||||
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
|
||||
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
|
||||
"\n",
|
||||
"# Update flops_budgets list after filtering\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Effective Parameter Count\n",
|
||||
"\n",
|
||||
"Different scaling law papers use different conventions for counting parameters:\n",
|
||||
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
|
||||
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
|
||||
"\n",
|
||||
"Our CSV now has granular counts:\n",
|
||||
"- `params_wte` - token embedding (lookup table)\n",
|
||||
"- `params_bigram_embed` - bigram hash embeddings (lookup table)\n",
|
||||
"- `params_value_embeds` - value embeddings (lookup table)\n",
|
||||
"- `params_lm_head` - unembedding projection (matmul)\n",
|
||||
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
|
||||
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
|
||||
"\n",
|
||||
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"def compute_effective_params(row):\n",
|
||||
" \"\"\"\n",
|
||||
" Compute the 'effective' parameter count for scaling law analysis.\n",
|
||||
"\n",
|
||||
" Modify this function to experiment with different conventions:\n",
|
||||
" - Chinchilla-style: include everything\n",
|
||||
" - Kaplan-style: exclude embeddings\n",
|
||||
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
|
||||
" - etc.\n",
|
||||
" \"\"\"\n",
|
||||
" # Option 1: Chinchilla-style (all params)\n",
|
||||
" # return row['params_total']\n",
|
||||
"\n",
|
||||
" # Option 2: Kaplan-style (exclude embeddings)\n",
|
||||
" return row['params_transformer'] + row['params_lm_head']\n",
|
||||
"\n",
|
||||
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
|
||||
" # return row['params_transformer']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compute derived columns\n",
|
||||
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
|
||||
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
|
||||
"\n",
|
||||
"# Show parameter breakdown for first few rows\n",
|
||||
"print(\"Parameter breakdown (first row per flops budget):\")\n",
|
||||
"param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n",
|
||||
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
|
||||
"df.groupby('flops_budget').first()[param_cols]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -54,11 +149,11 @@
|
||||
"optimal_by_bpb = []\n",
|
||||
"\n",
|
||||
"for flops, color in zip(flops_budgets, colors):\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
|
||||
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
|
||||
" ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
|
||||
"\n",
|
||||
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
|
||||
" log_params = np.log10(subset['num_scaling_params'])\n",
|
||||
" log_params = np.log10(subset['effective_params'])\n",
|
||||
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
|
||||
" a, b, c = coeffs\n",
|
||||
"\n",
|
||||
@@ -83,13 +178,13 @@
|
||||
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
|
||||
" ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
|
||||
" zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
|
||||
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
|
||||
"\n",
|
||||
"ax.set_xscale('log')\n",
|
||||
"ax.set_xlabel('Parameters')\n",
|
||||
"ax.set_xlabel('Effective Parameters')\n",
|
||||
"ax.set_ylabel('Validation Loss (bpb)')\n",
|
||||
"ax.set_title('IsoFLOP Curves')\n",
|
||||
"ax.legend(title='FLOPs', loc='upper right')\n",
|
||||
@@ -138,10 +233,61 @@
|
||||
"\n",
|
||||
"# Print the optimal points (from quadratic fits)\n",
|
||||
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
|
||||
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
|
||||
"print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
|
||||
"print(\"-\" * 65)\n",
|
||||
"for _, row in opt_df.iterrows():\n",
|
||||
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
|
||||
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# Optimal Ratio Summary (from power law fits)\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"# From the power law fits: N ∝ C^a and D ∝ C^b\n",
|
||||
"# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n",
|
||||
"\n",
|
||||
"if len(opt_df) >= 2:\n",
|
||||
" log_f = np.log10(opt_df['flops'])\n",
|
||||
" log_p = np.log10(opt_df['params'])\n",
|
||||
" log_t = np.log10(opt_df['tokens'])\n",
|
||||
"\n",
|
||||
" # Fit power laws\n",
|
||||
" slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n",
|
||||
" slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n",
|
||||
"\n",
|
||||
" # The ratio D/N at a reference compute (geometric mean of our budgets)\n",
|
||||
" ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n",
|
||||
" log_ref = np.log10(ref_flops)\n",
|
||||
"\n",
|
||||
" # Predicted optimal N and D at reference compute\n",
|
||||
" pred_log_n = intercept_n + slope_n * log_ref\n",
|
||||
" pred_log_d = intercept_d + slope_d * log_ref\n",
|
||||
" optimal_ratio = 10**(pred_log_d - pred_log_n)\n",
|
||||
"\n",
|
||||
" # Also compute from the fitted optimals directly (mean and std)\n",
|
||||
" mean_ratio = opt_df['ratio'].mean()\n",
|
||||
" std_ratio = opt_df['ratio'].std()\n",
|
||||
"\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
" print(\"OPTIMAL RATIO SUMMARY\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
" print(f\"\\nPower law exponents:\")\n",
|
||||
" print(f\" N ∝ C^{slope_n:.3f}\")\n",
|
||||
" print(f\" D ∝ C^{slope_d:.3f}\")\n",
|
||||
" print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n",
|
||||
" print(f\"\\nOptimal ratio (tokens per effective param):\")\n",
|
||||
" print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n",
|
||||
" print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n",
|
||||
" print(f\" Chinchilla reference: 20\")\n",
|
||||
" print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Need at least 2 flops budgets to compute power law fits\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user