Autoresearch round 2: smear, backout, and hyperparameter tuning
New architectural features: - Smear: mix previous token embedding into current position via learned gate, providing cheap bigram-like info (works in training + KV cache) - Backout: subtract learned fraction of mid-layer residual before logit projection to remove low-level features Hyperparameter tuning: - Muon momentum warmdown 0.97→0.90 during LR warmdown phase - Non-uniform per-layer init: resid_lambdas 1.15→1.05, x0_lambdas 0.20→0.05 - c_fc init scale 0.4x, QK norm scale 1.2, sliding window seq_len/4 - Speedrun data:params ratio reduced to 8 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -100,10 +100,13 @@ class KVCache:
|
||||
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
# Current sequence length per batch element (FA3 needs int32)
|
||||
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
# Previous token's normalized embedding for smear (set by model forward pass)
|
||||
self.prev_embedding = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset cache to empty state."""
|
||||
self.cache_seqlens.zero_()
|
||||
self.prev_embedding = None
|
||||
|
||||
def get_pos(self):
|
||||
"""Get current position (assumes all batch elements at same position)."""
|
||||
@@ -129,6 +132,9 @@ class KVCache:
|
||||
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
||||
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
||||
self.cache_seqlens.fill_(other_pos)
|
||||
# Copy smear state: expand batch=1 prev_embedding to num_samples
|
||||
if other.prev_embedding is not None:
|
||||
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user