delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
+25
-33
@@ -44,7 +44,6 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
@@ -69,7 +68,6 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
@@ -84,7 +82,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
@@ -93,7 +90,6 @@ class Worker:
|
||||
device: torch.device
|
||||
engine: Engine
|
||||
tokenizer: object
|
||||
autocast_ctx: torch.amp.autocast
|
||||
|
||||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
@@ -125,14 +121,11 @@ class WorkerPool:
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
device=device,
|
||||
engine=engine,
|
||||
tokenizer=tokenizer,
|
||||
autocast_ctx=autocast_ctx
|
||||
)
|
||||
self.workers.append(worker)
|
||||
await self.available_workers.put(worker)
|
||||
@@ -279,34 +272,33 @@ async def generate_stream(
|
||||
# Track the last complete UTF-8 string (without replacement characters)
|
||||
last_clean_text = ""
|
||||
|
||||
with worker.autocast_ctx:
|
||||
for token_column, token_masks in worker.engine.generate(
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
):
|
||||
token = token_column[0]
|
||||
for token_column, token_masks in worker.engine.generate(
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
):
|
||||
token = token_column[0]
|
||||
|
||||
# Stopping criteria
|
||||
if token == assistant_end or token == bos:
|
||||
break
|
||||
# Stopping criteria
|
||||
if token == assistant_end or token == bos:
|
||||
break
|
||||
|
||||
# Append the token to sequence
|
||||
accumulated_tokens.append(token)
|
||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
# Only emit text if it doesn't end with a replacement character
|
||||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith('�'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
last_clean_text = current_text
|
||||
# Append the token to sequence
|
||||
accumulated_tokens.append(token)
|
||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
# Only emit text if it doesn't end with a replacement character
|
||||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith('�'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
last_clean_text = current_text
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user