remove leftover mid references (#491)
This commit is contained in:
committed by
GitHub
parent
b19b4f3e49
commit
72b9064f9d
@@ -164,7 +164,6 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
|
|||||||
def load_model(source, *args, **kwargs):
|
def load_model(source, *args, **kwargs):
|
||||||
model_dir = {
|
model_dir = {
|
||||||
"base": "base_checkpoints",
|
"base": "base_checkpoints",
|
||||||
"mid": "mid_checkpoints",
|
|
||||||
"sft": "chatsft_checkpoints",
|
"sft": "chatsft_checkpoints",
|
||||||
"rl": "chatrl_checkpoints",
|
"rl": "chatrl_checkpoints",
|
||||||
}[source]
|
}[source]
|
||||||
|
|||||||
+1
-5
@@ -211,8 +211,6 @@ EXPECTED_FILES = [
|
|||||||
"base-model-training.md",
|
"base-model-training.md",
|
||||||
"base-model-loss.md",
|
"base-model-loss.md",
|
||||||
"base-model-evaluation.md",
|
"base-model-evaluation.md",
|
||||||
"midtraining.md",
|
|
||||||
"chat-evaluation-mid.md",
|
|
||||||
"chat-sft.md",
|
"chat-sft.md",
|
||||||
"chat-evaluation-sft.md",
|
"chat-evaluation-sft.md",
|
||||||
"chat-rl.md",
|
"chat-rl.md",
|
||||||
@@ -316,8 +314,6 @@ class Report:
|
|||||||
# extract the most important metrics from the sections
|
# extract the most important metrics from the sections
|
||||||
if file_name == "base-model-evaluation.md":
|
if file_name == "base-model-evaluation.md":
|
||||||
final_metrics["base"] = extract(section, "CORE")
|
final_metrics["base"] = extract(section, "CORE")
|
||||||
if file_name == "chat-evaluation-mid.md":
|
|
||||||
final_metrics["mid"] = extract(section, chat_metrics)
|
|
||||||
if file_name == "chat-evaluation-sft.md":
|
if file_name == "chat-evaluation-sft.md":
|
||||||
final_metrics["sft"] = extract(section, chat_metrics)
|
final_metrics["sft"] = extract(section, chat_metrics)
|
||||||
if file_name == "chat-evaluation-rl.md":
|
if file_name == "chat-evaluation-rl.md":
|
||||||
@@ -337,7 +333,7 @@ class Report:
|
|||||||
# Custom ordering: CORE first, ChatCORE last, rest in middle
|
# Custom ordering: CORE first, ChatCORE last, rest in middle
|
||||||
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
|
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
|
||||||
# Fixed column widths
|
# Fixed column widths
|
||||||
stages = ["base", "mid", "sft", "rl"]
|
stages = ["base", "sft", "rl"]
|
||||||
metric_width = 15
|
metric_width = 15
|
||||||
value_width = 8
|
value_width = 8
|
||||||
# Write table header
|
# Write table header
|
||||||
|
|||||||
+1
-1
@@ -12,7 +12,7 @@ from nanochat.engine import Engine
|
|||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Chat with the model')
|
parser = argparse.ArgumentParser(description='Chat with the model')
|
||||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Parse command-line arguments
|
# Parse command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
|
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
||||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||||
|
|||||||
+1
-2
@@ -38,7 +38,6 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
|
|||||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||||
# Model loading
|
# Model loading
|
||||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
|
||||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||||
# Training horizon
|
# Training horizon
|
||||||
@@ -77,7 +76,7 @@ use_dummy_wandb = args.run == "dummy" or not master_process
|
|||||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
+1
-1
@@ -62,7 +62,7 @@ MAX_MAX_TOKENS = 4096
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||||
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ LLM because it has to learn how every token (a little semantic chunk/atom)
|
|||||||
maps to the sequence of individual characters that make it up. Larger models
|
maps to the sequence of individual characters that make it up. Larger models
|
||||||
learn this eventually on their own, but if we want this capability to exist
|
learn this eventually on their own, but if we want this capability to exist
|
||||||
in smaller models, we have to actively encourage it by over-representing it
|
in smaller models, we have to actively encourage it by over-representing it
|
||||||
in the training data. Midtraining is a good place to do this.
|
in the training data. SFT is a good place to do this.
|
||||||
|
|
||||||
To preview a few example conversations, run:
|
To preview a few example conversations, run:
|
||||||
python -m tasks.spellingbee
|
python -m tasks.spellingbee
|
||||||
|
|||||||
Reference in New Issue
Block a user