nuke midtraining from orbit, it's not as needed now that we have a BOS-aligned dataloader. Also change the README a lot. midtrianing is not yet fully properly erased across the board, but good enough for step 1
This commit is contained in:
+13
-37
@@ -1,14 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script is the "Best ChatGPT clone that $100 can buy",
|
||||
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
|
||||
# This script is configured to train your own GPT-2 grade LLM (pretraining + finetuning)
|
||||
# It is designed to run on a blank 8XH100 GPU node and takes approximately 3 hours to complete.
|
||||
|
||||
# 1) Example launch (simplest):
|
||||
# bash speedrun.sh
|
||||
# 2) Example launch in a screen session (because the run takes ~4 hours):
|
||||
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# bash runs/speedrun.sh
|
||||
# 2) Example launch in a screen session (because the run takes ~3 hours):
|
||||
# screen -L -Logfile runs/speedrun.log -S speedrun bash runs/speedrun.sh
|
||||
# 3) Example launch with wandb logging, but see below for setting up wandb first:
|
||||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# WANDB_RUN=speedrun screen -L -Logfile runs/speedrun.log -S speedrun bash runs/speedrun.sh
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
export OMP_NUM_THREADS=1
|
||||
@@ -49,13 +49,14 @@ python -m nanochat.report reset
|
||||
# Tokenizer
|
||||
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# each data shard is ~250M chars
|
||||
# so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 370 is the right number here
|
||||
# Approximately 350 shards are needed for 10B tokens of data for pretraining.
|
||||
# The maximum total number of shards available in the entire dataset is 1822.
|
||||
python -m nanochat.dataset -n 370 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data
|
||||
@@ -65,43 +66,27 @@ python -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
||||
# The d20 model is 561M parameters.
|
||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
|
||||
# so 240 / (1 - 0.35) = 370 shards are needed.
|
||||
# At ~100MB/shard, this downloads ~37GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
|
||||
# d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12)
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=24 --target-param-data-ratio=12 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
# SFT (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
||||
# see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# run midtraining and eval the model
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
# run SFT and eval the model
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
@@ -111,15 +96,6 @@ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -
|
||||
# even better, chat with your model over a pretty WebUI ChatGPT style
|
||||
# python -m scripts.chat_web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Reinforcement Learning. Optional, and currently only on GSM8K
|
||||
# (optional)
|
||||
|
||||
# run reinforcement learning
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# eval the RL model only on GSM8K
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generate the full report by putting together all the sections
|
||||
# report.md is the output and will be copied to current directory for convenience
|
||||
|
||||
Reference in New Issue
Block a user