""" Audio modality for nanochat-omni (W1). Frozen Whisper encoder produces soft tokens; Projector maps them into nanochat's residual stream (n_embd) so they can be prepended to text token embeddings LLaVA-style. Output remains text-only. Weights: - ModelScope first when WHISPER_MS_ID is set (e.g. iic/Whisper-small, iic/Whisper-large-v3) — preferred path on CN boxes (ailab/zy/etc). - HuggingFace fallback (honors HF_ENDPOINT for hf-mirror). The encoder is held frozen; only Projector is trained. """ import os import torch import torch.nn as nn import torch.nn.functional as F from nanochat.gpt import Linear def _load_whisper_via_modelscope(ms_id): from modelscope import snapshot_download local_path = snapshot_download(ms_id) from transformers import WhisperModel, WhisperFeatureExtractor extractor = WhisperFeatureExtractor.from_pretrained(local_path) model = WhisperModel.from_pretrained(local_path) return extractor, model.encoder def _load_whisper_via_hf(hf_id): from transformers import WhisperModel, WhisperFeatureExtractor extractor = WhisperFeatureExtractor.from_pretrained(hf_id) model = WhisperModel.from_pretrained(hf_id) return extractor, model.encoder def load_whisper(hf_id="openai/whisper-base", ms_id=None): """Load (feature_extractor, encoder). Tries ModelScope if ms_id is given, falls back to HuggingFace. Returns the .encoder submodule (no decoder).""" ms_id = ms_id or os.environ.get("WHISPER_MS_ID") hf_id = os.environ.get("WHISPER_HF_ID", hf_id) errors = [] if ms_id: try: return _load_whisper_via_modelscope(ms_id) except Exception as e: errors.append(f"modelscope({ms_id}): {e}") try: return _load_whisper_via_hf(hf_id) except Exception as e: errors.append(f"hf({hf_id}): {e}") raise RuntimeError("Failed to load Whisper encoder. Tried: " + " | ".join(errors)) class WhisperEncoder(nn.Module): """Frozen Whisper encoder. Forward takes log-mel input_features (B, n_mels, T_mel) and returns (B, T_enc, d_model).""" def __init__(self, hf_id="openai/whisper-base", ms_id=None, device=None, dtype=None): super().__init__() extractor, encoder = load_whisper(hf_id=hf_id, ms_id=ms_id) self.feature_extractor = extractor self.encoder = encoder for p in self.encoder.parameters(): p.requires_grad = False self.encoder.eval() self._d_model = encoder.config.d_model self.sampling_rate = extractor.sampling_rate if device is not None or dtype is not None: self.encoder.to(device=device, dtype=dtype) @property def d_model(self): return self._d_model def preprocess(self, audio_arrays): """audio_arrays: list of 1D np.float32 (mono, sampling_rate Hz). Returns input_features tensor (B, n_mels, T_mel).""" out = self.feature_extractor( audio_arrays, sampling_rate=self.sampling_rate, return_tensors="pt", ) return out.input_features @torch.no_grad() def forward(self, input_features): out = self.encoder(input_features=input_features) return out.last_hidden_state class Projector(nn.Module): """LLaVA-style 2-layer MLP: audio_d -> hidden -> n_embd. Uses nanochat's Linear so master weights stay fp32 while forward runs in the activation dtype (typically bf16). Matches the convention in gpt.py. """ def __init__(self, in_dim, out_dim, hidden_dim=None): super().__init__() hidden_dim = hidden_dim or out_dim self.fc1 = Linear(in_dim, hidden_dim, bias=False) self.fc2 = Linear(hidden_dim, out_dim, bias=False) s = (3.0 / in_dim) ** 0.5 torch.nn.init.uniform_(self.fc1.weight, -s, s) torch.nn.init.zeros_(self.fc2.weight) def forward(self, x): x = self.fc1(x) x = F.gelu(x) x = self.fc2(x) return x