add support for CPU and for MPS. I had to change a few cosmetic things. I also discovered I think a bit of a bug, where I was casting wte to bfloat16 in the wrong place (the model init) instead of in init_weights
This commit is contained in:
@@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y < 0).any():
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
|
||||
Reference in New Issue
Block a user