@@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
14661466 return 0 ;
14671467}
14681468
1469- static void llama_kv_cache_tokens_rm (struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
1470- if (c0 < 0 ) c0 = 0 ;
1471- if (c1 < 0 ) c1 = cache.size ;
1472-
1473- for (int32_t i = c0; i < c1; ++i) {
1469+ static void llama_kv_cache_clear (struct llama_kv_cache & cache) {
1470+ for (int32_t i = 0 ; i < cache.size ; ++i) {
14741471 cache.cells [i].pos = -1 ;
14751472 cache.cells [i].seq_id .clear ();
14761473 }
1477-
1478- // Searching for a free slot can start here since we know it will be empty.
1479- cache.head = uint32_t (c0);
1474+ cache.head = 0 ;
14801475}
14811476
14821477static void llama_kv_cache_seq_rm (
@@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm(
14901485 if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
14911486
14921487 for (uint32_t i = 0 ; i < cache.size ; ++i) {
1493- if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1494- cache.cells [i].seq_id .erase (seq_id);
1488+ if (cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1489+ if (seq_id < 0 ) {
1490+ cache.cells [i].seq_id .clear ();
1491+ } else if (cache.cells [i].has_seq_id (seq_id)) {
1492+ cache.cells [i].seq_id .erase (seq_id);
1493+ } else {
1494+ continue ;
1495+ }
14951496 if (cache.cells [i].seq_id .empty ()) {
14961497 cache.cells [i].pos = -1 ;
14971498 if (new_head == cache.size ) new_head = i;
@@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
92079208 return ctx->kv_self .head ;
92089209}
92099210
9210- void llama_kv_cache_tokens_rm (struct llama_context * ctx, int32_t c0, int32_t c1 ) {
9211- llama_kv_cache_tokens_rm (ctx->kv_self , c0, c1 );
9211+ void llama_kv_cache_clear (struct llama_context * ctx) {
9212+ llama_kv_cache_clear (ctx->kv_self );
92129213}
92139214
92149215void llama_kv_cache_seq_rm (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -9654,7 +9655,7 @@ int llama_eval(
96549655 llama_token * tokens,
96559656 int32_t n_tokens,
96569657 int n_past) {
9657- llama_kv_cache_tokens_rm (ctx->kv_self , n_past, -1 );
9658+ llama_kv_cache_seq_rm (ctx->kv_self , - 1 , n_past, -1 );
96589659
96599660 const int ret = llama_decode_internal (*ctx, llama_batch_get_one (tokens, n_tokens, n_past, 0 ));
96609661 if (ret < 0 ) {
@@ -9669,7 +9670,7 @@ int llama_eval_embd(
96699670 float * embd,
96709671 int32_t n_tokens,
96719672 int n_past) {
9672- llama_kv_cache_tokens_rm (ctx->kv_self , n_past, -1 );
9673+ llama_kv_cache_seq_rm (ctx->kv_self , - 1 , n_past, -1 );
96739674
96749675 llama_batch batch = { n_tokens, nullptr , embd, nullptr , nullptr , nullptr , nullptr , n_past, 1 , 0 , };
96759676
0 commit comments