initial commit
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
A number of functions that help with evaluating a base model.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
"""
|
||||
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
||||
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
|
||||
apples:apples if you change the vocab size. The way this works is that instead of just
|
||||
calculating the average loss as usual, you calculate the sum loss, and indepependently
|
||||
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
||||
the number of bytes that the target tokens represent.
|
||||
|
||||
The added complexity is so that:
|
||||
1) All "normal" tokens are normalized by the length of the token in bytes
|
||||
2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
|
||||
3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
|
||||
|
||||
In addition to evaluate_loss, we need the token_bytes tensor:
|
||||
It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
|
||||
each token id, or 0 if the token is to not be counted (e.g. special tokens).
|
||||
"""
|
||||
# record the losses
|
||||
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
|
||||
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
||||
batch_iter = iter(batches)
|
||||
for _ in range(steps):
|
||||
x, y = next(batch_iter)
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y < 0).any():
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
y_safe = torch.where(valid, y, torch.zeros_like(y))
|
||||
# map valid targets to their byte length; ignored targets contribute 0 bytes
|
||||
num_bytes2d = torch.where(
|
||||
valid,
|
||||
token_bytes[y_safe],
|
||||
torch.zeros_like(y, dtype=token_bytes.dtype)
|
||||
)
|
||||
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
||||
total_bytes += num_bytes2d.sum()
|
||||
else:
|
||||
# fast path: no ignored targets, safe to index directly
|
||||
num_bytes2d = token_bytes[y]
|
||||
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
||||
total_bytes += num_bytes2d.sum()
|
||||
# sum reduce across all ranks
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
if world_size > 1:
|
||||
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
|
||||
# move both to cpu, calculate bpb and return
|
||||
total_nats = total_nats.item()
|
||||
total_bytes = total_bytes.item()
|
||||
bpb = total_nats / (math.log(2) * total_bytes)
|
||||
return bpb
|
||||
Reference in New Issue
Block a user