add alternating window size patterns for the GPT layers, following GPT-3. Experimented a bit and found the pattern SSSL to work well - 3 short, 1 long alternating. This is now the new default and the plots look quite a bit better on flops vs. bpb
This commit is contained in:
+16
@@ -4,6 +4,22 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## 2026-01-11: Sliding Window Attention
|
||||||
|
|
||||||
|
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
|
||||||
|
|
||||||
|
**Pattern string configuration:**
|
||||||
|
- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field
|
||||||
|
- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`)
|
||||||
|
- Final layer always forced to L (full context) regardless of pattern
|
||||||
|
- Short window = `sequence_len // 2`
|
||||||
|
- Long window = `sequence_len` (full context)
|
||||||
|
- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys`
|
||||||
|
|
||||||
|
Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 2026-01-11: Flash Attention 3 Integration
|
## 2026-01-11: Flash Attention 3 Integration
|
||||||
|
|
||||||
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
|
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
|
||||||
|
|||||||
@@ -20,6 +20,12 @@ def log0(message):
|
|||||||
if int(os.environ.get('RANK', 0)) == 0:
|
if int(os.environ.get('RANK', 0)) == 0:
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
|
def _patch_missing_config_keys(model_config_kwargs):
|
||||||
|
"""Add default values for new config keys missing in old checkpoints."""
|
||||||
|
# Old models were trained with full context (no sliding window)
|
||||||
|
if "window_pattern" not in model_config_kwargs:
|
||||||
|
model_config_kwargs["window_pattern"] = "L"
|
||||||
|
|
||||||
def _patch_missing_keys(model_data, model_config):
|
def _patch_missing_keys(model_data, model_config):
|
||||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||||
n_layer = model_config.n_layer
|
n_layer = model_config.n_layer
|
||||||
@@ -84,6 +90,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||||
model_config_kwargs = meta_data["model_config"]
|
model_config_kwargs = meta_data["model_config"]
|
||||||
|
_patch_missing_config_keys(model_config_kwargs)
|
||||||
log0(f"Building model with config: {model_config_kwargs}")
|
log0(f"Building model with config: {model_config_kwargs}")
|
||||||
model_config = GPTConfig(**model_config_kwargs)
|
model_config = GPTConfig(**model_config_kwargs)
|
||||||
_patch_missing_keys(model_data, model_config)
|
_patch_missing_keys(model_data, model_config)
|
||||||
|
|||||||
+58
-12
@@ -39,6 +39,10 @@ class GPTConfig:
|
|||||||
n_head: int = 6 # number of query heads
|
n_head: int = 6 # number of query heads
|
||||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||||
n_embd: int = 768
|
n_embd: int = 768
|
||||||
|
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||||
|
# Characters: L=long (full context), S=short (half context)
|
||||||
|
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||||
|
window_pattern: str = "L"
|
||||||
|
|
||||||
|
|
||||||
def norm(x):
|
def norm(x):
|
||||||
@@ -69,7 +73,7 @@ class CausalSelfAttention(nn.Module):
|
|||||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||||
|
|
||||||
def forward(self, x, cos_sin, kv_cache):
|
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||||
B, T, C = x.size()
|
B, T, C = x.size()
|
||||||
|
|
||||||
# Project the input to get queries, keys, and values
|
# Project the input to get queries, keys, and values
|
||||||
@@ -85,9 +89,10 @@ class CausalSelfAttention(nn.Module):
|
|||||||
|
|
||||||
# Attention with Flash Attention 3
|
# Attention with Flash Attention 3
|
||||||
# FA3 handles GQA automatically when n_kv_heads < n_heads
|
# FA3 handles GQA automatically when n_kv_heads < n_heads
|
||||||
|
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||||
if kv_cache is None:
|
if kv_cache is None:
|
||||||
# Training: simple causal attention
|
# Training: causal attention with optional sliding window
|
||||||
y = flash_attn.flash_attn_func(q, k, v, causal=True)
|
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||||
else:
|
else:
|
||||||
# Inference: use flash_attn_with_kvcache which handles cache management
|
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||||
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||||
@@ -96,6 +101,7 @@ class CausalSelfAttention(nn.Module):
|
|||||||
k=k, v=v,
|
k=k, v=v,
|
||||||
cache_seqlens=kv_cache.cache_seqlens,
|
cache_seqlens=kv_cache.cache_seqlens,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
)
|
)
|
||||||
# Advance position after last layer processes
|
# Advance position after last layer processes
|
||||||
if self.layer_idx == kv_cache.n_layers - 1:
|
if self.layer_idx == kv_cache.n_layers - 1:
|
||||||
@@ -126,8 +132,8 @@ class Block(nn.Module):
|
|||||||
self.attn = CausalSelfAttention(config, layer_idx)
|
self.attn = CausalSelfAttention(config, layer_idx)
|
||||||
self.mlp = MLP(config)
|
self.mlp = MLP(config)
|
||||||
|
|
||||||
def forward(self, x, cos_sin, kv_cache):
|
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
|
||||||
x = x + self.mlp(norm(x))
|
x = x + self.mlp(norm(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -141,11 +147,14 @@ class GPT(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
|
# Compute per-layer window sizes for sliding window attention
|
||||||
|
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
||||||
|
self.window_sizes = self._compute_window_sizes(config)
|
||||||
|
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
||||||
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
||||||
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
||||||
if padded_vocab_size != config.vocab_size:
|
if padded_vocab_size != config.vocab_size:
|
||||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
|
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
||||||
self.transformer = nn.ModuleDict({
|
self.transformer = nn.ModuleDict({
|
||||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||||
@@ -228,6 +237,35 @@ class GPT(nn.Module):
|
|||||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
def _compute_window_sizes(self, config):
|
||||||
|
"""
|
||||||
|
Compute per-layer window sizes for sliding window attention.
|
||||||
|
|
||||||
|
Returns list of (left, right) tuples for FA3's window_size parameter:
|
||||||
|
- left: how many tokens before current position to attend to (-1 = unlimited)
|
||||||
|
- right: how many tokens after current position to attend to (0 for causal)
|
||||||
|
|
||||||
|
Pattern string is tiled across layers. Final layer always gets L (full context).
|
||||||
|
Characters: L=long (full context), S=short (half context)
|
||||||
|
"""
|
||||||
|
pattern = config.window_pattern.upper()
|
||||||
|
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||||
|
# Map characters to window sizes
|
||||||
|
long_window = config.sequence_len
|
||||||
|
short_window = long_window // 2
|
||||||
|
char_to_window = {
|
||||||
|
"L": (long_window, 0),
|
||||||
|
"S": (short_window, 0),
|
||||||
|
}
|
||||||
|
# Tile pattern across layers
|
||||||
|
window_sizes = []
|
||||||
|
for layer_idx in range(config.n_layer):
|
||||||
|
char = pattern[layer_idx % len(pattern)]
|
||||||
|
window_sizes.append(char_to_window[char])
|
||||||
|
# Final layer always gets full context
|
||||||
|
window_sizes[-1] = (long_window, 0)
|
||||||
|
return window_sizes
|
||||||
|
|
||||||
def get_device(self):
|
def get_device(self):
|
||||||
return self.transformer.wte.weight.device
|
return self.transformer.wte.weight.device
|
||||||
|
|
||||||
@@ -236,16 +274,24 @@ class GPT(nn.Module):
|
|||||||
Return the estimated FLOPs per token for the model (forward + backward).
|
Return the estimated FLOPs per token for the model (forward + backward).
|
||||||
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
||||||
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
||||||
On top of that, the term 12 * l * h * q * t accounts for key @ query matmul flops inside attention.
|
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
||||||
|
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
||||||
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
||||||
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
||||||
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
||||||
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
||||||
"""
|
"""
|
||||||
nparams = sum(p.numel() for p in self.parameters())
|
nparams = sum(p.numel() for p in self.parameters())
|
||||||
nparams_embedding = self.transformer.wte.weight.numel()
|
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||||
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||||
|
# Sum attention FLOPs per layer, accounting for sliding window
|
||||||
|
attn_flops = 0
|
||||||
|
for window_size in self.window_sizes:
|
||||||
|
window = window_size[0] # (left, right) tuple, we use left
|
||||||
|
effective_seq = t if window < 0 else min(window, t)
|
||||||
|
attn_flops += 12 * h * q * effective_seq
|
||||||
|
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
||||||
return num_flops_per_token
|
return num_flops_per_token
|
||||||
|
|
||||||
def num_scaling_params(self):
|
def num_scaling_params(self):
|
||||||
@@ -311,7 +357,7 @@ class GPT(nn.Module):
|
|||||||
x0 = x # save initial normalized embedding for x0 residual
|
x0 = x # save initial normalized embedding for x0 residual
|
||||||
for i, block in enumerate(self.transformer.h):
|
for i, block in enumerate(self.transformer.h):
|
||||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||||
x = block(x, cos_sin, kv_cache)
|
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
|
||||||
x = norm(x)
|
x = norm(x)
|
||||||
|
|
||||||
# Forward the lm_head (compute logits)
|
# Forward the lm_head (compute logits)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ parser.add_argument("--depth", type=int, default=20, help="depth of the Transfor
|
|||||||
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||||
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
|
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
|
||||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||||
|
parser.add_argument("--window_pattern", type=str, default="L", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||||
# Training horizon (only one used, in order of precedence)
|
# Training horizon (only one used, in order of precedence)
|
||||||
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||||
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||||
@@ -139,7 +140,7 @@ if args.depth != 12:
|
|||||||
# Initialize the Model
|
# Initialize the Model
|
||||||
|
|
||||||
# Create a new model with random weights
|
# Create a new model with random weights
|
||||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||||
model_config = GPTConfig(**model_config_kwargs)
|
model_config = GPTConfig(**model_config_kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user