formatting
This commit is contained in:
+1
-2
@@ -265,8 +265,7 @@ class GPT(nn.Module):
|
|||||||
# Forward the lm_head (compute logits)
|
# Forward the lm_head (compute logits)
|
||||||
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||||
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
||||||
# slice to remove padding
|
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
||||||
logits = logits[..., :self.config.vocab_size]
|
|
||||||
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||||
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user