W1 needs the GPT to consume Whisper-projector outputs as a prefix of "soft
tokens" sitting in front of the text token embeddings (LLaVA-style). The
change is intentionally minimal:
- forward() takes a new keyword-only arg `audio_features` of shape
(B, T_a, n_embd). They must already be projected to n_embd by the caller
(Projector lives in nanochat.audio, kept out of GPT itself).
- The audio rows are normed (matches the post-wte norm convention) and
concatenated *after* smear so smear stays a strictly text-side op (its
prev-token semantics aren't defined for soft tokens, and revisiting that
belongs to a later phase).
- Rotary embeddings are re-sliced for T_a + T_text. Audio gets positions
0..T_a-1, text 0-shifts to T_a..T_full-1. The 10× over-allocated rotary
cache in __init__ already covers this.
- value_embeds lookup uses an idx padded with 0 for audio positions. They
feed the v residual but the gate (`ve_gate`) is input-dependent and will
learn to suppress the dummy rows; for W1 smoke this is fine.
- targets are auto-padded with -1 (ignore_index) over audio positions so
the LM is only graded on text predictions.
Not yet supported: audio_features with kv_cache. The KV-cache path is a
prefill+decode protocol that would need its own audio-aware semantics; W1
runs train-style forwards only, so we just assert.
New architectural features:
- Smear: mix previous token embedding into current position via learned
gate, providing cheap bigram-like info (works in training + KV cache)
- Backout: subtract learned fraction of mid-layer residual before logit
projection to remove low-level features
Hyperparameter tuning:
- Muon momentum warmdown 0.97→0.90 during LR warmdown phase
- Non-uniform per-layer init: resid_lambdas 1.15→1.05, x0_lambdas 0.20→0.05
- c_fc init scale 0.4x, QK norm scale 1.2, sliding window seq_len/4
- Speedrun data:params ratio reduced to 8
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This change ensures that the logits softcapping operation (tanh) is performed in float32 precision rather than bfloat16. Previously, the code cast to float32 after the tanh operation, which meant the non-linearity was computed with bfloat16 precision