clean up original tokenizing_distributed_data_loader (#478)
This commit is contained in:
committed by
GitHub
parent
dc291c627f
commit
43078c347e
+6
-53
@@ -1,24 +1,19 @@
|
|||||||
"""
|
"""
|
||||||
Distributed dataloaders for pretraining.
|
Distributed dataloaders for pretraining.
|
||||||
|
|
||||||
Two implementations are provided:
|
BOS-aligned bestfit:
|
||||||
|
|
||||||
1. Original (tokenizing_distributed_data_loader):
|
|
||||||
- Streams tokens into a flat buffer, reshapes to (B, T)
|
|
||||||
- Rows may start mid-document (no guaranteed BOS at position 0)
|
|
||||||
- 100% token utilization, simple and efficient
|
|
||||||
|
|
||||||
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
|
|
||||||
- Every row starts with BOS token
|
- Every row starts with BOS token
|
||||||
- Documents packed using best-fit algorithm to minimize cropping
|
- Documents packed using best-fit algorithm to minimize cropping
|
||||||
- When no document fits remaining space, crops a document to fill exactly
|
- When no document fits remaining space, crops a document to fill exactly
|
||||||
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
||||||
|
|
||||||
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
Compared to the original tokenizing_distributed_data_loader:
|
||||||
|
BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
||||||
there are fewer "confusing" tokens in the train/val batches as every token can
|
there are fewer "confusing" tokens in the train/val batches as every token can
|
||||||
now attend back to the BOS token and sees the full context of the document.
|
now attend back to the BOS token and sees the full context of the document.
|
||||||
(2) is the new default if you have enough data.
|
|
||||||
Fallback to (1) if you have very limited data AND long documents.
|
Fallback to the original if you have very limited data AND long documents:
|
||||||
|
https://github.com/karpathy/nanochat/blob/3c3a3d7/nanochat/dataloader.py#L78-L117
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -75,48 +70,6 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
|||||||
epoch += 1
|
epoch += 1
|
||||||
|
|
||||||
|
|
||||||
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
|
||||||
"""
|
|
||||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
|
||||||
|
|
||||||
This is the original dataloader that streams tokens into a flat buffer and reshapes.
|
|
||||||
Rows may start mid-document (no guaranteed BOS at position 0).
|
|
||||||
|
|
||||||
Supports approximate resume via state_dict.
|
|
||||||
"""
|
|
||||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
|
||||||
|
|
||||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
|
||||||
needed_tokens = B * T + 1 # +1 for target at last position
|
|
||||||
bos_token = tokenizer.get_bos_token_id()
|
|
||||||
token_buffer = []
|
|
||||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
|
||||||
|
|
||||||
while True:
|
|
||||||
|
|
||||||
# Accumulate enough tokens
|
|
||||||
while len(token_buffer) < needed_tokens:
|
|
||||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
|
||||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
|
||||||
for tokens in token_lists:
|
|
||||||
token_buffer.extend(tokens)
|
|
||||||
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
|
|
||||||
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
|
|
||||||
|
|
||||||
# Package tokens into inputs and targets, yield
|
|
||||||
use_cuda = device == "cuda"
|
|
||||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
|
|
||||||
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
|
|
||||||
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
|
|
||||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
|
||||||
|
|
||||||
|
|
||||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
|
||||||
"""Helper that omits state_dict from yields."""
|
|
||||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
|
||||||
yield inputs, targets
|
|
||||||
|
|
||||||
|
|
||||||
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||||
tokenizer, B, T, split,
|
tokenizer, B, T, split,
|
||||||
tokenizer_threads=4, tokenizer_batch_size=128,
|
tokenizer_threads=4, tokenizer_batch_size=128,
|
||||||
|
|||||||
Reference in New Issue
Block a user