add the SpellingBee task so that nanochat can count r in strawberry etc. along the way we had to add a bunch of new functionality, e.g. extend the calculator to support the count function of python. possibly the current TaskMixture uses way too many synthetic examples of SpellingBee because the eval gives us exactly 100% performance on spelling. We can tune this later to reclaim some wall clock time here I think
This commit is contained in:
@@ -5,6 +5,8 @@ Common utilities for nanochat.
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import fcntl
|
||||
import urllib.request
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -56,6 +58,44 @@ def get_base_dir():
|
||||
os.makedirs(nanochat_dir, exist_ok=True)
|
||||
return nanochat_dir
|
||||
|
||||
def download_file_with_lock(url, filename):
|
||||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||
"""
|
||||
base_dir = get_base_dir()
|
||||
file_path = os.path.join(base_dir, filename)
|
||||
lock_path = file_path + ".lock"
|
||||
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
with open(lock_path, 'w') as lock_file:
|
||||
|
||||
# Only a single rank can acquire this lock
|
||||
# All other ranks block until it is released
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
print(f"Downloading {url}...")
|
||||
with urllib.request.urlopen(url) as response:
|
||||
content = response.read().decode('utf-8')
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"Downloaded to {file_path}")
|
||||
|
||||
# Clean up the lock file after the lock is released
|
||||
try:
|
||||
os.remove(lock_path)
|
||||
except OSError:
|
||||
pass # Ignore if already removed by another process
|
||||
|
||||
return file_path
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
|
||||
+29
-3
@@ -44,12 +44,38 @@ def eval_with_timeout(formula, max_time=3):
|
||||
return None
|
||||
|
||||
def use_calculator(expr):
|
||||
"""Evaluate a math expression safely."""
|
||||
"""
|
||||
Evaluate a Python expression safely.
|
||||
Supports both math expressions and string operations like .count()
|
||||
"""
|
||||
# Remove commas from numbers
|
||||
expr = expr.replace(",", "")
|
||||
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
|
||||
|
||||
# Check if it's a pure math expression (old behavior)
|
||||
if all([x in "0123456789*+-/.() " for x in expr]):
|
||||
if "**" in expr: # disallow power operator
|
||||
return None
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# Check if it's a string operation we support
|
||||
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
||||
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||
if not all([x in allowed_chars for x in expr]):
|
||||
return None
|
||||
if "**" in expr: # for now disallow power operator, could be very expensive
|
||||
|
||||
# Disallow dangerous patterns
|
||||
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
||||
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
||||
'getattr', 'setattr', 'delattr', 'hasattr']
|
||||
expr_lower = expr.lower()
|
||||
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||
return None
|
||||
|
||||
# Only allow .count() method for now (can expand later)
|
||||
if '.count(' not in expr:
|
||||
return None
|
||||
|
||||
# Evaluate with timeout
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user