From e8c30c3b199b7a9f04016110080537d3c589712d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 7 Jan 2026 22:28:53 +0000 Subject: [PATCH] add notebook used for scaling laws analysis --- dev/scaling_analysis.ipynb | 227 +++++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 dev/scaling_analysis.ipynb diff --git a/dev/scaling_analysis.ipynb b/dev/scaling_analysis.ipynb new file mode 100644 index 0000000..a196bd1 --- /dev/null +++ b/dev/scaling_analysis.ipynb @@ -0,0 +1,227 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scaling Laws Analysis\n", + "\n", + "Analyze results from `scaling_laws.sh` to find the optimal param:data ratio for nanochat." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Load results\n", + "base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n", + "results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n", + "\n", + "df = pd.read_csv(results_path)\n", + "flops_budgets = sorted(df['flops_budget'].unique())\n", + "print(f\"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets\")\n", + "print(f\"Columns: {list(df.columns)}\")\n", + "df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## IsoFLOP Curves (Ć  la Chinchilla)\n", + "\n", + "For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n", + "\n", + "# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)\n", + "ax = axes[0]\n", + "colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))\n", + "optimal_by_bpb = []\n", + "\n", + "for flops, color in zip(flops_budgets, colors):\n", + " subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n", + " ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n", + "\n", + " # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n", + " log_params = np.log10(subset['num_scaling_params'])\n", + " coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n", + " a, b, c = coeffs\n", + "\n", + " # Plot fitted curve (dashed)\n", + " log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)\n", + " fit_y = a * log_fit_x**2 + b * log_fit_x + c\n", + " ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)\n", + "\n", + " # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)\n", + " if a > 0: # parabola opens upward (has a minimum)\n", + " log_opt = -b / (2 * a)\n", + " opt_params = 10**log_opt\n", + " opt_bpb = a * log_opt**2 + b * log_opt + c\n", + " # Mark the fitted optimal\n", + " ax.scatter([opt_params], [opt_bpb], s=150, color=color,\n", + " zorder=5, edgecolors='black', linewidths=2, marker='*')\n", + " # Interpolate tokens and ratio from actual data (don't use Cā‰ˆ6ND approximation)\n", + " opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])\n", + " opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])\n", + " optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})\n", + " else:\n", + " # Fallback to raw minimum if quadratic doesn't have minimum\n", + " best_idx = subset['val_bpb'].idxmin()\n", + " best = subset.loc[best_idx]\n", + " ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n", + " zorder=5, edgecolors='black', linewidths=2)\n", + " optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n", + " 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n", + "\n", + "ax.set_xscale('log')\n", + "ax.set_xlabel('Parameters')\n", + "ax.set_ylabel('Validation Loss (bpb)')\n", + "ax.set_title('IsoFLOP Curves')\n", + "ax.legend(title='FLOPs', loc='upper right')\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "opt_df = pd.DataFrame(optimal_by_bpb)\n", + "\n", + "# Plot 2: Optimal model size vs compute (power law)\n", + "ax = axes[1]\n", + "ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')\n", + "ax.set_xlabel('FLOPs')\n", + "ax.set_ylabel('Optimal Parameters')\n", + "ax.set_title('Optimal Model Size')\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "# Fit and show power law\n", + "if len(opt_df) >= 2:\n", + " log_f = np.log10(opt_df['flops'])\n", + " log_p = np.log10(opt_df['params'])\n", + " slope, intercept = np.polyfit(log_f, log_p, 1)\n", + " fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n", + " fit_p = 10**(intercept + slope * np.log10(fit_f))\n", + " ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N āˆ C^{slope:.2f}')\n", + " ax.legend()\n", + "\n", + "# Plot 3: Optimal tokens vs compute (power law)\n", + "ax = axes[2]\n", + "ax.loglog(opt_df['flops'], opt_df['tokens'], 'o', markersize=10, color='#e74c3c')\n", + "ax.set_xlabel('FLOPs')\n", + "ax.set_ylabel('Optimal Tokens')\n", + "ax.set_title('Optimal Training Tokens')\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "# Fit and show power law\n", + "if len(opt_df) >= 2:\n", + " log_f = np.log10(opt_df['flops'])\n", + " log_t = np.log10(opt_df['tokens'])\n", + " slope, intercept = np.polyfit(log_f, log_t, 1)\n", + " fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n", + " fit_t = 10**(intercept + slope * np.log10(fit_f))\n", + " ax.plot(fit_f, fit_t, 'r--', alpha=0.7, label=f'D āˆ C^{slope:.2f}')\n", + " ax.legend()\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Print the optimal points (from quadratic fits)\n", + "print(\"\\nOptimal configurations (from quadratic fits):\")\n", + "print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n", + "print(\"-\" * 65)\n", + "for _, row in opt_df.iterrows():\n", + " print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Val BPB vs Depth and Ratio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", + "\n", + "# Plot 1: Val BPB vs Depth\n", + "ax = axes[0]\n", + "for flops in flops_budgets:\n", + " subset = df[df['flops_budget'] == flops].sort_values('depth')\n", + " ax.plot(subset['depth'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n", + " # Mark the best (lowest)\n", + " best_idx = subset['val_bpb'].idxmin()\n", + " best = subset.loc[best_idx]\n", + " ax.scatter([best['depth']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n", + "\n", + "ax.set_xlabel('Depth')\n", + "ax.set_ylabel('Val BPB (lower is better)')\n", + "ax.set_title('Validation BPB vs Model Depth')\n", + "ax.legend(title='FLOPs')\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "# Plot 2: Val BPB vs Param:Data Ratio\n", + "ax = axes[1]\n", + "for flops in flops_budgets:\n", + " subset = df[df['flops_budget'] == flops].sort_values('param_data_ratio')\n", + " ax.plot(subset['param_data_ratio'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n", + " best_idx = subset['val_bpb'].idxmin()\n", + " best = subset.loc[best_idx]\n", + " ax.scatter([best['param_data_ratio']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n", + "\n", + "ax.axvline(x=20, color='red', linestyle='--', alpha=0.5, label='Chinchilla (20)')\n", + "ax.set_xlabel('Param:Data Ratio (tokens/param)')\n", + "ax.set_ylabel('Val BPB (lower is better)')\n", + "ax.set_title('Val BPB vs Param:Data Ratio')\n", + "ax.legend(title='FLOPs')\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}