remove spurious cast, gets compiled away anyway but it's confusing people
This commit is contained in:
+2
-4
@@ -41,12 +41,10 @@ def norm(x):
|
|||||||
def apply_rotary_emb(x, cos, sin):
|
def apply_rotary_emb(x, cos, sin):
|
||||||
assert x.ndim == 4 # multihead attention
|
assert x.ndim == 4 # multihead attention
|
||||||
d = x.shape[3] // 2
|
d = x.shape[3] // 2
|
||||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
||||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||||
y2 = x1 * (-sin) + x2 * cos
|
y2 = x1 * (-sin) + x2 * cos
|
||||||
out = torch.cat([y1, y2], 3) # re-assemble
|
return torch.cat([y1, y2], 3)
|
||||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
|
||||||
return out
|
|
||||||
|
|
||||||
class CausalSelfAttention(nn.Module):
|
class CausalSelfAttention(nn.Module):
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
|
|||||||
Reference in New Issue
Block a user