add support for CPU and for MPS. I had to change a few cosmetic things. I also discovered I think a bit of a bug, where I was casting wte to bfloat16 in the wrong place (the model init) instead of in init_weights
This commit is contained in:
+1
-11
@@ -11,6 +11,7 @@ dependencies = [
|
||||
"numpy==1.26.4",
|
||||
"psutil>=7.1.0",
|
||||
"regex>=2025.9.1",
|
||||
"setuptools>=80.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch>=2.8.0",
|
||||
@@ -22,17 +23,6 @@ dependencies = [
|
||||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
# target torch to cuda 12.8
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.maturin]
|
||||
module-name = "rustbpe"
|
||||
bindings = "pyo3"
|
||||
|
||||
Reference in New Issue
Block a user