End-to-end smoke proving the audio path:
wav -> WhisperEncoder (frozen) -> Projector -> prepend to text embeddings
-> tiny d6 GPT (random init) -> CE loss on text only
Pass criterion is a plain "loss drops by at least 0.5". On a 4090 the run
finishes in ~1 s and goes 5.55 -> 0.17 over 50 steps, so the threshold has
plenty of headroom against false positives.
Two design calls worth keeping in mind:
1. Synthetic sine clips, not LibriSpeech. W1 is forward-path proof, not
alignment quality, and a deterministic offline dataset means no network
on the smoke path. data/audio_smoke/manifest.jsonl is the only thing
committed; wavs are regenerated by audio_smoke_data.py and gitignored.
W2 swaps in real LibriSpeech.
2. Standalone byte-level tokenizer (UTF-8 bytes + a single BOS, vocab=257).
Avoids depending on a trained nanochat BPE — the d6 GPT is random
anyway, so vocab choice doesn't matter for "does the gradient flow"
smoke. W2 onwards uses the real BPE on a real base.
Caveat documented in doc/todo.md: because the LM is also random and being
trained, the loss-down here mostly reflects the LM memorising 5 short
strings, not Whisper-Projector alignment. That's fine for proving
plumbing; W2 freezes the LM so projector-only gradient is the only path
to lower loss.
The audio modality module that pairs with the gpt.forward audio_features
hook. Two things live here:
WhisperEncoder: thin wrapper around transformers' WhisperModel.encoder.
- Weight loading prefers ModelScope when WHISPER_MS_ID is set (matches the
CN-mirror policy in doc/todo.md — modelscope is first-class for model
weights, hf-mirror is fallback). Otherwise falls back to plain HF, with
WHISPER_HF_ID as the override and `openai/whisper-base` as the default
(the smallest variant that still produces useful features for smoke).
- Encoder params have requires_grad=False from __init__ so they never
appear in the optimizer's param list. Caller does not need to remember
to freeze it.
- preprocess() runs the feature extractor; forward() takes (B, n_mels,
T_mel) and returns last_hidden_state (B, T_enc, d_model). Whisper pads
every clip to 30 s, so T_enc is a constant 1500 regardless of input
duration — handy for batching, wasteful for short clips. We accept the
waste at W1; W2 can switch to streaming-style chunking.
- Note for W3+/W5+: last_hidden_state is the most text-semantic layer.
When we start caring about timbre / prosody / emotion ("质感感知"), we
should expose middle layers or a learnable weighted sum across layers.
Projector: 2-layer MLP (in_dim → out_dim → out_dim) with GELU and the
nanochat Linear class so master weights stay fp32 while forward runs in
the activation dtype (bf16). fc2 is zero-initialized so the model starts
ignoring audio entirely, which gives a clean baseline before any training
signal flows through (audio path is opt-out by default, opt-in by
training).
W1 needs the GPT to consume Whisper-projector outputs as a prefix of "soft
tokens" sitting in front of the text token embeddings (LLaVA-style). The
change is intentionally minimal:
- forward() takes a new keyword-only arg `audio_features` of shape
(B, T_a, n_embd). They must already be projected to n_embd by the caller
(Projector lives in nanochat.audio, kept out of GPT itself).
- The audio rows are normed (matches the post-wte norm convention) and
concatenated *after* smear so smear stays a strictly text-side op (its
prev-token semantics aren't defined for soft tokens, and revisiting that
belongs to a later phase).
- Rotary embeddings are re-sliced for T_a + T_text. Audio gets positions
0..T_a-1, text 0-shifts to T_a..T_full-1. The 10× over-allocated rotary
cache in __init__ already covers this.
- value_embeds lookup uses an idx padded with 0 for audio positions. They
feed the v residual but the gate (`ve_gate`) is input-dependent and will
learn to suppress the dummy rows; for W1 smoke this is fine.
- targets are auto-padded with -1 (ignore_index) over audio positions so
the LM is only graded on text predictions.
Not yet supported: audio_features with kv_cache. The KV-cache path is a
prefill+decode protocol that would need its own audio-aware semantics; W1
runs train-style forwards only, so we just assert.
- pyproject.toml + uv.lock: pytorch-cu128/cpu indexes → mirror.sjtu.edu.cn
(aliyun lacks 2.9.1, sjtu has it)
- nanochat/dataset.py: climbmix BASE_URL → hf-mirror.com
For ailab (CN, RTX 5090) where direct pytorch.org and huggingface.co
are unreachable. Override at uv-sync time with UV_DEFAULT_INDEX env.
When swapping Float8Linear to Linear in disable_fp8 context manager,
using device=fp8_module.weight.device directly allocates new tensors
on GPU, causing unnecessary VRAM spike (~1GB for large models).
This fix uses device='meta' to avoid physical memory allocation,
then swaps in the weight tensor reference. This eliminates the
unnecessary VRAM spike during evaluation phase.
Fixes issue #592
Co-authored-by: RoomWithOutRoof <roomwithoutroof@sparklab.ai>
The bf16 cast is intentional for speed on Hopper+ GPUs, but should be
skipped on other platforms rather than blindly applied. fp16 is unstable
here due to its limited exponent range, and fp32 platforms don't benefit
from the cast. Now: bf16 when COMPUTE_DTYPE is bf16, no cast otherwise.
Inspired by PR #667.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
New architectural features:
- Smear: mix previous token embedding into current position via learned
gate, providing cheap bigram-like info (works in training + KV cache)
- Backout: subtract learned fraction of mid-layer residual before logit
projection to remove low-level features
Hyperparameter tuning:
- Muon momentum warmdown 0.97→0.90 during LR warmdown phase
- Non-uniform per-layer init: resid_lambdas 1.15→1.05, x0_lambdas 0.20→0.05
- c_fc init scale 0.4x, QK norm scale 1.2, sliding window seq_len/4
- Speedrun data:params ratio reduced to 8
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* printing steps count
* adding reply only loss for chat
* using the mask by render_conversation function of tokeniser
* undoing some changes
* putting back the comment which got removed accidently, no functionality change