Files
nanochat-omni/nanochat/gpt.py
T
Fam Zheng d760915daa omni: gpt.forward — optional audio_features for soft-token prepend
W1 needs the GPT to consume Whisper-projector outputs as a prefix of "soft
tokens" sitting in front of the text token embeddings (LLaVA-style). The
change is intentionally minimal:

- forward() takes a new keyword-only arg `audio_features` of shape
  (B, T_a, n_embd). They must already be projected to n_embd by the caller
  (Projector lives in nanochat.audio, kept out of GPT itself).
- The audio rows are normed (matches the post-wte norm convention) and
  concatenated *after* smear so smear stays a strictly text-side op (its
  prev-token semantics aren't defined for soft tokens, and revisiting that
  belongs to a later phase).
- Rotary embeddings are re-sliced for T_a + T_text. Audio gets positions
  0..T_a-1, text 0-shifts to T_a..T_full-1. The 10× over-allocated rotary
  cache in __init__ already covers this.
- value_embeds lookup uses an idx padded with 0 for audio positions. They
  feed the v residual but the gate (`ve_gate`) is input-dependent and will
  learn to suppress the dummy rows; for W1 smoke this is fine.
- targets are auto-padded with -1 (ignore_index) over audio positions so
  the LM is only graded on text predictions.

Not yet supported: audio_features with kv_cache. The KV-cache path is a
prefill+decode protocol that would need its own audio-aware semantics; W1
runs train-style forwards only, so we just assert.
2026-05-05 22:38:49 +01:00

27 KiB