use enable_gqa of pytorch sdpa, allows us to delete some code, didnt realize it's available
This commit is contained in:
@@ -85,7 +85,7 @@ print0(f"Vocab size: {vocab_size:,}")
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # 1:1 MQA ratio
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
|
||||
Reference in New Issue
Block a user