""" W1 smoke: prove the audio path works end-to-end. Pipeline: wav -> WhisperEncoder (frozen) -> Projector -> prepend to text embeddings -> tiny d6 GPT (random init) -> CE loss on text tokens only The model is randomly initialized and the dataset is 5 synthetic sine clips, so the only thing this validates is that gradients flow through the projector into a decreasing loss. Pass criterion: end loss < start loss by a clear margin. Standalone tokenizer (UTF-8 bytes + a single BOS) so the smoke does not depend on the nanochat BPE tokenizer being trained yet — that prerequisite belongs to W2 onwards. Usage: python -m scripts.audio_align_smoke """ import argparse import json import time import wave from pathlib import Path import numpy as np import torch from nanochat.audio import Projector, WhisperEncoder from nanochat.common import ( COMPUTE_DTYPE, autodetect_device_type, compute_cleanup, compute_init, ) from nanochat.gpt import GPT, GPTConfig # Byte-level tokenizer: vocab[0..255] = raw UTF-8 byte, 256 = . BOS_ID = 256 VOCAB_SIZE = 257 def encode(text): return [BOS_ID] + list(text.encode("utf-8")) def load_manifest(data_dir): items = [] with open(Path(data_dir) / "manifest.jsonl") as f: for line in f: items.append(json.loads(line)) return items def load_wav_mono16k(path): """Read a mono PCM16 WAV (matches scripts/audio_smoke_data.py output).""" with wave.open(str(path), "rb") as w: assert w.getnchannels() == 1, f"expected mono, got {w.getnchannels()} channels" assert w.getsampwidth() == 2, f"expected pcm16, got sampwidth {w.getsampwidth()}" sr = w.getframerate() frames = w.readframes(w.getnframes()) audio = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 return audio, sr def build_gpt(depth, head_dim, max_seq_len, device): base_dim = depth * 64 # nanochat's default aspect ratio model_dim = ((base_dim + head_dim - 1) // head_dim) * head_dim n_head = model_dim // head_dim config = GPTConfig( sequence_len=max_seq_len, vocab_size=VOCAB_SIZE, n_layer=depth, n_head=n_head, n_kv_head=n_head, n_embd=model_dim, window_pattern="L", ) with torch.device("meta"): model = GPT(config) model.to_empty(device=device) model.init_weights() return model, config def pack_text_batch(text_ids_list, device): """idx[i, t] is input token; targets[i, t] is the next token (or -1 to ignore). Right-pad to the longest sequence with 0/-1. """ in_len = max(len(ids) for ids in text_ids_list) - 1 B = len(text_ids_list) idx = torch.zeros((B, in_len), dtype=torch.long, device=device) targets = torch.full((B, in_len), -1, dtype=torch.long, device=device) for i, ids in enumerate(text_ids_list): L = len(ids) - 1 idx[i, :L] = torch.tensor(ids[:-1], dtype=torch.long, device=device) targets[i, :L] = torch.tensor(ids[1:], dtype=torch.long, device=device) return idx, targets def main(): parser = argparse.ArgumentParser() parser.add_argument("--data-dir", default="data/audio_smoke") parser.add_argument("--depth", type=int, default=6) parser.add_argument("--head-dim", type=int, default=64) parser.add_argument("--max-seq-len", type=int, default=2048) parser.add_argument("--num-iters", type=int, default=50) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--whisper", default="openai/whisper-base", help="HF Whisper id (override via WHISPER_HF_ID env)") parser.add_argument("--loss-drop-min", type=float, default=0.5, help="end loss must be at least this much lower than start loss") args = parser.parse_args() device_type = autodetect_device_type() ddp, _, _, _, device = compute_init(device_type) assert not ddp, "smoke is single-process" # Synthetic audio + manifest: regenerate if missing so the script is self-contained. if not (Path(args.data_dir) / "manifest.jsonl").exists(): from scripts.audio_smoke_data import generate_synthetic generate_synthetic(args.data_dir) items = load_manifest(args.data_dir) audios = [load_wav_mono16k(Path(args.data_dir) / it["wav"])[0] for it in items] texts = [it["text"] for it in items] print(f"loaded {len(items)} samples: {texts}") # Frozen Whisper encoder + Projector to nanochat n_embd print(f"loading Whisper encoder ({args.whisper})...") whisper = WhisperEncoder(hf_id=args.whisper, device=device, dtype=COMPUTE_DTYPE) # Pre-extract Whisper input_features and encoder outputs once; encoder is frozen # so its output never changes across training steps -> hoist out of the loop. input_features = whisper.preprocess(audios).to(device=device, dtype=COMPUTE_DTYPE) print(f"input_features: {tuple(input_features.shape)}") audio_feats = whisper(input_features).detach() print(f"whisper features: {tuple(audio_feats.shape)} (T_a soft tokens)") # GPT (random init, d6 by default) and Projector gpt, config = build_gpt(args.depth, args.head_dim, args.max_seq_len, device) print(f"GPT: depth={config.n_layer} n_embd={config.n_embd} n_head={config.n_head}") projector = Projector(in_dim=whisper.d_model, out_dim=config.n_embd).to(device=device) # Tokenize transcripts and pack into a batch text_ids_list = [encode(t) for t in texts] idx, targets = pack_text_batch(text_ids_list, device=device) print(f"text idx: {tuple(idx.shape)} (max_text_len-1)") # Single AdamW over projector + LM. Whisper stays frozen (requires_grad=False # was set in WhisperEncoder.__init__, so its params won't appear here anyway). train_params = list(projector.parameters()) + [p for p in gpt.parameters() if p.requires_grad] optim = torch.optim.AdamW(train_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.0) losses = [] t0 = time.time() for step in range(args.num_iters): soft_tokens = projector(audio_feats) # (B, T_a, n_embd) loss = gpt(idx, targets=targets, audio_features=soft_tokens) optim.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(train_params, 1.0) optim.step() losses.append(loss.item()) if step % 5 == 0 or step == args.num_iters - 1: print(f"step {step:03d} | loss {loss.item():.4f}") dt = time.time() - t0 drop = losses[0] - losses[-1] print(f"\nDone {args.num_iters} steps in {dt:.1f}s | start={losses[0]:.4f} end={losses[-1]:.4f} drop={drop:.4f}") assert drop >= args.loss_drop_min, f"loss did not drop enough: {drop:.4f} < {args.loss_drop_min}" print("PASS: audio path forward+backward works, loss is descending.") compute_cleanup() if __name__ == "__main__": main()