add support for CPU and for MPS. I had to change a few cosmetic things. I also discovered I think a bit of a bug, where I was casting wte to bfloat16 in the wrong place (the model init) instead of in init_weights

This commit is contained in:
karpathy
2025-10-16 10:04:43 -07:00
parent 722da4f543
commit 306bc380ab
6 changed files with 68 additions and 46 deletions
+1 -11
View File
@@ -11,6 +11,7 @@ dependencies = [
"numpy==1.26.4",
"psutil>=7.1.0",
"regex>=2025.9.1",
"setuptools>=80.9.0",
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.8.0",
@@ -22,17 +23,6 @@ dependencies = [
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.maturin]
module-name = "rustbpe"
bindings = "pyo3"