group common code to make things neater in gpt logit computation
This commit is contained in:
+8
-10
@@ -260,20 +260,18 @@ class GPT(nn.Module):
|
|||||||
x = norm(x)
|
x = norm(x)
|
||||||
|
|
||||||
# Forward the lm_head (compute logits)
|
# Forward the lm_head (compute logits)
|
||||||
softcap = 15
|
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||||
|
logits = self.lm_head(x) # (B, T, vocab_size) <- very big tensor, large amount of memory
|
||||||
|
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||||
|
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
||||||
|
|
||||||
if targets is not None:
|
if targets is not None:
|
||||||
# training mode: compute and return the loss
|
# training: given the targets, compute and return the loss
|
||||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
# TODO experiment with chunked cross-entropy?
|
||||||
logits = self.lm_head(x)
|
|
||||||
logits = logits.float() # use tf32/fp32 for logits
|
|
||||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
# inference mode: compute and return the logits
|
# inference: just return the logits directly
|
||||||
logits = self.lm_head(x)
|
|
||||||
logits = logits.float() # use tf32/fp32 for logits
|
|
||||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
Reference in New Issue
Block a user