diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 07a1eae..31db579 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -413,7 +413,7 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizer - def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): + def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean', audio_features=None): B, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) @@ -448,6 +448,22 @@ class GPT(nn.Module): gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24])) x = x + gate * x_pre_smear + # Audio soft-token prepend (LLaVA-style): audio_features must already be projected to n_embd. + idx_for_ve = idx + if audio_features is not None: + assert kv_cache is None, "audio_features prepend not supported with kv_cache" + audio = norm(audio_features.to(COMPUTE_DTYPE)) + T_a = audio.size(1) + x = torch.cat([audio, x], dim=1) + T_full = T_a + T + assert T_full <= self.cos.size(1), f"Sequence length grew beyond rotary cache: {T_full} > {self.cos.size(1)}" + cos_sin = self.cos[:, :T_full], self.sin[:, :T_full] + idx_pad = torch.zeros((B, T_a), dtype=idx.dtype, device=idx.device) + idx_for_ve = torch.cat([idx_pad, idx], dim=1) + if targets is not None: + pad = torch.full((B, T_a), -1, dtype=targets.dtype, device=targets.device) + targets = torch.cat([pad, targets], dim=1) + # Forward the trunk of the Transformer x0 = x # save initial normalized embedding for x0 residual n_layer = self.config.n_layer @@ -455,7 +471,7 @@ class GPT(nn.Module): x_backout = None for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None + ve = self.value_embeds[str(i)](idx_for_ve).to(x.dtype) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) if i == backout_layer: x_backout = x