diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 501cc4a..9a80c7c 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -265,8 +265,7 @@ class GPT(nn.Module): # Forward the lm_head (compute logits) softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory - # slice to remove padding - logits = logits[..., :self.config.vocab_size] + logits = logits[..., :self.config.vocab_size] # slice to remove padding logits = logits.float() # switch to fp32 for logit softcap and loss computation logits = softcap * torch.tanh(logits / softcap) # squash the logits