diff --git a/.gitignore b/.gitignore index 81f3f0c..b0a6848 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ wandb/ # Claude Code runtime .claude/ + +# W1 audio smoke: regenerated by scripts/audio_smoke_data.py, only manifest is committed +data/audio_smoke/wavs/ diff --git a/data/audio_smoke/manifest.jsonl b/data/audio_smoke/manifest.jsonl new file mode 100644 index 0000000..8f09949 --- /dev/null +++ b/data/audio_smoke/manifest.jsonl @@ -0,0 +1,5 @@ +{"wav": "wavs/sine_0220.wav", "text": "low tone", "sr": 16000} +{"wav": "wavs/sine_0330.wav", "text": "mid low tone", "sr": 16000} +{"wav": "wavs/sine_0440.wav", "text": "middle tone", "sr": 16000} +{"wav": "wavs/sine_0660.wav", "text": "mid high tone", "sr": 16000} +{"wav": "wavs/sine_0880.wav", "text": "high tone", "sr": 16000} diff --git a/doc/todo.md b/doc/todo.md index d4fd94c..a93899d 100644 --- a/doc/todo.md +++ b/doc/todo.md @@ -19,12 +19,17 @@ 参考 research §1.2 模块图。 -- [ ] `nanochat/audio.py`:WhisperEncoder wrapper(冻结,权重优先走 ModelScope,例如 `iic/Whisper-large-v3` / `iic/Whisper-small`;HF mirror 留作 fallback)+ Projector(MLP,输出维度对齐 nanochat `model_dim`) -- [ ] `nanochat/gpt.py` `GPT.forward()` 加可选 `audio_features` 参数,作为 soft tokens prepend 到 text embedding 前面 -- [ ] mini dataset:1–10 段 5s wav + 字幕,落 `data/audio_smoke/`(git 内不存音频,仅清单 + 下载脚本) -- [ ] `scripts/audio_align_smoke.py`:50 步、d6 nanochat base、loss 下降即过 +- [x] `nanochat/audio.py`:WhisperEncoder wrapper(冻结,ModelScope 优先经 `WHISPER_MS_ID`,HF fallback 默认 `openai/whisper-base`)+ Projector(MLP,输出维度对齐 nanochat `n_embd`) +- [x] `nanochat/gpt.py` `GPT.forward()` 加可选 `audio_features` 参数,作为 soft tokens prepend 到 text embedding 前面(kv_cache 路径暂不支持,audio 位置 targets 自动 -1 mask) +- [x] mini dataset:5 段 5s 合成正弦 + 字幕,落 `data/audio_smoke/`(wav 由 `scripts/audio_smoke_data.py` 生成,gitignore 排除) +- [x] `scripts/audio_align_smoke.py`:50 步、d6 随机初始化 GPT、字节级 tokenizer、loss 下降即过(4090 实测 ~1s,5.55→0.17) - [ ] CI 加 audio smoke job(ailab runner 装 ffmpeg;whisper 走 transformers 即可) +W1 后续可改进(暂搁,留给 W3+/W5+ 质感任务): + +- 当前用 `last_hidden_state`(最偏文本语义的层);为质感感知应切到中间层 / 多层 weighted sum / w2v-bert +- d6 GPT 是随机初始化,alignment 信号其实在练 LM 而非 projector;W2 上真 base 后 freeze LM、只练 projector 才是真正的弱对齐 + ## W2 — S1 弱对齐训练 - [ ] 拉 LibriSpeech 100h(HF mirror),预提 Whisper-base encoder 特征落盘 webdataset diff --git a/scripts/audio_align_smoke.py b/scripts/audio_align_smoke.py new file mode 100644 index 0000000..49bb89d --- /dev/null +++ b/scripts/audio_align_smoke.py @@ -0,0 +1,179 @@ +""" +W1 smoke: prove the audio path works end-to-end. + +Pipeline: + wav -> WhisperEncoder (frozen) -> Projector -> prepend to text embeddings + -> tiny d6 GPT (random init) -> CE loss on text tokens only + +The model is randomly initialized and the dataset is 5 synthetic sine clips, +so the only thing this validates is that gradients flow through the projector +into a decreasing loss. Pass criterion: end loss < start loss by a clear margin. + +Standalone tokenizer (UTF-8 bytes + a single BOS) so the smoke does not depend +on the nanochat BPE tokenizer being trained yet — that prerequisite belongs to +W2 onwards. + +Usage: + python -m scripts.audio_align_smoke +""" + +import argparse +import json +import time +import wave +from pathlib import Path + +import numpy as np +import torch + +from nanochat.audio import Projector, WhisperEncoder +from nanochat.common import ( + COMPUTE_DTYPE, + autodetect_device_type, + compute_cleanup, + compute_init, +) +from nanochat.gpt import GPT, GPTConfig + + +# Byte-level tokenizer: vocab[0..255] = raw UTF-8 byte, 256 = . +BOS_ID = 256 +VOCAB_SIZE = 257 + + +def encode(text): + return [BOS_ID] + list(text.encode("utf-8")) + + +def load_manifest(data_dir): + items = [] + with open(Path(data_dir) / "manifest.jsonl") as f: + for line in f: + items.append(json.loads(line)) + return items + + +def load_wav_mono16k(path): + """Read a mono PCM16 WAV (matches scripts/audio_smoke_data.py output).""" + with wave.open(str(path), "rb") as w: + assert w.getnchannels() == 1, f"expected mono, got {w.getnchannels()} channels" + assert w.getsampwidth() == 2, f"expected pcm16, got sampwidth {w.getsampwidth()}" + sr = w.getframerate() + frames = w.readframes(w.getnframes()) + audio = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 + return audio, sr + + +def build_gpt(depth, head_dim, max_seq_len, device): + base_dim = depth * 64 # nanochat's default aspect ratio + model_dim = ((base_dim + head_dim - 1) // head_dim) * head_dim + n_head = model_dim // head_dim + config = GPTConfig( + sequence_len=max_seq_len, + vocab_size=VOCAB_SIZE, + n_layer=depth, + n_head=n_head, + n_kv_head=n_head, + n_embd=model_dim, + window_pattern="L", + ) + with torch.device("meta"): + model = GPT(config) + model.to_empty(device=device) + model.init_weights() + return model, config + + +def pack_text_batch(text_ids_list, device): + """idx[i, t] is input token; targets[i, t] is the next token (or -1 to ignore). + Right-pad to the longest sequence with 0/-1. + """ + in_len = max(len(ids) for ids in text_ids_list) - 1 + B = len(text_ids_list) + idx = torch.zeros((B, in_len), dtype=torch.long, device=device) + targets = torch.full((B, in_len), -1, dtype=torch.long, device=device) + for i, ids in enumerate(text_ids_list): + L = len(ids) - 1 + idx[i, :L] = torch.tensor(ids[:-1], dtype=torch.long, device=device) + targets[i, :L] = torch.tensor(ids[1:], dtype=torch.long, device=device) + return idx, targets + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-dir", default="data/audio_smoke") + parser.add_argument("--depth", type=int, default=6) + parser.add_argument("--head-dim", type=int, default=64) + parser.add_argument("--max-seq-len", type=int, default=2048) + parser.add_argument("--num-iters", type=int, default=50) + parser.add_argument("--lr", type=float, default=3e-3) + parser.add_argument("--whisper", default="openai/whisper-base", + help="HF Whisper id (override via WHISPER_HF_ID env)") + parser.add_argument("--loss-drop-min", type=float, default=0.5, + help="end loss must be at least this much lower than start loss") + args = parser.parse_args() + + device_type = autodetect_device_type() + ddp, _, _, _, device = compute_init(device_type) + assert not ddp, "smoke is single-process" + + # Synthetic audio + manifest: regenerate if missing so the script is self-contained. + if not (Path(args.data_dir) / "manifest.jsonl").exists(): + from scripts.audio_smoke_data import generate_synthetic + generate_synthetic(args.data_dir) + + items = load_manifest(args.data_dir) + audios = [load_wav_mono16k(Path(args.data_dir) / it["wav"])[0] for it in items] + texts = [it["text"] for it in items] + print(f"loaded {len(items)} samples: {texts}") + + # Frozen Whisper encoder + Projector to nanochat n_embd + print(f"loading Whisper encoder ({args.whisper})...") + whisper = WhisperEncoder(hf_id=args.whisper, device=device, dtype=COMPUTE_DTYPE) + + # Pre-extract Whisper input_features and encoder outputs once; encoder is frozen + # so its output never changes across training steps -> hoist out of the loop. + input_features = whisper.preprocess(audios).to(device=device, dtype=COMPUTE_DTYPE) + print(f"input_features: {tuple(input_features.shape)}") + audio_feats = whisper(input_features).detach() + print(f"whisper features: {tuple(audio_feats.shape)} (T_a soft tokens)") + + # GPT (random init, d6 by default) and Projector + gpt, config = build_gpt(args.depth, args.head_dim, args.max_seq_len, device) + print(f"GPT: depth={config.n_layer} n_embd={config.n_embd} n_head={config.n_head}") + projector = Projector(in_dim=whisper.d_model, out_dim=config.n_embd).to(device=device) + + # Tokenize transcripts and pack into a batch + text_ids_list = [encode(t) for t in texts] + idx, targets = pack_text_batch(text_ids_list, device=device) + print(f"text idx: {tuple(idx.shape)} (max_text_len-1)") + + # Single AdamW over projector + LM. Whisper stays frozen (requires_grad=False + # was set in WhisperEncoder.__init__, so its params won't appear here anyway). + train_params = list(projector.parameters()) + [p for p in gpt.parameters() if p.requires_grad] + optim = torch.optim.AdamW(train_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.0) + + losses = [] + t0 = time.time() + for step in range(args.num_iters): + soft_tokens = projector(audio_feats) # (B, T_a, n_embd) + loss = gpt(idx, targets=targets, audio_features=soft_tokens) + optim.zero_grad(set_to_none=True) + loss.backward() + torch.nn.utils.clip_grad_norm_(train_params, 1.0) + optim.step() + losses.append(loss.item()) + if step % 5 == 0 or step == args.num_iters - 1: + print(f"step {step:03d} | loss {loss.item():.4f}") + dt = time.time() - t0 + + drop = losses[0] - losses[-1] + print(f"\nDone {args.num_iters} steps in {dt:.1f}s | start={losses[0]:.4f} end={losses[-1]:.4f} drop={drop:.4f}") + assert drop >= args.loss_drop_min, f"loss did not drop enough: {drop:.4f} < {args.loss_drop_min}" + print("PASS: audio path forward+backward works, loss is descending.") + + compute_cleanup() + + +if __name__ == "__main__": + main() diff --git a/scripts/audio_smoke_data.py b/scripts/audio_smoke_data.py new file mode 100644 index 0000000..fb7de37 --- /dev/null +++ b/scripts/audio_smoke_data.py @@ -0,0 +1,78 @@ +""" +Generate the W1 audio smoke dataset: a handful of 5s sine-wave clips paired +with deterministic transcripts. + +Why synthetic instead of real speech: W1 only proves the forward path +(WhisperEncoder -> Projector -> GPT prepend) and that the projector's gradient +flows into a decreasing loss on a tiny fixed set. Real speech adds a network +dependency to a step that should be reproducible offline. W2 swaps in +LibriSpeech. + +Audio files land under data/audio_smoke/wavs/ (gitignored). The manifest +data/audio_smoke/manifest.jsonl is the only artifact committed. + +Usage: + python -m scripts.audio_smoke_data +""" + +import argparse +import json +import wave +from pathlib import Path + +import numpy as np + + +SAMPLES = [ + (220.0, "low tone"), + (330.0, "mid low tone"), + (440.0, "middle tone"), + (660.0, "mid high tone"), + (880.0, "high tone"), +] +SR = 16000 +DURATION_S = 5.0 + + +def synth_sine(freq_hz, duration_s=DURATION_S, sr=SR): + """Sine + 2nd harmonic + a sliver of noise so Whisper sees non-degenerate + frames (a pure tone collapses to a near-constant log-mel).""" + t = np.arange(int(sr * duration_s)) / sr + x = 0.5 * np.sin(2 * np.pi * freq_hz * t) + 0.25 * np.sin(2 * np.pi * 2 * freq_hz * t) + rng = np.random.default_rng(int(freq_hz)) + x = x + 0.01 * rng.standard_normal(len(x)) + return x.astype(np.float32) + + +def write_wav_pcm16(path, audio, sr=SR): + """Write mono PCM16 WAV using the stdlib (no scipy/soundfile dependency).""" + pcm = np.clip(audio, -1.0, 1.0) + pcm = (pcm * 32767.0).astype(np.int16) + with wave.open(str(path), "wb") as w: + w.setnchannels(1) + w.setsampwidth(2) + w.setframerate(sr) + w.writeframes(pcm.tobytes()) + + +def generate_synthetic(data_dir): + data_dir = Path(data_dir) + wav_dir = data_dir / "wavs" + wav_dir.mkdir(parents=True, exist_ok=True) + manifest_path = data_dir / "manifest.jsonl" + with open(manifest_path, "w") as f: + for freq, text in SAMPLES: + name = f"sine_{int(freq):04d}.wav" + wav_path = wav_dir / name + if not wav_path.exists(): + write_wav_pcm16(wav_path, synth_sine(freq)) + f.write(json.dumps({"wav": f"wavs/{name}", "text": text, "sr": SR}) + "\n") + print(f"Wrote {len(SAMPLES)} samples to {data_dir}") + return manifest_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-dir", default="data/audio_smoke") + args = parser.parse_args() + generate_synthetic(args.data_dir)