Fix kv cache, given resize will destroys the logical structure
This commit is contained in:
+4
-3
@@ -135,9 +135,10 @@ class KVCache:
|
|||||||
if t1 > self.kv_cache.size(4):
|
if t1 > self.kv_cache.size(4):
|
||||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||||
current_shape = list(self.kv_cache.shape)
|
additional_shape = list(self.kv_cache.shape)
|
||||||
current_shape[4] = t_needed
|
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
||||||
self.kv_cache.resize_(current_shape)
|
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
||||||
|
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
||||||
# Insert k, v into the cache
|
# Insert k, v into the cache
|
||||||
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
||||||
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
||||||
|
|||||||
Reference in New Issue
Block a user