Merge branch 'master' into logo/kerning-update
This commit is contained in:
@@ -26,7 +26,6 @@ class DistAdamW(torch.optim.Optimizer):
|
||||
grad_slices = []
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
|
||||
for base_i in range(len(params)):
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
|
||||
+64
-10
@@ -5,6 +5,8 @@ Common utilities for nanochat.
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import fcntl
|
||||
import urllib.request
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -56,6 +58,44 @@ def get_base_dir():
|
||||
os.makedirs(nanochat_dir, exist_ok=True)
|
||||
return nanochat_dir
|
||||
|
||||
def download_file_with_lock(url, filename):
|
||||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||
"""
|
||||
base_dir = get_base_dir()
|
||||
file_path = os.path.join(base_dir, filename)
|
||||
lock_path = file_path + ".lock"
|
||||
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
with open(lock_path, 'w') as lock_file:
|
||||
|
||||
# Only a single rank can acquire this lock
|
||||
# All other ranks block until it is released
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
print(f"Downloading {url}...")
|
||||
with urllib.request.urlopen(url) as response:
|
||||
content = response.read().decode('utf-8')
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"Downloaded to {file_path}")
|
||||
|
||||
# Clean up the lock file after the lock is released
|
||||
try:
|
||||
os.remove(lock_path)
|
||||
except OSError:
|
||||
pass # Ignore if already removed by another process
|
||||
|
||||
return file_path
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
@@ -89,32 +129,46 @@ def get_dist_info():
|
||||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
def compute_init():
|
||||
def autodetect_device_type():
|
||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||
if torch.cuda.is_available():
|
||||
device_type = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device_type = "mps"
|
||||
else:
|
||||
device_type = "cpu"
|
||||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Precision
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp:
|
||||
if ddp and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
||||
@@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
@@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
|
||||
# infinite iterator over document batches
|
||||
def document_batches():
|
||||
@@ -38,12 +37,13 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
||||
token_buffer.extend(tokens)
|
||||
batch_index += 1
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
yield inputs, targets
|
||||
|
||||
+34
-6
@@ -44,12 +44,38 @@ def eval_with_timeout(formula, max_time=3):
|
||||
return None
|
||||
|
||||
def use_calculator(expr):
|
||||
"""Evaluate a math expression safely."""
|
||||
"""
|
||||
Evaluate a Python expression safely.
|
||||
Supports both math expressions and string operations like .count()
|
||||
"""
|
||||
# Remove commas from numbers
|
||||
expr = expr.replace(",", "")
|
||||
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
|
||||
|
||||
# Check if it's a pure math expression (old behavior)
|
||||
if all([x in "0123456789*+-/.() " for x in expr]):
|
||||
if "**" in expr: # disallow power operator
|
||||
return None
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# Check if it's a string operation we support
|
||||
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
||||
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||
if not all([x in allowed_chars for x in expr]):
|
||||
return None
|
||||
if "**" in expr: # for now disallow power operator, could be very expensive
|
||||
|
||||
# Disallow dangerous patterns
|
||||
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
||||
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
||||
'getattr', 'setattr', 'delattr', 'hasattr']
|
||||
expr_lower = expr.lower()
|
||||
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||
return None
|
||||
|
||||
# Only allow .count() method for now (can expand later)
|
||||
if '.count(' not in expr:
|
||||
return None
|
||||
|
||||
# Evaluate with timeout
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -109,9 +135,11 @@ class KVCache:
|
||||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
current_shape = list(self.kv_cache.shape)
|
||||
current_shape[4] = t_needed
|
||||
self.kv_cache.resize_(current_shape)
|
||||
additional_shape = list(self.kv_cache.shape)
|
||||
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
||||
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
||||
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
||||
self.kv_shape = self.kv_cache.shape
|
||||
# Insert k, v into the cache
|
||||
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
||||
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
||||
|
||||
@@ -146,13 +146,12 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
||||
with caution.
|
||||
"""
|
||||
|
||||
if maximum_memory_bytes is not None:
|
||||
if platform.uname().system != "Darwin":
|
||||
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||
import resource
|
||||
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
if not platform.uname().system == "Darwin":
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
|
||||
faulthandler.disable()
|
||||
|
||||
@@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
||||
rmtree = shutil.rmtree
|
||||
rmdir = os.rmdir
|
||||
chdir = os.chdir
|
||||
unlink = os.unlink
|
||||
|
||||
# Disable functionalities that can make destructive changes to the test.
|
||||
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
||||
@@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
||||
shutil.rmtree = rmtree
|
||||
os.rmdir = rmdir
|
||||
os.chdir = chdir
|
||||
os.unlink = unlink
|
||||
|
||||
|
||||
def execute_code(
|
||||
|
||||
+7
-22
@@ -48,19 +48,6 @@ def apply_rotary_emb(x, cos, sin):
|
||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
return out
|
||||
|
||||
|
||||
def repeat_kv(x, n_rep):
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
if n_rep == 1:
|
||||
return x
|
||||
bs, n_kv_heads, slen, head_dim = x.shape
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
@@ -96,19 +83,16 @@ class CausalSelfAttention(nn.Module):
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Apply MQA: replicate the key/value heads for each query head
|
||||
nrep = self.n_head // self.n_kv_head
|
||||
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
@@ -118,7 +102,7 @@ class CausalSelfAttention(nn.Module):
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
@@ -169,8 +153,6 @@ class GPT(nn.Module):
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
@@ -184,6 +166,9 @@ class GPT(nn.Module):
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
|
||||
@@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y < 0).any():
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
|
||||
@@ -283,6 +283,10 @@ class Report:
|
||||
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
||||
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
||||
bloat_data = bloat_data.group(1) if bloat_data else ""
|
||||
else:
|
||||
start_time = None # will cause us to not write the total wall clock time
|
||||
bloat_data = "[bloat data missing]"
|
||||
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
||||
# process all the individual sections
|
||||
for file_name in EXPECTED_FILES:
|
||||
section_file = os.path.join(report_dir, file_name)
|
||||
|
||||
@@ -341,16 +341,19 @@ class RustBPETokenizer:
|
||||
mask = mask[:max_tokens]
|
||||
return ids, mask
|
||||
|
||||
def visualize_tokenization(self, ids, mask):
|
||||
def visualize_tokenization(self, ids, mask, with_token_id=False):
|
||||
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
RESET = '\033[0m'
|
||||
GRAY = '\033[90m'
|
||||
tokens = []
|
||||
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
||||
token_str = self.decode([token_id])
|
||||
color = GREEN if mask_val == 1 else RED
|
||||
tokens.append(f"{color}{token_str}{RESET}")
|
||||
if with_token_id:
|
||||
tokens.append(f"{GRAY}({token_id}){RESET}")
|
||||
return '|'.join(tokens)
|
||||
|
||||
def render_for_completion(self, conversation):
|
||||
|
||||
+181
-13
@@ -2,7 +2,7 @@
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
|
||||
<title>NanoChat</title>
|
||||
<link rel="icon" type="image/svg+xml" href="/logo.svg">
|
||||
<style>
|
||||
@@ -18,7 +18,7 @@
|
||||
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
||||
background-color: #ffffff;
|
||||
color: #111827;
|
||||
min-height: 100vh;
|
||||
min-height: 100dvh;
|
||||
margin: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
@@ -108,6 +108,15 @@
|
||||
background: transparent;
|
||||
border: none;
|
||||
padding: 0.25rem 0;
|
||||
cursor: pointer;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem;
|
||||
margin-left: -0.5rem;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.message.assistant .message-content:hover {
|
||||
background-color: #f9fafb;
|
||||
}
|
||||
|
||||
.message.user .message-content {
|
||||
@@ -115,11 +124,27 @@
|
||||
border-radius: 1.25rem;
|
||||
padding: 0.8rem 1rem;
|
||||
max-width: 65%;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.message.user .message-content:hover {
|
||||
background-color: #e5e7eb;
|
||||
}
|
||||
|
||||
.message.console .message-content {
|
||||
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
|
||||
font-size: 0.875rem;
|
||||
background-color: #fafafa;
|
||||
padding: 0.75rem 1rem;
|
||||
color: #374151;
|
||||
max-width: 80%;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
background-color: #ffffff;
|
||||
padding: 1rem;
|
||||
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
|
||||
}
|
||||
|
||||
.input-wrapper {
|
||||
@@ -255,6 +280,8 @@
|
||||
|
||||
let messages = [];
|
||||
let isGenerating = false;
|
||||
let currentTemperature = 0.8;
|
||||
let currentTopK = 50;
|
||||
|
||||
chatInput.addEventListener('input', function() {
|
||||
this.style.height = 'auto';
|
||||
@@ -289,7 +316,7 @@
|
||||
chatInput.focus();
|
||||
}
|
||||
|
||||
function addMessage(role, content) {
|
||||
function addMessage(role, content, messageIndex = null) {
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = `message ${role}`;
|
||||
|
||||
@@ -297,6 +324,28 @@
|
||||
contentDiv.className = 'message-content';
|
||||
contentDiv.textContent = content;
|
||||
|
||||
// Add click handler for user messages to enable editing
|
||||
if (role === 'user' && messageIndex !== null) {
|
||||
contentDiv.setAttribute('data-message-index', messageIndex);
|
||||
contentDiv.setAttribute('title', 'Click to edit and restart from here');
|
||||
contentDiv.addEventListener('click', function() {
|
||||
if (!isGenerating) {
|
||||
editMessage(messageIndex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Add click handler for assistant messages to enable regeneration
|
||||
if (role === 'assistant' && messageIndex !== null) {
|
||||
contentDiv.setAttribute('data-message-index', messageIndex);
|
||||
contentDiv.setAttribute('title', 'Click to regenerate this response');
|
||||
contentDiv.addEventListener('click', function() {
|
||||
if (!isGenerating) {
|
||||
regenerateMessage(messageIndex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
messageDiv.appendChild(contentDiv);
|
||||
chatWrapper.appendChild(messageDiv);
|
||||
|
||||
@@ -304,17 +353,35 @@
|
||||
return contentDiv;
|
||||
}
|
||||
|
||||
async function sendMessage() {
|
||||
const message = chatInput.value.trim();
|
||||
if (!message || isGenerating) return;
|
||||
function editMessage(messageIndex) {
|
||||
// Find the message in the messages array
|
||||
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
||||
|
||||
isGenerating = true;
|
||||
chatInput.value = '';
|
||||
const messageToEdit = messages[messageIndex];
|
||||
if (messageToEdit.role !== 'user') return;
|
||||
|
||||
// Copy message content to input
|
||||
chatInput.value = messageToEdit.content;
|
||||
chatInput.style.height = 'auto';
|
||||
sendButton.disabled = true;
|
||||
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
|
||||
|
||||
messages.push({ role: 'user', content: message });
|
||||
addMessage('user', message);
|
||||
// Remove this message and all subsequent messages from the array
|
||||
messages = messages.slice(0, messageIndex);
|
||||
|
||||
// Remove message elements from DOM starting from messageIndex
|
||||
const allMessages = chatWrapper.querySelectorAll('.message');
|
||||
for (let i = messageIndex; i < allMessages.length; i++) {
|
||||
allMessages[i].remove();
|
||||
}
|
||||
|
||||
// Enable send button and focus input
|
||||
sendButton.disabled = false;
|
||||
chatInput.focus();
|
||||
}
|
||||
|
||||
async function generateAssistantResponse() {
|
||||
isGenerating = true;
|
||||
sendButton.disabled = true;
|
||||
|
||||
const assistantContent = addMessage('assistant', '');
|
||||
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
||||
@@ -327,8 +394,8 @@
|
||||
},
|
||||
body: JSON.stringify({
|
||||
messages: messages,
|
||||
stream: true,
|
||||
temperature: 0.8,
|
||||
temperature: currentTemperature,
|
||||
top_k: currentTopK,
|
||||
max_tokens: 512
|
||||
}),
|
||||
});
|
||||
@@ -364,8 +431,18 @@
|
||||
}
|
||||
}
|
||||
|
||||
const assistantMessageIndex = messages.length;
|
||||
messages.push({ role: 'assistant', content: fullResponse });
|
||||
|
||||
// Add click handler to regenerate this assistant message
|
||||
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
|
||||
assistantContent.setAttribute('title', 'Click to regenerate this response');
|
||||
assistantContent.addEventListener('click', function() {
|
||||
if (!isGenerating) {
|
||||
regenerateMessage(assistantMessageIndex);
|
||||
}
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
||||
@@ -375,6 +452,97 @@
|
||||
}
|
||||
}
|
||||
|
||||
async function regenerateMessage(messageIndex) {
|
||||
// Find the message in the messages array
|
||||
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
||||
|
||||
const messageToRegenerate = messages[messageIndex];
|
||||
if (messageToRegenerate.role !== 'assistant') return;
|
||||
|
||||
// Remove this message and all subsequent messages from the array
|
||||
messages = messages.slice(0, messageIndex);
|
||||
|
||||
// Remove message elements from DOM starting from messageIndex
|
||||
const allMessages = chatWrapper.querySelectorAll('.message');
|
||||
for (let i = messageIndex; i < allMessages.length; i++) {
|
||||
allMessages[i].remove();
|
||||
}
|
||||
|
||||
// Regenerate the assistant response
|
||||
await generateAssistantResponse();
|
||||
}
|
||||
|
||||
function handleSlashCommand(command) {
|
||||
const parts = command.trim().split(/\s+/);
|
||||
const cmd = parts[0].toLowerCase();
|
||||
const arg = parts[1];
|
||||
|
||||
if (cmd === '/temperature') {
|
||||
if (arg === undefined) {
|
||||
addMessage('console', `Current temperature: ${currentTemperature}`);
|
||||
} else {
|
||||
const temp = parseFloat(arg);
|
||||
if (isNaN(temp) || temp < 0 || temp > 2) {
|
||||
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
|
||||
} else {
|
||||
currentTemperature = temp;
|
||||
addMessage('console', `Temperature set to ${currentTemperature}`);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else if (cmd === '/topk') {
|
||||
if (arg === undefined) {
|
||||
addMessage('console', `Current top-k: ${currentTopK}`);
|
||||
} else {
|
||||
const topk = parseInt(arg);
|
||||
if (isNaN(topk) || topk < 1 || topk > 200) {
|
||||
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
|
||||
} else {
|
||||
currentTopK = topk;
|
||||
addMessage('console', `Top-k set to ${currentTopK}`);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else if (cmd === '/clear') {
|
||||
newConversation();
|
||||
return true;
|
||||
} else if (cmd === '/help') {
|
||||
addMessage('console',
|
||||
'Available commands:\n' +
|
||||
'/temperature - Show current temperature\n' +
|
||||
'/temperature <value> - Set temperature (0.0-2.0)\n' +
|
||||
'/topk - Show current top-k\n' +
|
||||
'/topk <value> - Set top-k (1-200)\n' +
|
||||
'/clear - Clear conversation\n' +
|
||||
'/help - Show this help message'
|
||||
);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async function sendMessage() {
|
||||
const message = chatInput.value.trim();
|
||||
if (!message || isGenerating) return;
|
||||
|
||||
// Handle slash commands
|
||||
if (message.startsWith('/')) {
|
||||
chatInput.value = '';
|
||||
chatInput.style.height = 'auto';
|
||||
handleSlashCommand(message);
|
||||
return;
|
||||
}
|
||||
|
||||
chatInput.value = '';
|
||||
chatInput.style.height = 'auto';
|
||||
|
||||
const userMessageIndex = messages.length;
|
||||
messages.push({ role: 'user', content: message });
|
||||
addMessage('user', message, userMessageIndex);
|
||||
|
||||
await generateAssistantResponse();
|
||||
}
|
||||
|
||||
sendButton.disabled = false;
|
||||
|
||||
// Autofocus the chat input on page load
|
||||
|
||||
Reference in New Issue
Block a user