delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
+12
-18
@@ -10,8 +10,6 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -185,7 +183,6 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
||||
parser.add_argument('-n', '--num-samples', type=int, default=1)
|
||||
@@ -199,8 +196,6 @@ if __name__ == "__main__":
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
@@ -220,19 +215,18 @@ if __name__ == "__main__":
|
||||
# Run all the task evaluations sequentially
|
||||
results = {}
|
||||
for task_name in task_names:
|
||||
with autocast_ctx:
|
||||
acc = run_chat_eval(
|
||||
task_name,
|
||||
model, tokenizer, engine,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.num_samples,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_problems=args.max_problems,
|
||||
)
|
||||
results[task_name] = acc
|
||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||
acc = run_chat_eval(
|
||||
task_name,
|
||||
model, tokenizer, engine,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.num_samples,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_problems=args.max_problems,
|
||||
)
|
||||
results[task_name] = acc
|
||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
||||
Reference in New Issue
Block a user