7cc94cf584
The audio modality module that pairs with the gpt.forward audio_features
hook. Two things live here:
WhisperEncoder: thin wrapper around transformers' WhisperModel.encoder.
- Weight loading prefers ModelScope when WHISPER_MS_ID is set (matches the
CN-mirror policy in doc/todo.md — modelscope is first-class for model
weights, hf-mirror is fallback). Otherwise falls back to plain HF, with
WHISPER_HF_ID as the override and `openai/whisper-base` as the default
(the smallest variant that still produces useful features for smoke).
- Encoder params have requires_grad=False from __init__ so they never
appear in the optimizer's param list. Caller does not need to remember
to freeze it.
- preprocess() runs the feature extractor; forward() takes (B, n_mels,
T_mel) and returns last_hidden_state (B, T_enc, d_model). Whisper pads
every clip to 30 s, so T_enc is a constant 1500 regardless of input
duration — handy for batching, wasteful for short clips. We accept the
waste at W1; W2 can switch to streaming-style chunking.
- Note for W3+/W5+: last_hidden_state is the most text-semantic layer.
When we start caring about timbre / prosody / emotion ("质感感知"), we
should expose middle layers or a learnable weighted sum across layers.
Projector: 2-layer MLP (in_dim → out_dim → out_dim) with GELU and the
nanochat Linear class so master weights stay fp32 while forward runs in
the activation dtype (bf16). fc2 is zero-initialized so the model starts
ignoring audio entirely, which gives a clean baseline before any training
signal flows through (audio path is opt-out by default, opt-in by
training).
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
"""
|
|
Audio modality for nanochat-omni (W1).
|
|
|
|
Frozen Whisper encoder produces soft tokens; Projector maps them into nanochat's
|
|
residual stream (n_embd) so they can be prepended to text token embeddings
|
|
LLaVA-style. Output remains text-only.
|
|
|
|
Weights:
|
|
- ModelScope first when WHISPER_MS_ID is set (e.g. iic/Whisper-small,
|
|
iic/Whisper-large-v3) — preferred path on CN boxes (ailab/zy/etc).
|
|
- HuggingFace fallback (honors HF_ENDPOINT for hf-mirror).
|
|
|
|
The encoder is held frozen; only Projector is trained.
|
|
"""
|
|
|
|
import os
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from nanochat.gpt import Linear
|
|
|
|
|
|
def _load_whisper_via_modelscope(ms_id):
|
|
from modelscope import snapshot_download
|
|
local_path = snapshot_download(ms_id)
|
|
from transformers import WhisperModel, WhisperFeatureExtractor
|
|
extractor = WhisperFeatureExtractor.from_pretrained(local_path)
|
|
model = WhisperModel.from_pretrained(local_path)
|
|
return extractor, model.encoder
|
|
|
|
|
|
def _load_whisper_via_hf(hf_id):
|
|
from transformers import WhisperModel, WhisperFeatureExtractor
|
|
extractor = WhisperFeatureExtractor.from_pretrained(hf_id)
|
|
model = WhisperModel.from_pretrained(hf_id)
|
|
return extractor, model.encoder
|
|
|
|
|
|
def load_whisper(hf_id="openai/whisper-base", ms_id=None):
|
|
"""Load (feature_extractor, encoder). Tries ModelScope if ms_id is given,
|
|
falls back to HuggingFace. Returns the .encoder submodule (no decoder)."""
|
|
ms_id = ms_id or os.environ.get("WHISPER_MS_ID")
|
|
hf_id = os.environ.get("WHISPER_HF_ID", hf_id)
|
|
errors = []
|
|
if ms_id:
|
|
try:
|
|
return _load_whisper_via_modelscope(ms_id)
|
|
except Exception as e:
|
|
errors.append(f"modelscope({ms_id}): {e}")
|
|
try:
|
|
return _load_whisper_via_hf(hf_id)
|
|
except Exception as e:
|
|
errors.append(f"hf({hf_id}): {e}")
|
|
raise RuntimeError("Failed to load Whisper encoder. Tried: " + " | ".join(errors))
|
|
|
|
|
|
class WhisperEncoder(nn.Module):
|
|
"""Frozen Whisper encoder. Forward takes log-mel input_features
|
|
(B, n_mels, T_mel) and returns (B, T_enc, d_model)."""
|
|
|
|
def __init__(self, hf_id="openai/whisper-base", ms_id=None, device=None, dtype=None):
|
|
super().__init__()
|
|
extractor, encoder = load_whisper(hf_id=hf_id, ms_id=ms_id)
|
|
self.feature_extractor = extractor
|
|
self.encoder = encoder
|
|
for p in self.encoder.parameters():
|
|
p.requires_grad = False
|
|
self.encoder.eval()
|
|
self._d_model = encoder.config.d_model
|
|
self.sampling_rate = extractor.sampling_rate
|
|
if device is not None or dtype is not None:
|
|
self.encoder.to(device=device, dtype=dtype)
|
|
|
|
@property
|
|
def d_model(self):
|
|
return self._d_model
|
|
|
|
def preprocess(self, audio_arrays):
|
|
"""audio_arrays: list of 1D np.float32 (mono, sampling_rate Hz).
|
|
Returns input_features tensor (B, n_mels, T_mel)."""
|
|
out = self.feature_extractor(
|
|
audio_arrays,
|
|
sampling_rate=self.sampling_rate,
|
|
return_tensors="pt",
|
|
)
|
|
return out.input_features
|
|
|
|
@torch.no_grad()
|
|
def forward(self, input_features):
|
|
out = self.encoder(input_features=input_features)
|
|
return out.last_hidden_state
|
|
|
|
|
|
class Projector(nn.Module):
|
|
"""LLaVA-style 2-layer MLP: audio_d -> hidden -> n_embd.
|
|
|
|
Uses nanochat's Linear so master weights stay fp32 while forward runs in
|
|
the activation dtype (typically bf16). Matches the convention in gpt.py.
|
|
"""
|
|
|
|
def __init__(self, in_dim, out_dim, hidden_dim=None):
|
|
super().__init__()
|
|
hidden_dim = hidden_dim or out_dim
|
|
self.fc1 = Linear(in_dim, hidden_dim, bias=False)
|
|
self.fc2 = Linear(hidden_dim, out_dim, bias=False)
|
|
s = (3.0 / in_dim) ** 0.5
|
|
torch.nn.init.uniform_(self.fc1.weight, -s, s)
|
|
torch.nn.init.zeros_(self.fc2.weight)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = F.gelu(x)
|
|
x = self.fc2(x)
|
|
return x
|