diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index c1636b1..6636f54 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration token_buffer = deque() # we stream tokens on the right and pop from the left - scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) # infinite iterator over document batches def document_batches(): @@ -38,8 +37,8 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz token_buffer.extend(tokens) batch_index += 1 # Move tokens from the deque into the scratch buffer - for i in range(needed_tokens): - scratch[i] = token_buffer.popleft() + tokens = [token_buffer.popleft() for _ in range(needed_tokens)] + scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True) # Create the inputs/targets as 1D tensors inputs_cpu = scratch[:-1].to(dtype=torch.int32) targets_cpu = scratch[1:]