remove leftover mid references (#491)

This commit is contained in:
Sofie Van Landeghem
2026-02-02 17:33:46 +01:00
committed by GitHub
parent b19b4f3e49
commit 72b9064f9d
7 changed files with 6 additions and 12 deletions
+1 -2
View File
@@ -38,7 +38,6 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
@@ -77,7 +76,7 @@ use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
# Init model and tokenizer
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
engine = Engine(model, tokenizer) # for sampling rollouts
# -----------------------------------------------------------------------------