From d760915daada6e527172278a0f724c93899ca2b6 Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Tue, 5 May 2026 22:38:49 +0100 Subject: [PATCH] =?UTF-8?q?omni:=20gpt.forward=20=E2=80=94=20optional=20au?= =?UTF-8?q?dio=5Ffeatures=20for=20soft-token=20prepend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit W1 needs the GPT to consume Whisper-projector outputs as a prefix of "soft tokens" sitting in front of the text token embeddings (LLaVA-style). The change is intentionally minimal: - forward() takes a new keyword-only arg `audio_features` of shape (B, T_a, n_embd). They must already be projected to n_embd by the caller (Projector lives in nanochat.audio, kept out of GPT itself). - The audio rows are normed (matches the post-wte norm convention) and concatenated *after* smear so smear stays a strictly text-side op (its prev-token semantics aren't defined for soft tokens, and revisiting that belongs to a later phase). - Rotary embeddings are re-sliced for T_a + T_text. Audio gets positions 0..T_a-1, text 0-shifts to T_a..T_full-1. The 10× over-allocated rotary cache in __init__ already covers this. - value_embeds lookup uses an idx padded with 0 for audio positions. They feed the v residual but the gate (`ve_gate`) is input-dependent and will learn to suppress the dummy rows; for W1 smoke this is fine. - targets are auto-padded with -1 (ignore_index) over audio positions so the LM is only graded on text predictions. Not yet supported: audio_features with kv_cache. The KV-cache path is a prefill+decode protocol that would need its own audio-aware semantics; W1 runs train-style forwards only, so we just assert. --- nanochat/gpt.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 07a1eae..31db579 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -413,7 +413,7 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizer - def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): + def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean', audio_features=None): B, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) @@ -448,6 +448,22 @@ class GPT(nn.Module): gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24])) x = x + gate * x_pre_smear + # Audio soft-token prepend (LLaVA-style): audio_features must already be projected to n_embd. + idx_for_ve = idx + if audio_features is not None: + assert kv_cache is None, "audio_features prepend not supported with kv_cache" + audio = norm(audio_features.to(COMPUTE_DTYPE)) + T_a = audio.size(1) + x = torch.cat([audio, x], dim=1) + T_full = T_a + T + assert T_full <= self.cos.size(1), f"Sequence length grew beyond rotary cache: {T_full} > {self.cos.size(1)}" + cos_sin = self.cos[:, :T_full], self.sin[:, :T_full] + idx_pad = torch.zeros((B, T_a), dtype=idx.dtype, device=idx.device) + idx_for_ve = torch.cat([idx_pad, idx], dim=1) + if targets is not None: + pad = torch.full((B, T_a), -1, dtype=targets.dtype, device=targets.device) + targets = torch.cat([pad, targets], dim=1) + # Forward the trunk of the Transformer x0 = x # save initial normalized embedding for x0 residual n_layer = self.config.n_layer @@ -455,7 +471,7 @@ class GPT(nn.Module): x_backout = None for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None + ve = self.value_embeds[str(i)](idx_for_ve).to(x.dtype) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) if i == backout_layer: x_backout = x