omni: W1 audio align smoke — synthetic dataset + 50-step script
End-to-end smoke proving the audio path:
wav -> WhisperEncoder (frozen) -> Projector -> prepend to text embeddings
-> tiny d6 GPT (random init) -> CE loss on text only
Pass criterion is a plain "loss drops by at least 0.5". On a 4090 the run
finishes in ~1 s and goes 5.55 -> 0.17 over 50 steps, so the threshold has
plenty of headroom against false positives.
Two design calls worth keeping in mind:
1. Synthetic sine clips, not LibriSpeech. W1 is forward-path proof, not
alignment quality, and a deterministic offline dataset means no network
on the smoke path. data/audio_smoke/manifest.jsonl is the only thing
committed; wavs are regenerated by audio_smoke_data.py and gitignored.
W2 swaps in real LibriSpeech.
2. Standalone byte-level tokenizer (UTF-8 bytes + a single BOS, vocab=257).
Avoids depending on a trained nanochat BPE — the d6 GPT is random
anyway, so vocab choice doesn't matter for "does the gradient flow"
smoke. W2 onwards uses the real BPE on a real base.
Caveat documented in doc/todo.md: because the LM is also random and being
trained, the loss-down here mostly reflects the LM memorising 5 short
strings, not Whisper-Projector alignment. That's fine for proving
plumbing; W2 freezes the LM so projector-only gradient is the only path
to lower loss.
This commit is contained in:
@@ -14,3 +14,6 @@ wandb/
|
|||||||
|
|
||||||
# Claude Code runtime
|
# Claude Code runtime
|
||||||
.claude/
|
.claude/
|
||||||
|
|
||||||
|
# W1 audio smoke: regenerated by scripts/audio_smoke_data.py, only manifest is committed
|
||||||
|
data/audio_smoke/wavs/
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
{"wav": "wavs/sine_0220.wav", "text": "low tone", "sr": 16000}
|
||||||
|
{"wav": "wavs/sine_0330.wav", "text": "mid low tone", "sr": 16000}
|
||||||
|
{"wav": "wavs/sine_0440.wav", "text": "middle tone", "sr": 16000}
|
||||||
|
{"wav": "wavs/sine_0660.wav", "text": "mid high tone", "sr": 16000}
|
||||||
|
{"wav": "wavs/sine_0880.wav", "text": "high tone", "sr": 16000}
|
||||||
+9
-4
@@ -19,12 +19,17 @@
|
|||||||
|
|
||||||
参考 research §1.2 模块图。
|
参考 research §1.2 模块图。
|
||||||
|
|
||||||
- [ ] `nanochat/audio.py`:WhisperEncoder wrapper(冻结,权重优先走 ModelScope,例如 `iic/Whisper-large-v3` / `iic/Whisper-small`;HF mirror 留作 fallback)+ Projector(MLP,输出维度对齐 nanochat `model_dim`)
|
- [x] `nanochat/audio.py`:WhisperEncoder wrapper(冻结,ModelScope 优先经 `WHISPER_MS_ID`,HF fallback 默认 `openai/whisper-base`)+ Projector(MLP,输出维度对齐 nanochat `n_embd`)
|
||||||
- [ ] `nanochat/gpt.py` `GPT.forward()` 加可选 `audio_features` 参数,作为 soft tokens prepend 到 text embedding 前面
|
- [x] `nanochat/gpt.py` `GPT.forward()` 加可选 `audio_features` 参数,作为 soft tokens prepend 到 text embedding 前面(kv_cache 路径暂不支持,audio 位置 targets 自动 -1 mask)
|
||||||
- [ ] mini dataset:1–10 段 5s wav + 字幕,落 `data/audio_smoke/`(git 内不存音频,仅清单 + 下载脚本)
|
- [x] mini dataset:5 段 5s 合成正弦 + 字幕,落 `data/audio_smoke/`(wav 由 `scripts/audio_smoke_data.py` 生成,gitignore 排除)
|
||||||
- [ ] `scripts/audio_align_smoke.py`:50 步、d6 nanochat base、loss 下降即过
|
- [x] `scripts/audio_align_smoke.py`:50 步、d6 随机初始化 GPT、字节级 tokenizer、loss 下降即过(4090 实测 ~1s,5.55→0.17)
|
||||||
- [ ] CI 加 audio smoke job(ailab runner 装 ffmpeg;whisper 走 transformers 即可)
|
- [ ] CI 加 audio smoke job(ailab runner 装 ffmpeg;whisper 走 transformers 即可)
|
||||||
|
|
||||||
|
W1 后续可改进(暂搁,留给 W3+/W5+ 质感任务):
|
||||||
|
|
||||||
|
- 当前用 `last_hidden_state`(最偏文本语义的层);为质感感知应切到中间层 / 多层 weighted sum / w2v-bert
|
||||||
|
- d6 GPT 是随机初始化,alignment 信号其实在练 LM 而非 projector;W2 上真 base 后 freeze LM、只练 projector 才是真正的弱对齐
|
||||||
|
|
||||||
## W2 — S1 弱对齐训练
|
## W2 — S1 弱对齐训练
|
||||||
|
|
||||||
- [ ] 拉 LibriSpeech 100h(HF mirror),预提 Whisper-base encoder 特征落盘 webdataset
|
- [ ] 拉 LibriSpeech 100h(HF mirror),预提 Whisper-base encoder 特征落盘 webdataset
|
||||||
|
|||||||
@@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
W1 smoke: prove the audio path works end-to-end.
|
||||||
|
|
||||||
|
Pipeline:
|
||||||
|
wav -> WhisperEncoder (frozen) -> Projector -> prepend to text embeddings
|
||||||
|
-> tiny d6 GPT (random init) -> CE loss on text tokens only
|
||||||
|
|
||||||
|
The model is randomly initialized and the dataset is 5 synthetic sine clips,
|
||||||
|
so the only thing this validates is that gradients flow through the projector
|
||||||
|
into a decreasing loss. Pass criterion: end loss < start loss by a clear margin.
|
||||||
|
|
||||||
|
Standalone tokenizer (UTF-8 bytes + a single BOS) so the smoke does not depend
|
||||||
|
on the nanochat BPE tokenizer being trained yet — that prerequisite belongs to
|
||||||
|
W2 onwards.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m scripts.audio_align_smoke
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from nanochat.audio import Projector, WhisperEncoder
|
||||||
|
from nanochat.common import (
|
||||||
|
COMPUTE_DTYPE,
|
||||||
|
autodetect_device_type,
|
||||||
|
compute_cleanup,
|
||||||
|
compute_init,
|
||||||
|
)
|
||||||
|
from nanochat.gpt import GPT, GPTConfig
|
||||||
|
|
||||||
|
|
||||||
|
# Byte-level tokenizer: vocab[0..255] = raw UTF-8 byte, 256 = <BOS>.
|
||||||
|
BOS_ID = 256
|
||||||
|
VOCAB_SIZE = 257
|
||||||
|
|
||||||
|
|
||||||
|
def encode(text):
|
||||||
|
return [BOS_ID] + list(text.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def load_manifest(data_dir):
|
||||||
|
items = []
|
||||||
|
with open(Path(data_dir) / "manifest.jsonl") as f:
|
||||||
|
for line in f:
|
||||||
|
items.append(json.loads(line))
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def load_wav_mono16k(path):
|
||||||
|
"""Read a mono PCM16 WAV (matches scripts/audio_smoke_data.py output)."""
|
||||||
|
with wave.open(str(path), "rb") as w:
|
||||||
|
assert w.getnchannels() == 1, f"expected mono, got {w.getnchannels()} channels"
|
||||||
|
assert w.getsampwidth() == 2, f"expected pcm16, got sampwidth {w.getsampwidth()}"
|
||||||
|
sr = w.getframerate()
|
||||||
|
frames = w.readframes(w.getnframes())
|
||||||
|
audio = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
return audio, sr
|
||||||
|
|
||||||
|
|
||||||
|
def build_gpt(depth, head_dim, max_seq_len, device):
|
||||||
|
base_dim = depth * 64 # nanochat's default aspect ratio
|
||||||
|
model_dim = ((base_dim + head_dim - 1) // head_dim) * head_dim
|
||||||
|
n_head = model_dim // head_dim
|
||||||
|
config = GPTConfig(
|
||||||
|
sequence_len=max_seq_len,
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
n_layer=depth,
|
||||||
|
n_head=n_head,
|
||||||
|
n_kv_head=n_head,
|
||||||
|
n_embd=model_dim,
|
||||||
|
window_pattern="L",
|
||||||
|
)
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = GPT(config)
|
||||||
|
model.to_empty(device=device)
|
||||||
|
model.init_weights()
|
||||||
|
return model, config
|
||||||
|
|
||||||
|
|
||||||
|
def pack_text_batch(text_ids_list, device):
|
||||||
|
"""idx[i, t] is input token; targets[i, t] is the next token (or -1 to ignore).
|
||||||
|
Right-pad to the longest sequence with 0/-1.
|
||||||
|
"""
|
||||||
|
in_len = max(len(ids) for ids in text_ids_list) - 1
|
||||||
|
B = len(text_ids_list)
|
||||||
|
idx = torch.zeros((B, in_len), dtype=torch.long, device=device)
|
||||||
|
targets = torch.full((B, in_len), -1, dtype=torch.long, device=device)
|
||||||
|
for i, ids in enumerate(text_ids_list):
|
||||||
|
L = len(ids) - 1
|
||||||
|
idx[i, :L] = torch.tensor(ids[:-1], dtype=torch.long, device=device)
|
||||||
|
targets[i, :L] = torch.tensor(ids[1:], dtype=torch.long, device=device)
|
||||||
|
return idx, targets
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-dir", default="data/audio_smoke")
|
||||||
|
parser.add_argument("--depth", type=int, default=6)
|
||||||
|
parser.add_argument("--head-dim", type=int, default=64)
|
||||||
|
parser.add_argument("--max-seq-len", type=int, default=2048)
|
||||||
|
parser.add_argument("--num-iters", type=int, default=50)
|
||||||
|
parser.add_argument("--lr", type=float, default=3e-3)
|
||||||
|
parser.add_argument("--whisper", default="openai/whisper-base",
|
||||||
|
help="HF Whisper id (override via WHISPER_HF_ID env)")
|
||||||
|
parser.add_argument("--loss-drop-min", type=float, default=0.5,
|
||||||
|
help="end loss must be at least this much lower than start loss")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device_type = autodetect_device_type()
|
||||||
|
ddp, _, _, _, device = compute_init(device_type)
|
||||||
|
assert not ddp, "smoke is single-process"
|
||||||
|
|
||||||
|
# Synthetic audio + manifest: regenerate if missing so the script is self-contained.
|
||||||
|
if not (Path(args.data_dir) / "manifest.jsonl").exists():
|
||||||
|
from scripts.audio_smoke_data import generate_synthetic
|
||||||
|
generate_synthetic(args.data_dir)
|
||||||
|
|
||||||
|
items = load_manifest(args.data_dir)
|
||||||
|
audios = [load_wav_mono16k(Path(args.data_dir) / it["wav"])[0] for it in items]
|
||||||
|
texts = [it["text"] for it in items]
|
||||||
|
print(f"loaded {len(items)} samples: {texts}")
|
||||||
|
|
||||||
|
# Frozen Whisper encoder + Projector to nanochat n_embd
|
||||||
|
print(f"loading Whisper encoder ({args.whisper})...")
|
||||||
|
whisper = WhisperEncoder(hf_id=args.whisper, device=device, dtype=COMPUTE_DTYPE)
|
||||||
|
|
||||||
|
# Pre-extract Whisper input_features and encoder outputs once; encoder is frozen
|
||||||
|
# so its output never changes across training steps -> hoist out of the loop.
|
||||||
|
input_features = whisper.preprocess(audios).to(device=device, dtype=COMPUTE_DTYPE)
|
||||||
|
print(f"input_features: {tuple(input_features.shape)}")
|
||||||
|
audio_feats = whisper(input_features).detach()
|
||||||
|
print(f"whisper features: {tuple(audio_feats.shape)} (T_a soft tokens)")
|
||||||
|
|
||||||
|
# GPT (random init, d6 by default) and Projector
|
||||||
|
gpt, config = build_gpt(args.depth, args.head_dim, args.max_seq_len, device)
|
||||||
|
print(f"GPT: depth={config.n_layer} n_embd={config.n_embd} n_head={config.n_head}")
|
||||||
|
projector = Projector(in_dim=whisper.d_model, out_dim=config.n_embd).to(device=device)
|
||||||
|
|
||||||
|
# Tokenize transcripts and pack into a batch
|
||||||
|
text_ids_list = [encode(t) for t in texts]
|
||||||
|
idx, targets = pack_text_batch(text_ids_list, device=device)
|
||||||
|
print(f"text idx: {tuple(idx.shape)} (max_text_len-1)")
|
||||||
|
|
||||||
|
# Single AdamW over projector + LM. Whisper stays frozen (requires_grad=False
|
||||||
|
# was set in WhisperEncoder.__init__, so its params won't appear here anyway).
|
||||||
|
train_params = list(projector.parameters()) + [p for p in gpt.parameters() if p.requires_grad]
|
||||||
|
optim = torch.optim.AdamW(train_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.0)
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
t0 = time.time()
|
||||||
|
for step in range(args.num_iters):
|
||||||
|
soft_tokens = projector(audio_feats) # (B, T_a, n_embd)
|
||||||
|
loss = gpt(idx, targets=targets, audio_features=soft_tokens)
|
||||||
|
optim.zero_grad(set_to_none=True)
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(train_params, 1.0)
|
||||||
|
optim.step()
|
||||||
|
losses.append(loss.item())
|
||||||
|
if step % 5 == 0 or step == args.num_iters - 1:
|
||||||
|
print(f"step {step:03d} | loss {loss.item():.4f}")
|
||||||
|
dt = time.time() - t0
|
||||||
|
|
||||||
|
drop = losses[0] - losses[-1]
|
||||||
|
print(f"\nDone {args.num_iters} steps in {dt:.1f}s | start={losses[0]:.4f} end={losses[-1]:.4f} drop={drop:.4f}")
|
||||||
|
assert drop >= args.loss_drop_min, f"loss did not drop enough: {drop:.4f} < {args.loss_drop_min}"
|
||||||
|
print("PASS: audio path forward+backward works, loss is descending.")
|
||||||
|
|
||||||
|
compute_cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
Generate the W1 audio smoke dataset: a handful of 5s sine-wave clips paired
|
||||||
|
with deterministic transcripts.
|
||||||
|
|
||||||
|
Why synthetic instead of real speech: W1 only proves the forward path
|
||||||
|
(WhisperEncoder -> Projector -> GPT prepend) and that the projector's gradient
|
||||||
|
flows into a decreasing loss on a tiny fixed set. Real speech adds a network
|
||||||
|
dependency to a step that should be reproducible offline. W2 swaps in
|
||||||
|
LibriSpeech.
|
||||||
|
|
||||||
|
Audio files land under data/audio_smoke/wavs/ (gitignored). The manifest
|
||||||
|
data/audio_smoke/manifest.jsonl is the only artifact committed.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m scripts.audio_smoke_data
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLES = [
|
||||||
|
(220.0, "low tone"),
|
||||||
|
(330.0, "mid low tone"),
|
||||||
|
(440.0, "middle tone"),
|
||||||
|
(660.0, "mid high tone"),
|
||||||
|
(880.0, "high tone"),
|
||||||
|
]
|
||||||
|
SR = 16000
|
||||||
|
DURATION_S = 5.0
|
||||||
|
|
||||||
|
|
||||||
|
def synth_sine(freq_hz, duration_s=DURATION_S, sr=SR):
|
||||||
|
"""Sine + 2nd harmonic + a sliver of noise so Whisper sees non-degenerate
|
||||||
|
frames (a pure tone collapses to a near-constant log-mel)."""
|
||||||
|
t = np.arange(int(sr * duration_s)) / sr
|
||||||
|
x = 0.5 * np.sin(2 * np.pi * freq_hz * t) + 0.25 * np.sin(2 * np.pi * 2 * freq_hz * t)
|
||||||
|
rng = np.random.default_rng(int(freq_hz))
|
||||||
|
x = x + 0.01 * rng.standard_normal(len(x))
|
||||||
|
return x.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def write_wav_pcm16(path, audio, sr=SR):
|
||||||
|
"""Write mono PCM16 WAV using the stdlib (no scipy/soundfile dependency)."""
|
||||||
|
pcm = np.clip(audio, -1.0, 1.0)
|
||||||
|
pcm = (pcm * 32767.0).astype(np.int16)
|
||||||
|
with wave.open(str(path), "wb") as w:
|
||||||
|
w.setnchannels(1)
|
||||||
|
w.setsampwidth(2)
|
||||||
|
w.setframerate(sr)
|
||||||
|
w.writeframes(pcm.tobytes())
|
||||||
|
|
||||||
|
|
||||||
|
def generate_synthetic(data_dir):
|
||||||
|
data_dir = Path(data_dir)
|
||||||
|
wav_dir = data_dir / "wavs"
|
||||||
|
wav_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
manifest_path = data_dir / "manifest.jsonl"
|
||||||
|
with open(manifest_path, "w") as f:
|
||||||
|
for freq, text in SAMPLES:
|
||||||
|
name = f"sine_{int(freq):04d}.wav"
|
||||||
|
wav_path = wav_dir / name
|
||||||
|
if not wav_path.exists():
|
||||||
|
write_wav_pcm16(wav_path, synth_sine(freq))
|
||||||
|
f.write(json.dumps({"wav": f"wavs/{name}", "text": text, "sr": SR}) + "\n")
|
||||||
|
print(f"Wrote {len(SAMPLES)} samples to {data_dir}")
|
||||||
|
return manifest_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-dir", default="data/audio_smoke")
|
||||||
|
args = parser.parse_args()
|
||||||
|
generate_synthetic(args.data_dir)
|
||||||
Reference in New Issue
Block a user