diff --git a/nanochat/ui.html b/nanochat/ui.html
index 39e608f..264a654 100644
--- a/nanochat/ui.html
+++ b/nanochat/ui.html
@@ -327,7 +327,6 @@
},
body: JSON.stringify({
messages: messages,
- stream: true,
temperature: 0.8,
max_tokens: 512
}),
diff --git a/scripts/chat_web.py b/scripts/chat_web.py
index 1a4cfe2..2643417 100644
--- a/scripts/chat_web.py
+++ b/scripts/chat_web.py
@@ -1,26 +1,46 @@
#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
-Run with: python web_chat.py
-Then open http://localhost:8000 in your browser.
+
+Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
+a full copy of the model, and incoming requests are distributed to available workers.
+
+Launch examples:
+
+- single available GPU (default)
+python -m scripts.chat_web
+
+- 4 GPUs
+python -m scripts.chat_web --num-gpus 4
+
+To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
+
+Endpoints:
+ GET / - Chat UI
+ POST /chat/completions - Chat API (streaming only)
+ GET /health - Health check with worker pool status
+ GET /stats - Worker pool statistics and GPU utilization
"""
import argparse
import json
import os
import torch
+import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
+from dataclasses import dataclass
from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
parser = argparse.ArgumentParser(description='NanoChat Web Server')
+parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
@@ -32,7 +52,55 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
-autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
+
+@dataclass
+class Worker:
+ """A worker with a model loaded on a specific GPU."""
+ gpu_id: int
+ device: torch.device
+ engine: Engine
+ tokenizer: object
+ autocast_ctx: torch.amp.autocast
+
+class WorkerPool:
+ """Pool of workers, each with a model replica on a different GPU."""
+
+ def __init__(self, num_gpus: Optional[int] = None):
+ self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
+ self.workers: List[Worker] = []
+ self.available_workers: asyncio.Queue = asyncio.Queue()
+
+ async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
+ """Load model on each GPU."""
+ print(f"Initializing worker pool with {self.num_gpus} GPUs...")
+
+ for gpu_id in range(self.num_gpus):
+ device = torch.device(f"cuda:{gpu_id}")
+ print(f"Loading model on GPU {gpu_id}...")
+
+ model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
+ engine = Engine(model, tokenizer)
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
+
+ worker = Worker(
+ gpu_id=gpu_id,
+ device=device,
+ engine=engine,
+ tokenizer=tokenizer,
+ autocast_ctx=autocast_ctx
+ )
+ self.workers.append(worker)
+ await self.available_workers.put(worker)
+
+ print(f"All {self.num_gpus} workers initialized!")
+
+ async def acquire_worker(self) -> Worker:
+ """Get an available worker from the pool."""
+ return await self.available_workers.get()
+
+ async def release_worker(self, worker: Worker):
+ """Return a worker to the pool."""
+ await self.available_workers.put(worker)
class ChatMessage(BaseModel):
role: str
@@ -43,14 +111,13 @@ class ChatRequest(BaseModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
- stream: Optional[bool] = True
@asynccontextmanager
async def lifespan(app: FastAPI):
- """Load model on startup."""
- print("Loading nanochat model...")
- app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
- app.state.engine = Engine(app.state.model, app.state.tokenizer)
+ """Load models on all GPUs on startup."""
+ print("Loading nanochat models across GPUs...")
+ app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
+ await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
@@ -85,8 +152,7 @@ async def logo():
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
- engine,
- tokenizer,
+ worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
@@ -97,11 +163,11 @@ async def generate_stream(
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k
- assistant_end = tokenizer.encode_special("<|assistant_end|>")
- bos = tokenizer.get_bos_token_id()
+ assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
+ bos = worker.tokenizer.get_bos_token_id()
- with autocast_ctx:
- for token_column, token_masks in engine.generate(
+ with worker.autocast_ctx:
+ for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
@@ -113,82 +179,89 @@ async def generate_stream(
if token == assistant_end or token == bos:
break
- token_text = tokenizer.decode([token])
- yield f"data: {json.dumps({'token': token_text})}\n\n"
+ token_text = worker.tokenizer.decode([token])
+ yield f"data: {json.dumps({'token': token_text, 'gpu': worker.gpu_id})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
- """Chat completion endpoint with streaming."""
- engine = app.state.engine
- tokenizer = app.state.tokenizer
+ """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
+ worker_pool = app.state.worker_pool
- # Build conversation tokens
- bos = tokenizer.get_bos_token_id()
- user_start = tokenizer.encode_special("<|user_start|>")
- user_end = tokenizer.encode_special("<|user_end|>")
- assistant_start = tokenizer.encode_special("<|assistant_start|>")
- assistant_end = tokenizer.encode_special("<|assistant_end|>")
+ # Acquire a worker from the pool (will wait if all are busy)
+ worker = await worker_pool.acquire_worker()
- conversation_tokens = [bos]
- for message in request.messages:
- if message.role == "user":
- conversation_tokens.append(user_start)
- conversation_tokens.extend(tokenizer.encode(message.content))
- conversation_tokens.append(user_end)
- elif message.role == "assistant":
- conversation_tokens.append(assistant_start)
- conversation_tokens.extend(tokenizer.encode(message.content))
- conversation_tokens.append(assistant_end)
+ try:
+ # Build conversation tokens
+ bos = worker.tokenizer.get_bos_token_id()
+ user_start = worker.tokenizer.encode_special("<|user_start|>")
+ user_end = worker.tokenizer.encode_special("<|user_end|>")
+ assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
+ assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
- conversation_tokens.append(assistant_start)
+ conversation_tokens = [bos]
+ for message in request.messages:
+ if message.role == "user":
+ conversation_tokens.append(user_start)
+ conversation_tokens.extend(worker.tokenizer.encode(message.content))
+ conversation_tokens.append(user_end)
+ elif message.role == "assistant":
+ conversation_tokens.append(assistant_start)
+ conversation_tokens.extend(worker.tokenizer.encode(message.content))
+ conversation_tokens.append(assistant_end)
+
+ conversation_tokens.append(assistant_start)
+
+ # Streaming response with worker release after completion
+ async def stream_and_release():
+ try:
+ async for chunk in generate_stream(
+ worker,
+ conversation_tokens,
+ temperature=request.temperature,
+ max_new_tokens=request.max_tokens,
+ top_k=request.top_k
+ ):
+ yield chunk
+ finally:
+ # Release worker back to pool after streaming is done
+ await worker_pool.release_worker(worker)
- if request.stream:
return StreamingResponse(
- generate_stream(
- engine,
- tokenizer,
- conversation_tokens,
- temperature=request.temperature,
- max_new_tokens=request.max_tokens,
- top_k=request.top_k
- ),
+ stream_and_release(),
media_type="text/event-stream"
)
- else:
- # Non-streaming response
- temperature = request.temperature if request.temperature is not None else args.temperature
- max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
- top_k = request.top_k if request.top_k is not None else args.top_k
-
- with autocast_ctx:
- result_tokens, masks = engine.generate_batch(
- conversation_tokens,
- num_samples=1,
- max_tokens=max_tokens,
- temperature=temperature,
- top_k=top_k
- )[0]
-
- response_tokens = result_tokens[len(conversation_tokens):]
- response_text = tokenizer.decode(response_tokens)
- return {
- "choices": [{
- "message": {
- "role": "assistant",
- "content": response_text
- },
- "finish_reason": "stop"
- }]
- }
+ except Exception as e:
+ # Make sure to release worker even on error
+ await worker_pool.release_worker(worker)
+ raise e
@app.get("/health")
async def health():
"""Health check endpoint."""
+ worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
- "ready": hasattr(app.state, 'model') and app.state.model is not None
+ "ready": worker_pool is not None and len(worker_pool.workers) > 0,
+ "num_gpus": worker_pool.num_gpus if worker_pool else 0,
+ "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
+ }
+
+@app.get("/stats")
+async def stats():
+ """Get worker pool statistics."""
+ worker_pool = app.state.worker_pool
+ return {
+ "total_workers": len(worker_pool.workers),
+ "available_workers": worker_pool.available_workers.qsize(),
+ "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
+ "workers": [
+ {
+ "gpu_id": w.gpu_id,
+ "device": str(w.device)
+ } for w in worker_pool.workers
+ ]
}
if __name__ == "__main__":