many small tweaks. base, eval, core work now i think
This commit is contained in:
+10
-6
@@ -15,6 +15,7 @@ import time
|
||||
import json
|
||||
import random
|
||||
import yaml
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -118,18 +119,21 @@ def load_hf_model(hf_path: str, device):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if len(sys.argv) >= 2:
|
||||
if args.hf_path is not None:
|
||||
# atm assume that if a path is given, it's a huggingface model path
|
||||
hf_path = sys.argv[1]
|
||||
hf_path = args.hf_path
|
||||
print0(f"Loading huggingface model from: {hf_path}")
|
||||
model, tokenizer = load_hf_model(hf_path, device)
|
||||
model_name = hf_path # just for logging
|
||||
@@ -142,7 +146,7 @@ def main():
|
||||
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
out = evaluate_model(model, tokenizer, device)
|
||||
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write out the results to a csv file
|
||||
core_metric = None
|
||||
|
||||
Reference in New Issue
Block a user