omni: gpt.forward — optional audio_features for soft-token prepend
W1 needs the GPT to consume Whisper-projector outputs as a prefix of "soft tokens" sitting in front of the text token embeddings (LLaVA-style). The change is intentionally minimal: - forward() takes a new keyword-only arg `audio_features` of shape (B, T_a, n_embd). They must already be projected to n_embd by the caller (Projector lives in nanochat.audio, kept out of GPT itself). - The audio rows are normed (matches the post-wte norm convention) and concatenated *after* smear so smear stays a strictly text-side op (its prev-token semantics aren't defined for soft tokens, and revisiting that belongs to a later phase). - Rotary embeddings are re-sliced for T_a + T_text. Audio gets positions 0..T_a-1, text 0-shifts to T_a..T_full-1. The 10× over-allocated rotary cache in __init__ already covers this. - value_embeds lookup uses an idx padded with 0 for audio positions. They feed the v residual but the gate (`ve_gate`) is input-dependent and will learn to suppress the dummy rows; for W1 smoke this is fine. - targets are auto-padded with -1 (ignore_index) over audio positions so the LM is only graded on text predictions. Not yet supported: audio_features with kv_cache. The KV-cache path is a prefill+decode protocol that would need its own audio-aware semantics; W1 runs train-style forwards only, so we just assert.
This commit is contained in:
+18
-2
@@ -413,7 +413,7 @@ class GPT(nn.Module):
|
|||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean', audio_features=None):
|
||||||
B, T = idx.size()
|
B, T = idx.size()
|
||||||
|
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||||
@@ -448,6 +448,22 @@ class GPT(nn.Module):
|
|||||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
|
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
|
||||||
x = x + gate * x_pre_smear
|
x = x + gate * x_pre_smear
|
||||||
|
|
||||||
|
# Audio soft-token prepend (LLaVA-style): audio_features must already be projected to n_embd.
|
||||||
|
idx_for_ve = idx
|
||||||
|
if audio_features is not None:
|
||||||
|
assert kv_cache is None, "audio_features prepend not supported with kv_cache"
|
||||||
|
audio = norm(audio_features.to(COMPUTE_DTYPE))
|
||||||
|
T_a = audio.size(1)
|
||||||
|
x = torch.cat([audio, x], dim=1)
|
||||||
|
T_full = T_a + T
|
||||||
|
assert T_full <= self.cos.size(1), f"Sequence length grew beyond rotary cache: {T_full} > {self.cos.size(1)}"
|
||||||
|
cos_sin = self.cos[:, :T_full], self.sin[:, :T_full]
|
||||||
|
idx_pad = torch.zeros((B, T_a), dtype=idx.dtype, device=idx.device)
|
||||||
|
idx_for_ve = torch.cat([idx_pad, idx], dim=1)
|
||||||
|
if targets is not None:
|
||||||
|
pad = torch.full((B, T_a), -1, dtype=targets.dtype, device=targets.device)
|
||||||
|
targets = torch.cat([pad, targets], dim=1)
|
||||||
|
|
||||||
# Forward the trunk of the Transformer
|
# Forward the trunk of the Transformer
|
||||||
x0 = x # save initial normalized embedding for x0 residual
|
x0 = x # save initial normalized embedding for x0 residual
|
||||||
n_layer = self.config.n_layer
|
n_layer = self.config.n_layer
|
||||||
@@ -455,7 +471,7 @@ class GPT(nn.Module):
|
|||||||
x_backout = None
|
x_backout = None
|
||||||
for i, block in enumerate(self.transformer.h):
|
for i, block in enumerate(self.transformer.h):
|
||||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
ve = self.value_embeds[str(i)](idx_for_ve).to(x.dtype) if str(i) in self.value_embeds else None
|
||||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||||
if i == backout_layer:
|
if i == backout_layer:
|
||||||
x_backout = x
|
x_backout = x
|
||||||
|
|||||||
Reference in New Issue
Block a user