fix broken import sigh
This commit is contained in:
@@ -28,7 +28,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
|||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
from nanochat.flash_attention import HAS_FA3
|
from nanochat.flash_attention import HAS_FA3
|
||||||
from scripts.base_eval import evaluate_model
|
from scripts.base_eval import evaluate_core
|
||||||
print_banner()
|
print_banner()
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -305,7 +305,7 @@ while True:
|
|||||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||||
model.eval()
|
model.eval()
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
"step": step,
|
"step": step,
|
||||||
|
|||||||
Reference in New Issue
Block a user