@@ -3541,7 +3541,11 @@ struct server_context {
35413541                                slot.n_past  = 0 ;
35423542                            }
35433543
3544-                             const  auto  n_swa = llama_model_n_swa (model);
3544+                             //  note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545+                             const  auto  n_swa = std::max (1 , llama_model_n_swa (model));
3546+ 
3547+                             //  the largest pos_min required for a checkpoint to be useful
3548+                             const  auto  pos_min_thold = std::max (0 , slot.n_past  - n_swa);
35453549
35463550                            if  (slot.n_past  > 0  && slot.n_past  < (int ) slot.cache_tokens .size ()) {
35473551                                const  auto  pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
@@ -3550,17 +3554,16 @@ struct server_context {
35503554                                    GGML_ABORT (" pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237" 
35513555                                }
35523556
3553-                                 const  auto  pos_min_thold = std::max (0 , slot.n_past  - n_swa);
3554- 
3555-                                 if  (pos_min > pos_min_thold + 1 ) {
3557+                                 if  (pos_min > pos_min_thold) {
35563558                                    SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
35573559
35583560                                    //  search for a context checkpoint
35593561                                    const  auto  it = std::find_if (
35603562                                        slot.ctx_checkpoints .rbegin (),
35613563                                        slot.ctx_checkpoints .rend (),
35623564                                        [&](const  auto  & cur) {
3563-                                             return  cur.pos_min  <= pos_min_thold;
3565+                                             //  guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3566+                                             return  cur.pos_min  < pos_min_thold;
35643567                                        }
35653568                                    );
35663569
@@ -3577,7 +3580,7 @@ struct server_context {
35773580                                            do_reset = true ;
35783581                                            // printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35793582                                        } else  {
3580-                                             slot.n_past  = std::min (slot.n_past , it->pos_max );
3583+                                             slot.n_past  = std::min (slot.n_past , std::max ( it->pos_min  +  1 , it-> pos_max ) );
35813584                                            SLT_WRN (slot, " restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024  / 1024 );
35823585                                        }
35833586                                    }
@@ -3586,25 +3589,23 @@ struct server_context {
35863589                                        SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n " 
35873590                                                " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" 
35883591                                        slot.n_past  = 0 ;
3589-                                         slot.ctx_checkpoints .clear ();
35903592                                    }
35913593                                }
35923594                            }
35933595
3594-                             if  (n_swa > 0 ) {
3595-                                 const  auto  pos_min_thold = std::max (0 , slot.n_past  - n_swa);
3596- 
3596+                             {
35973597                                //  erase any checkpoints with pos_min > pos_min_thold
35983598                                for  (int  i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
35993599                                    const  auto  & cur = slot.ctx_checkpoints [i];
36003600                                    if  (cur.pos_min  > pos_min_thold) {
3601-                                         slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
36023601                                        SLT_WRN (slot, " erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024  / 1024 );
3602+                                         slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
36033603                                    }
36043604                                }
36053605                            }
36063606                        }
36073607
3608+                         //  [TAG_PROMPT_LOGITS]
36083609                        if  (slot.n_past  == slot.n_prompt_tokens  && slot.n_past  > 0 ) {
36093610                            SLT_WRN (slot, " need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n " n_past , slot.n_prompt_tokens );
36103611                            slot.n_past --;
0 commit comments