From 7cc94cf58470b2e614cf1350e7ae1a9cbc8141d0 Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Tue, 5 May 2026 22:39:05 +0100 Subject: [PATCH] =?UTF-8?q?omni:=20nanochat/audio.py=20=E2=80=94=20frozen?= =?UTF-8?q?=20Whisper=20encoder=20+=20Projector?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The audio modality module that pairs with the gpt.forward audio_features hook. Two things live here: WhisperEncoder: thin wrapper around transformers' WhisperModel.encoder. - Weight loading prefers ModelScope when WHISPER_MS_ID is set (matches the CN-mirror policy in doc/todo.md — modelscope is first-class for model weights, hf-mirror is fallback). Otherwise falls back to plain HF, with WHISPER_HF_ID as the override and `openai/whisper-base` as the default (the smallest variant that still produces useful features for smoke). - Encoder params have requires_grad=False from __init__ so they never appear in the optimizer's param list. Caller does not need to remember to freeze it. - preprocess() runs the feature extractor; forward() takes (B, n_mels, T_mel) and returns last_hidden_state (B, T_enc, d_model). Whisper pads every clip to 30 s, so T_enc is a constant 1500 regardless of input duration — handy for batching, wasteful for short clips. We accept the waste at W1; W2 can switch to streaming-style chunking. - Note for W3+/W5+: last_hidden_state is the most text-semantic layer. When we start caring about timbre / prosody / emotion ("质感感知"), we should expose middle layers or a learnable weighted sum across layers. Projector: 2-layer MLP (in_dim → out_dim → out_dim) with GELU and the nanochat Linear class so master weights stay fp32 while forward runs in the activation dtype (bf16). fc2 is zero-initialized so the model starts ignoring audio entirely, which gives a clean baseline before any training signal flows through (audio path is opt-out by default, opt-in by training). --- nanochat/audio.py | 116 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 nanochat/audio.py diff --git a/nanochat/audio.py b/nanochat/audio.py new file mode 100644 index 0000000..b656276 --- /dev/null +++ b/nanochat/audio.py @@ -0,0 +1,116 @@ +""" +Audio modality for nanochat-omni (W1). + +Frozen Whisper encoder produces soft tokens; Projector maps them into nanochat's +residual stream (n_embd) so they can be prepended to text token embeddings +LLaVA-style. Output remains text-only. + +Weights: +- ModelScope first when WHISPER_MS_ID is set (e.g. iic/Whisper-small, + iic/Whisper-large-v3) — preferred path on CN boxes (ailab/zy/etc). +- HuggingFace fallback (honors HF_ENDPOINT for hf-mirror). + +The encoder is held frozen; only Projector is trained. +""" + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nanochat.gpt import Linear + + +def _load_whisper_via_modelscope(ms_id): + from modelscope import snapshot_download + local_path = snapshot_download(ms_id) + from transformers import WhisperModel, WhisperFeatureExtractor + extractor = WhisperFeatureExtractor.from_pretrained(local_path) + model = WhisperModel.from_pretrained(local_path) + return extractor, model.encoder + + +def _load_whisper_via_hf(hf_id): + from transformers import WhisperModel, WhisperFeatureExtractor + extractor = WhisperFeatureExtractor.from_pretrained(hf_id) + model = WhisperModel.from_pretrained(hf_id) + return extractor, model.encoder + + +def load_whisper(hf_id="openai/whisper-base", ms_id=None): + """Load (feature_extractor, encoder). Tries ModelScope if ms_id is given, + falls back to HuggingFace. Returns the .encoder submodule (no decoder).""" + ms_id = ms_id or os.environ.get("WHISPER_MS_ID") + hf_id = os.environ.get("WHISPER_HF_ID", hf_id) + errors = [] + if ms_id: + try: + return _load_whisper_via_modelscope(ms_id) + except Exception as e: + errors.append(f"modelscope({ms_id}): {e}") + try: + return _load_whisper_via_hf(hf_id) + except Exception as e: + errors.append(f"hf({hf_id}): {e}") + raise RuntimeError("Failed to load Whisper encoder. Tried: " + " | ".join(errors)) + + +class WhisperEncoder(nn.Module): + """Frozen Whisper encoder. Forward takes log-mel input_features + (B, n_mels, T_mel) and returns (B, T_enc, d_model).""" + + def __init__(self, hf_id="openai/whisper-base", ms_id=None, device=None, dtype=None): + super().__init__() + extractor, encoder = load_whisper(hf_id=hf_id, ms_id=ms_id) + self.feature_extractor = extractor + self.encoder = encoder + for p in self.encoder.parameters(): + p.requires_grad = False + self.encoder.eval() + self._d_model = encoder.config.d_model + self.sampling_rate = extractor.sampling_rate + if device is not None or dtype is not None: + self.encoder.to(device=device, dtype=dtype) + + @property + def d_model(self): + return self._d_model + + def preprocess(self, audio_arrays): + """audio_arrays: list of 1D np.float32 (mono, sampling_rate Hz). + Returns input_features tensor (B, n_mels, T_mel).""" + out = self.feature_extractor( + audio_arrays, + sampling_rate=self.sampling_rate, + return_tensors="pt", + ) + return out.input_features + + @torch.no_grad() + def forward(self, input_features): + out = self.encoder(input_features=input_features) + return out.last_hidden_state + + +class Projector(nn.Module): + """LLaVA-style 2-layer MLP: audio_d -> hidden -> n_embd. + + Uses nanochat's Linear so master weights stay fp32 while forward runs in + the activation dtype (typically bf16). Matches the convention in gpt.py. + """ + + def __init__(self, in_dim, out_dim, hidden_dim=None): + super().__init__() + hidden_dim = hidden_dim or out_dim + self.fc1 = Linear(in_dim, hidden_dim, bias=False) + self.fc2 = Linear(hidden_dim, out_dim, bias=False) + s = (3.0 / in_dim) ** 0.5 + torch.nn.init.uniform_(self.fc1.weight, -s, s) + torch.nn.init.zeros_(self.fc2.weight) + + def forward(self, x): + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + return x