Files
nanochat-omni/nanochat/audio.py
T
Fam Zheng 7cc94cf584 omni: nanochat/audio.py — frozen Whisper encoder + Projector
The audio modality module that pairs with the gpt.forward audio_features
hook. Two things live here:

WhisperEncoder: thin wrapper around transformers' WhisperModel.encoder.
- Weight loading prefers ModelScope when WHISPER_MS_ID is set (matches the
  CN-mirror policy in doc/todo.md — modelscope is first-class for model
  weights, hf-mirror is fallback). Otherwise falls back to plain HF, with
  WHISPER_HF_ID as the override and `openai/whisper-base` as the default
  (the smallest variant that still produces useful features for smoke).
- Encoder params have requires_grad=False from __init__ so they never
  appear in the optimizer's param list. Caller does not need to remember
  to freeze it.
- preprocess() runs the feature extractor; forward() takes (B, n_mels,
  T_mel) and returns last_hidden_state (B, T_enc, d_model). Whisper pads
  every clip to 30 s, so T_enc is a constant 1500 regardless of input
  duration — handy for batching, wasteful for short clips. We accept the
  waste at W1; W2 can switch to streaming-style chunking.
- Note for W3+/W5+: last_hidden_state is the most text-semantic layer.
  When we start caring about timbre / prosody / emotion ("质感感知"), we
  should expose middle layers or a learnable weighted sum across layers.

Projector: 2-layer MLP (in_dim → out_dim → out_dim) with GELU and the
nanochat Linear class so master weights stay fp32 while forward runs in
the activation dtype (bf16). fc2 is zero-initialized so the model starts
ignoring audio entirely, which gives a clean baseline before any training
signal flows through (audio path is opt-out by default, opt-in by
training).
2026-05-05 22:39:05 +01:00

117 lines
3.9 KiB
Python

"""
Audio modality for nanochat-omni (W1).
Frozen Whisper encoder produces soft tokens; Projector maps them into nanochat's
residual stream (n_embd) so they can be prepended to text token embeddings
LLaVA-style. Output remains text-only.
Weights:
- ModelScope first when WHISPER_MS_ID is set (e.g. iic/Whisper-small,
iic/Whisper-large-v3) — preferred path on CN boxes (ailab/zy/etc).
- HuggingFace fallback (honors HF_ENDPOINT for hf-mirror).
The encoder is held frozen; only Projector is trained.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.gpt import Linear
def _load_whisper_via_modelscope(ms_id):
from modelscope import snapshot_download
local_path = snapshot_download(ms_id)
from transformers import WhisperModel, WhisperFeatureExtractor
extractor = WhisperFeatureExtractor.from_pretrained(local_path)
model = WhisperModel.from_pretrained(local_path)
return extractor, model.encoder
def _load_whisper_via_hf(hf_id):
from transformers import WhisperModel, WhisperFeatureExtractor
extractor = WhisperFeatureExtractor.from_pretrained(hf_id)
model = WhisperModel.from_pretrained(hf_id)
return extractor, model.encoder
def load_whisper(hf_id="openai/whisper-base", ms_id=None):
"""Load (feature_extractor, encoder). Tries ModelScope if ms_id is given,
falls back to HuggingFace. Returns the .encoder submodule (no decoder)."""
ms_id = ms_id or os.environ.get("WHISPER_MS_ID")
hf_id = os.environ.get("WHISPER_HF_ID", hf_id)
errors = []
if ms_id:
try:
return _load_whisper_via_modelscope(ms_id)
except Exception as e:
errors.append(f"modelscope({ms_id}): {e}")
try:
return _load_whisper_via_hf(hf_id)
except Exception as e:
errors.append(f"hf({hf_id}): {e}")
raise RuntimeError("Failed to load Whisper encoder. Tried: " + " | ".join(errors))
class WhisperEncoder(nn.Module):
"""Frozen Whisper encoder. Forward takes log-mel input_features
(B, n_mels, T_mel) and returns (B, T_enc, d_model)."""
def __init__(self, hf_id="openai/whisper-base", ms_id=None, device=None, dtype=None):
super().__init__()
extractor, encoder = load_whisper(hf_id=hf_id, ms_id=ms_id)
self.feature_extractor = extractor
self.encoder = encoder
for p in self.encoder.parameters():
p.requires_grad = False
self.encoder.eval()
self._d_model = encoder.config.d_model
self.sampling_rate = extractor.sampling_rate
if device is not None or dtype is not None:
self.encoder.to(device=device, dtype=dtype)
@property
def d_model(self):
return self._d_model
def preprocess(self, audio_arrays):
"""audio_arrays: list of 1D np.float32 (mono, sampling_rate Hz).
Returns input_features tensor (B, n_mels, T_mel)."""
out = self.feature_extractor(
audio_arrays,
sampling_rate=self.sampling_rate,
return_tensors="pt",
)
return out.input_features
@torch.no_grad()
def forward(self, input_features):
out = self.encoder(input_features=input_features)
return out.last_hidden_state
class Projector(nn.Module):
"""LLaVA-style 2-layer MLP: audio_d -> hidden -> n_embd.
Uses nanochat's Linear so master weights stay fp32 while forward runs in
the activation dtype (typically bf16). Matches the convention in gpt.py.
"""
def __init__(self, in_dim, out_dim, hidden_dim=None):
super().__init__()
hidden_dim = hidden_dim or out_dim
self.fc1 = Linear(in_dim, hidden_dim, bias=False)
self.fc2 = Linear(hidden_dim, out_dim, bias=False)
s = (3.0 / in_dim) ** 0.5
torch.nn.init.uniform_(self.fc1.weight, -s, s)
torch.nn.init.zeros_(self.fc2.weight)
def forward(self, x):
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
return x