Merge branch 'master' into logo/kerning-update

This commit is contained in:
svlandeg
2025-10-29 11:45:40 +01:00
32 changed files with 2171 additions and 307 deletions
-1
View File
@@ -26,7 +26,6 @@ class DistAdamW(torch.optim.Optimizer):
grad_slices = []
for group in self.param_groups:
params: list[Tensor] = group["params"]
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
for base_i in range(len(params)):
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
+64 -10
View File
@@ -5,6 +5,8 @@ Common utilities for nanochat.
import os
import re
import logging
import fcntl
import urllib.request
import torch
import torch.distributed as dist
@@ -56,6 +58,44 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def download_file_with_lock(url, filename):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"
if os.path.exists(file_path):
return file_path
with open(lock_path, 'w') as lock_file:
# Only a single rank can acquire this lock
# All other ranks block until it is released
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
if os.path.exists(file_path):
return file_path
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read().decode('utf-8')
with open(file_path, 'w') as f:
f.write(content)
print(f"Downloaded to {file_path}")
# Clean up the lock file after the lock is released
try:
os.remove(lock_path)
except OSError:
pass # Ignore if already removed by another process
return file_path
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
@@ -89,32 +129,46 @@ def get_dist_info():
else:
return False, 0, 0, 1
def compute_init():
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp:
if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device("cuda")
device = torch.device(device_type) # mps|cpu
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
+6 -6
View File
@@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"], "split must be 'train' or 'val'"
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# infinite iterator over document batches
def document_batches():
@@ -38,12 +37,13 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
token_buffer.extend(tokens)
batch_index += 1
# Move tokens from the deque into the scratch buffer
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
# CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
yield inputs, targets
+34 -6
View File
@@ -44,12 +44,38 @@ def eval_with_timeout(formula, max_time=3):
return None
def use_calculator(expr):
"""Evaluate a math expression safely."""
"""
Evaluate a Python expression safely.
Supports both math expressions and string operations like .count()
"""
# Remove commas from numbers
expr = expr.replace(",", "")
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
# Check if it's a pure math expression (old behavior)
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # disallow power operator
return None
return eval_with_timeout(expr)
# Check if it's a string operation we support
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
if not all([x in allowed_chars for x in expr]):
return None
if "**" in expr: # for now disallow power operator, could be very expensive
# Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
'getattr', 'setattr', 'delattr', 'hasattr']
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None
# Only allow .count() method for now (can expand later)
if '.count(' not in expr:
return None
# Evaluate with timeout
return eval_with_timeout(expr)
# -----------------------------------------------------------------------------
@@ -109,9 +135,11 @@ class KVCache:
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
current_shape = list(self.kv_cache.shape)
current_shape[4] = t_needed
self.kv_cache.resize_(current_shape)
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
+5 -4
View File
@@ -146,13 +146,12 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
with caution.
"""
if maximum_memory_bytes is not None:
if platform.uname().system != "Darwin":
# These resource limit calls seem to fail on macOS (Darwin), skip?
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
if not platform.uname().system == "Darwin":
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()
@@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir
unlink = os.unlink
# Disable functionalities that can make destructive changes to the test.
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
@@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
os.unlink = unlink
def execute_code(
+7 -22
View File
@@ -48,19 +48,6 @@ def apply_rotary_emb(x, cos, sin):
out = out.to(x.dtype) # ensure input/output dtypes match
return out
def repeat_kv(x, n_rep):
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
if n_rep == 1:
return x
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
@@ -96,19 +83,16 @@ class CausalSelfAttention(nn.Module):
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Apply MQA: replicate the key/value heads for each query head
nrep = self.n_head // self.n_kv_head
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
@@ -118,7 +102,7 @@ class CausalSelfAttention(nn.Module):
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
@@ -169,8 +153,6 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self):
self.apply(self._init_weights)
@@ -184,6 +166,9 @@ class GPT(nn.Module):
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
+1 -1
View File
@@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (y < 0).any():
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
valid = y >= 0
+4
View File
@@ -283,6 +283,10 @@ class Report:
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
bloat_data = bloat_data.group(1) if bloat_data else ""
else:
start_time = None # will cause us to not write the total wall clock time
bloat_data = "[bloat data missing]"
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
# process all the individual sections
for file_name in EXPECTED_FILES:
section_file = os.path.join(report_dir, file_name)
+4 -1
View File
@@ -341,16 +341,19 @@ class RustBPETokenizer:
mask = mask[:max_tokens]
return ids, mask
def visualize_tokenization(self, ids, mask):
def visualize_tokenization(self, ids, mask, with_token_id=False):
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
RED = '\033[91m'
GREEN = '\033[92m'
RESET = '\033[0m'
GRAY = '\033[90m'
tokens = []
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
token_str = self.decode([token_id])
color = GREEN if mask_val == 1 else RED
tokens.append(f"{color}{token_str}{RESET}")
if with_token_id:
tokens.append(f"{GRAY}({token_id}){RESET}")
return '|'.join(tokens)
def render_for_completion(self, conversation):
+181 -13
View File
@@ -2,7 +2,7 @@
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
<title>NanoChat</title>
<link rel="icon" type="image/svg+xml" href="/logo.svg">
<style>
@@ -18,7 +18,7 @@
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
background-color: #ffffff;
color: #111827;
min-height: 100vh;
min-height: 100dvh;
margin: 0;
display: flex;
flex-direction: column;
@@ -108,6 +108,15 @@
background: transparent;
border: none;
padding: 0.25rem 0;
cursor: pointer;
border-radius: 0.5rem;
padding: 0.5rem;
margin-left: -0.5rem;
transition: background-color 0.2s ease;
}
.message.assistant .message-content:hover {
background-color: #f9fafb;
}
.message.user .message-content {
@@ -115,11 +124,27 @@
border-radius: 1.25rem;
padding: 0.8rem 1rem;
max-width: 65%;
cursor: pointer;
transition: background-color 0.2s ease;
}
.message.user .message-content:hover {
background-color: #e5e7eb;
}
.message.console .message-content {
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
font-size: 0.875rem;
background-color: #fafafa;
padding: 0.75rem 1rem;
color: #374151;
max-width: 80%;
}
.input-container {
background-color: #ffffff;
padding: 1rem;
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
}
.input-wrapper {
@@ -255,6 +280,8 @@
let messages = [];
let isGenerating = false;
let currentTemperature = 0.8;
let currentTopK = 50;
chatInput.addEventListener('input', function() {
this.style.height = 'auto';
@@ -289,7 +316,7 @@
chatInput.focus();
}
function addMessage(role, content) {
function addMessage(role, content, messageIndex = null) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${role}`;
@@ -297,6 +324,28 @@
contentDiv.className = 'message-content';
contentDiv.textContent = content;
// Add click handler for user messages to enable editing
if (role === 'user' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to edit and restart from here');
contentDiv.addEventListener('click', function() {
if (!isGenerating) {
editMessage(messageIndex);
}
});
}
// Add click handler for assistant messages to enable regeneration
if (role === 'assistant' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to regenerate this response');
contentDiv.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(messageIndex);
}
});
}
messageDiv.appendChild(contentDiv);
chatWrapper.appendChild(messageDiv);
@@ -304,17 +353,35 @@
return contentDiv;
}
async function sendMessage() {
const message = chatInput.value.trim();
if (!message || isGenerating) return;
function editMessage(messageIndex) {
// Find the message in the messages array
if (messageIndex < 0 || messageIndex >= messages.length) return;
isGenerating = true;
chatInput.value = '';
const messageToEdit = messages[messageIndex];
if (messageToEdit.role !== 'user') return;
// Copy message content to input
chatInput.value = messageToEdit.content;
chatInput.style.height = 'auto';
sendButton.disabled = true;
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
messages.push({ role: 'user', content: message });
addMessage('user', message);
// Remove this message and all subsequent messages from the array
messages = messages.slice(0, messageIndex);
// Remove message elements from DOM starting from messageIndex
const allMessages = chatWrapper.querySelectorAll('.message');
for (let i = messageIndex; i < allMessages.length; i++) {
allMessages[i].remove();
}
// Enable send button and focus input
sendButton.disabled = false;
chatInput.focus();
}
async function generateAssistantResponse() {
isGenerating = true;
sendButton.disabled = true;
const assistantContent = addMessage('assistant', '');
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
@@ -327,8 +394,8 @@
},
body: JSON.stringify({
messages: messages,
stream: true,
temperature: 0.8,
temperature: currentTemperature,
top_k: currentTopK,
max_tokens: 512
}),
});
@@ -364,8 +431,18 @@
}
}
const assistantMessageIndex = messages.length;
messages.push({ role: 'assistant', content: fullResponse });
// Add click handler to regenerate this assistant message
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
assistantContent.setAttribute('title', 'Click to regenerate this response');
assistantContent.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(assistantMessageIndex);
}
});
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
@@ -375,6 +452,97 @@
}
}
async function regenerateMessage(messageIndex) {
// Find the message in the messages array
if (messageIndex < 0 || messageIndex >= messages.length) return;
const messageToRegenerate = messages[messageIndex];
if (messageToRegenerate.role !== 'assistant') return;
// Remove this message and all subsequent messages from the array
messages = messages.slice(0, messageIndex);
// Remove message elements from DOM starting from messageIndex
const allMessages = chatWrapper.querySelectorAll('.message');
for (let i = messageIndex; i < allMessages.length; i++) {
allMessages[i].remove();
}
// Regenerate the assistant response
await generateAssistantResponse();
}
function handleSlashCommand(command) {
const parts = command.trim().split(/\s+/);
const cmd = parts[0].toLowerCase();
const arg = parts[1];
if (cmd === '/temperature') {
if (arg === undefined) {
addMessage('console', `Current temperature: ${currentTemperature}`);
} else {
const temp = parseFloat(arg);
if (isNaN(temp) || temp < 0 || temp > 2) {
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
} else {
currentTemperature = temp;
addMessage('console', `Temperature set to ${currentTemperature}`);
}
}
return true;
} else if (cmd === '/topk') {
if (arg === undefined) {
addMessage('console', `Current top-k: ${currentTopK}`);
} else {
const topk = parseInt(arg);
if (isNaN(topk) || topk < 1 || topk > 200) {
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
} else {
currentTopK = topk;
addMessage('console', `Top-k set to ${currentTopK}`);
}
}
return true;
} else if (cmd === '/clear') {
newConversation();
return true;
} else if (cmd === '/help') {
addMessage('console',
'Available commands:\n' +
'/temperature - Show current temperature\n' +
'/temperature <value> - Set temperature (0.0-2.0)\n' +
'/topk - Show current top-k\n' +
'/topk <value> - Set top-k (1-200)\n' +
'/clear - Clear conversation\n' +
'/help - Show this help message'
);
return true;
}
return false;
}
async function sendMessage() {
const message = chatInput.value.trim();
if (!message || isGenerating) return;
// Handle slash commands
if (message.startsWith('/')) {
chatInput.value = '';
chatInput.style.height = 'auto';
handleSlashCommand(message);
return;
}
chatInput.value = '';
chatInput.style.height = 'auto';
const userMessageIndex = messages.length;
messages.push({ role: 'user', content: message });
addMessage('user', message, userMessageIndex);
await generateAssistantResponse();
}
sendButton.disabled = false;
// Autofocus the chat input on page load