fix: sample first token independently for each row in multi-sample generation

Previously, when generating multiple samples (num_samples > 1), the first
token after prefill was sampled once and broadcast to all rows, causing
all samples to start identically. Now the prefill logits are expanded to
num_samples and sampled independently for each row.

Also simplified the generation loop by moving the forward pass to the end
of the loop, eliminating the first_iteration flag and if/else branching.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Andrej Karpathy
2025-12-28 04:52:13 +00:00
parent 2f2d7ab80c
commit 8f979a8bda
2 changed files with 135 additions and 24 deletions
+13 -23
View File
@@ -19,7 +19,7 @@ from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
from contextlib import nullcontext
# -----------------------------------------------------------------------------
# Calculator tool helpers
@@ -107,23 +107,23 @@ class KVCache:
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
# Extract dimensions explicitly
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
# Validate dimensions
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
@@ -223,9 +223,7 @@ class Engine:
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
@@ -242,7 +240,6 @@ class Engine:
# 4) Main generation loop
num_generated = 0
first_iteration = True
while True:
# Stop condition: we've reached max tokens
if max_tokens is not None and num_generated >= max_tokens:
@@ -251,18 +248,9 @@ class Engine:
if all(state.completed for state in row_states):
break
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
first_iteration = False
else:
# Forward the model and get the next token for each row
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) at last time step
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Sample the next token for each row
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose the next token, update state, optional tool use
token_column = [] # contains the next token id along each row
@@ -299,8 +287,10 @@ class Engine:
# Yield the token column
yield token_column, token_masks
num_generated += 1
# Prepare ids for next iteration
# Prepare logits for next iteration
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
def generate_batch(self, tokens, num_samples=1, **kwargs):
"""