add alternating window size patterns for the GPT layers, following GPT-3. Experimented a bit and found the pattern SSSL to work well - 3 short, 1 long alternating. This is now the new default and the plots look quite a bit better on flops vs. bpb
This commit is contained in:
@@ -20,6 +20,12 @@ def log0(message):
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def _patch_missing_config_keys(model_config_kwargs):
|
||||
"""Add default values for new config keys missing in old checkpoints."""
|
||||
# Old models were trained with full context (no sliding window)
|
||||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
n_layer = model_config.n_layer
|
||||
@@ -84,6 +90,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
_patch_missing_config_keys(model_config_kwargs)
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
_patch_missing_keys(model_data, model_config)
|
||||
|
||||
Reference in New Issue
Block a user