add support for CPU and for MPS. I had to change a few cosmetic things. I also discovered I think a bit of a bug, where I was casting wte to bfloat16 in the wrong place (the model init) instead of in init_weights
This commit is contained in:
+8
-6
@@ -89,11 +89,14 @@ def get_dist_info():
|
||||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
# assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
@@ -101,11 +104,10 @@ def compute_init(device_type="cuda"): # cuda|cpu
|
||||
torch.cuda.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Precision
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
||||
+3
-2
@@ -169,8 +169,6 @@ class GPT(nn.Module):
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
@@ -184,6 +182,9 @@ class GPT(nn.Module):
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
|
||||
@@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y < 0).any():
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
|
||||
Reference in New Issue
Block a user