big, breaking change but large upside: swap previous FineWeb-EDU dataset to NVIDIA ClimbMix dataset. Requires people to download the data shards. The upside is that training GPT-2 capablity model now only takes ~2 hours, down from 2.76 hours, so this is a huge win data-wise
This commit is contained in:
+41
-9
@@ -20,19 +20,43 @@ from nanochat.common import get_base_dir
|
||||
# The specifics of the current pretraining dataset
|
||||
|
||||
# The URL on the internet where the data is hosted and downloaded from on demand
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
||||
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
|
||||
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
|
||||
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
||||
base_dir = get_base_dir()
|
||||
DATA_DIR = os.path.join(base_dir, "base_data")
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
DATA_DIR = os.path.join(base_dir, "base_data_climbmix")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# These functions are useful utilities to other modules, can/should be imported
|
||||
|
||||
def list_parquet_files(data_dir=None):
|
||||
def list_parquet_files(data_dir=None, warn_on_legacy=False):
|
||||
""" Looks into a data dir and returns full paths to all parquet files. """
|
||||
data_dir = DATA_DIR if data_dir is None else data_dir
|
||||
|
||||
# Legacy-supporting code due to the upgrade from FinewebEdu-100B to ClimbMix-400B
|
||||
# This code will eventually be deleted.
|
||||
if not os.path.exists(data_dir):
|
||||
if warn_on_legacy:
|
||||
print()
|
||||
print("=" * 80)
|
||||
print(" WARNING: DATASET UPGRADE REQUIRED")
|
||||
print("=" * 80)
|
||||
print()
|
||||
print(f" Could not find: {data_dir}")
|
||||
print()
|
||||
print(" nanochat recently switched from FinewebEdu-100B to ClimbMix-400B.")
|
||||
print(" Everyone who does `git pull` as of March 4, 2026 is expected to see this message.")
|
||||
print(" To upgrade to the new ClimbMix-400B dataset, run these two commands:")
|
||||
print()
|
||||
print(" python -m nanochat.dataset -n 170 # download ~170 shards, enough for GPT-2, adjust as desired")
|
||||
print(" python -m scripts.tok_train # re-train tokenizer on new ClimbMix data")
|
||||
print()
|
||||
print(" For now, falling back to your old FinewebEdu-100B dataset...")
|
||||
print("=" * 80)
|
||||
print()
|
||||
# attempt a fallback to the legacy data directory
|
||||
data_dir = os.path.join(base_dir, "base_data")
|
||||
|
||||
parquet_files = sorted([
|
||||
f for f in os.listdir(data_dir)
|
||||
if f.endswith('.parquet') and not f.endswith('.tmp')
|
||||
@@ -110,13 +134,21 @@ def download_single_file(index):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
||||
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
|
||||
parser = argparse.ArgumentParser(description="Download pretraining dataset shards")
|
||||
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of train shards to download (default: -1), -1 = disable")
|
||||
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
||||
args = parser.parse_args()
|
||||
|
||||
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
||||
ids_to_download = list(range(num))
|
||||
# Prepare the output directory
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
|
||||
# The way this works is that the user specifies the number of train shards to download via the -n flag.
|
||||
# In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard.
|
||||
num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD)
|
||||
ids_to_download = list(range(num_train_shards))
|
||||
ids_to_download.append(MAX_SHARD) # always download the validation shard
|
||||
|
||||
# Download the shards
|
||||
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||
print(f"Target directory: {DATA_DIR}")
|
||||
print()
|
||||
|
||||
Reference in New Issue
Block a user