omni: gpt.forward — optional audio_features for soft-token prepend

W1 needs the GPT to consume Whisper-projector outputs as a prefix of "soft
tokens" sitting in front of the text token embeddings (LLaVA-style). The
change is intentionally minimal:

- forward() takes a new keyword-only arg `audio_features` of shape
  (B, T_a, n_embd). They must already be projected to n_embd by the caller
  (Projector lives in nanochat.audio, kept out of GPT itself).
- The audio rows are normed (matches the post-wte norm convention) and
  concatenated *after* smear so smear stays a strictly text-side op (its
  prev-token semantics aren't defined for soft tokens, and revisiting that
  belongs to a later phase).
- Rotary embeddings are re-sliced for T_a + T_text. Audio gets positions
  0..T_a-1, text 0-shifts to T_a..T_full-1. The 10× over-allocated rotary
  cache in __init__ already covers this.
- value_embeds lookup uses an idx padded with 0 for audio positions. They
  feed the v residual but the gate (`ve_gate`) is input-dependent and will
  learn to suppress the dummy rows; for W1 smoke this is fine.
- targets are auto-padded with -1 (ignore_index) over audio positions so
  the LM is only graded on text predictions.

Not yet supported: audio_features with kv_cache. The KV-cache path is a
prefill+decode protocol that would need its own audio-aware semantics; W1
runs train-style forwards only, so we just assert.
This commit is contained in:
Fam Zheng
2026-05-05 22:38:49 +01:00
parent 9cae824aa5
commit d760915daa
+18 -2
View File
@@ -413,7 +413,7 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"]
return optimizer
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean', audio_features=None):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
@@ -448,6 +448,22 @@ class GPT(nn.Module):
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
x = x + gate * x_pre_smear
# Audio soft-token prepend (LLaVA-style): audio_features must already be projected to n_embd.
idx_for_ve = idx
if audio_features is not None:
assert kv_cache is None, "audio_features prepend not supported with kv_cache"
audio = norm(audio_features.to(COMPUTE_DTYPE))
T_a = audio.size(1)
x = torch.cat([audio, x], dim=1)
T_full = T_a + T
assert T_full <= self.cos.size(1), f"Sequence length grew beyond rotary cache: {T_full} > {self.cos.size(1)}"
cos_sin = self.cos[:, :T_full], self.sin[:, :T_full]
idx_pad = torch.zeros((B, T_a), dtype=idx.dtype, device=idx.device)
idx_for_ve = torch.cat([idx_pad, idx], dim=1)
if targets is not None:
pad = torch.full((B, T_a), -1, dtype=targets.dtype, device=targets.device)
targets = torch.cat([pad, targets], dim=1)
# Forward the trunk of the Transformer
x0 = x # save initial normalized embedding for x0 residual
n_layer = self.config.n_layer
@@ -455,7 +471,7 @@ class GPT(nn.Module):
x_backout = None
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
ve = self.value_embeds[str(i)](idx_for_ve).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
if i == backout_layer:
x_backout = x