diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5f20f51209..0100715eef 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -193,6 +193,10 @@ def LlamaAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) + + # Need to do it prior 2 steps before hitting full on short KV cache + # or else error + self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) cos, sin = self.rotary_emb.get_cached(kv_seq_len) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) @@ -1122,7 +1126,7 @@ def get_cached(self, seq_len = None): def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 - self.current_rope_size = math.ceil(seq_len / 8192) * 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) pass pass @@ -1248,7 +1252,7 @@ def get_cached(self, seq_len = None): def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 - self.current_rope_size = math.ceil(seq_len / 8192) * 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) pass pass @@ -1363,7 +1367,7 @@ def get_cached(self, seq_len = None): def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 - self.current_rope_size = math.ceil(seq_len / 8192) * 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) pass pass