@@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift(
15521552
15531553 for (uint32_t i = 0 ; i < cache.size ; ++i) {
15541554 if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1555- cache.cells [i].pos += delta;
1555+ cache.has_shift = true ;
1556+ cache.cells [i].pos += delta;
1557+ cache.cells [i].delta += delta;
1558+
15561559 if (cache.cells [i].pos < 0 ) {
15571560 cache.cells [i].pos = -1 ;
15581561 cache.cells [i].seq_id .clear ();
15591562 if (new_head == cache.size ) new_head = i;
1560- } else {
1561- cache.has_shift = true ;
1562- cache.cells [i].delta = delta;
15631563 }
15641564 }
15651565 }
@@ -6073,11 +6073,20 @@ static int llama_decode_internal(
60736073#endif
60746074
60756075 // update the kv ring buffer
6076- lctx.kv_self .has_shift = false ;
6077- lctx.kv_self .head += n_tokens;
6078- // Ensure kv cache head points to a valid index.
6079- if (lctx.kv_self .head >= lctx.kv_self .size ) {
6080- lctx.kv_self .head = 0 ;
6076+ {
6077+ if (kv_self.has_shift ) {
6078+ kv_self.has_shift = false ;
6079+ for (uint32_t i = 0 ; i < kv_self.size ; ++i) {
6080+ kv_self.cells [i].delta = 0 ;
6081+ }
6082+ }
6083+
6084+ kv_self.head += n_tokens;
6085+
6086+ // Ensure kv cache head points to a valid index.
6087+ if (kv_self.head >= kv_self.size ) {
6088+ kv_self.head = 0 ;
6089+ }
60816090 }
60826091
60836092#ifdef GGML_PERF
0 commit comments