fix(model): apply float32 cast before logits softcapping
This change ensures that the logits softcapping operation (tanh) is performed in float32 precision rather than bfloat16. Previously, the code cast to float32 after the tanh operation, which meant the non-linearity was computed with bfloat16 precision
This commit is contained in:
+2
-1
@@ -265,13 +265,14 @@ class GPT(nn.Module):
|
||||
# training mode: compute and return the loss
|
||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
else:
|
||||
# inference mode: compute and return the logits
|
||||
logits = self.lm_head(x)
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
return logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user