diff --git a/nanochat/audio.py b/nanochat/audio.py new file mode 100644 index 0000000..b656276 --- /dev/null +++ b/nanochat/audio.py @@ -0,0 +1,116 @@ +""" +Audio modality for nanochat-omni (W1). + +Frozen Whisper encoder produces soft tokens; Projector maps them into nanochat's +residual stream (n_embd) so they can be prepended to text token embeddings +LLaVA-style. Output remains text-only. + +Weights: +- ModelScope first when WHISPER_MS_ID is set (e.g. iic/Whisper-small, + iic/Whisper-large-v3) — preferred path on CN boxes (ailab/zy/etc). +- HuggingFace fallback (honors HF_ENDPOINT for hf-mirror). + +The encoder is held frozen; only Projector is trained. +""" + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nanochat.gpt import Linear + + +def _load_whisper_via_modelscope(ms_id): + from modelscope import snapshot_download + local_path = snapshot_download(ms_id) + from transformers import WhisperModel, WhisperFeatureExtractor + extractor = WhisperFeatureExtractor.from_pretrained(local_path) + model = WhisperModel.from_pretrained(local_path) + return extractor, model.encoder + + +def _load_whisper_via_hf(hf_id): + from transformers import WhisperModel, WhisperFeatureExtractor + extractor = WhisperFeatureExtractor.from_pretrained(hf_id) + model = WhisperModel.from_pretrained(hf_id) + return extractor, model.encoder + + +def load_whisper(hf_id="openai/whisper-base", ms_id=None): + """Load (feature_extractor, encoder). Tries ModelScope if ms_id is given, + falls back to HuggingFace. Returns the .encoder submodule (no decoder).""" + ms_id = ms_id or os.environ.get("WHISPER_MS_ID") + hf_id = os.environ.get("WHISPER_HF_ID", hf_id) + errors = [] + if ms_id: + try: + return _load_whisper_via_modelscope(ms_id) + except Exception as e: + errors.append(f"modelscope({ms_id}): {e}") + try: + return _load_whisper_via_hf(hf_id) + except Exception as e: + errors.append(f"hf({hf_id}): {e}") + raise RuntimeError("Failed to load Whisper encoder. Tried: " + " | ".join(errors)) + + +class WhisperEncoder(nn.Module): + """Frozen Whisper encoder. Forward takes log-mel input_features + (B, n_mels, T_mel) and returns (B, T_enc, d_model).""" + + def __init__(self, hf_id="openai/whisper-base", ms_id=None, device=None, dtype=None): + super().__init__() + extractor, encoder = load_whisper(hf_id=hf_id, ms_id=ms_id) + self.feature_extractor = extractor + self.encoder = encoder + for p in self.encoder.parameters(): + p.requires_grad = False + self.encoder.eval() + self._d_model = encoder.config.d_model + self.sampling_rate = extractor.sampling_rate + if device is not None or dtype is not None: + self.encoder.to(device=device, dtype=dtype) + + @property + def d_model(self): + return self._d_model + + def preprocess(self, audio_arrays): + """audio_arrays: list of 1D np.float32 (mono, sampling_rate Hz). + Returns input_features tensor (B, n_mels, T_mel).""" + out = self.feature_extractor( + audio_arrays, + sampling_rate=self.sampling_rate, + return_tensors="pt", + ) + return out.input_features + + @torch.no_grad() + def forward(self, input_features): + out = self.encoder(input_features=input_features) + return out.last_hidden_state + + +class Projector(nn.Module): + """LLaVA-style 2-layer MLP: audio_d -> hidden -> n_embd. + + Uses nanochat's Linear so master weights stay fp32 while forward runs in + the activation dtype (typically bf16). Matches the convention in gpt.py. + """ + + def __init__(self, in_dim, out_dim, hidden_dim=None): + super().__init__() + hidden_dim = hidden_dim or out_dim + self.fc1 = Linear(in_dim, hidden_dim, bias=False) + self.fc2 = Linear(hidden_dim, out_dim, bias=False) + s = (3.0 / in_dim) ** 0.5 + torch.nn.init.uniform_(self.fc1.weight, -s, s) + torch.nn.init.zeros_(self.fc2.weight) + + def forward(self, x): + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + return x