@@ -764,7 +764,7 @@ struct completion_token_output {
764764    }
765765};
766766
767- struct  swa_checkpoint  {
767+ struct  ctx_checkpoint  {
768768    llama_pos pos_min;
769769    llama_pos pos_max;
770770
@@ -1460,7 +1460,7 @@ struct server_slot {
14601460
14611461    std::vector<completion_token_output> generated_token_probs;
14621462
1463-     std::vector<swa_checkpoint> swa_checkpoints ;
1463+     std::vector<ctx_checkpoint> ctx_checkpoints ;
14641464
14651465    bool  has_next_token = true ;
14661466    bool  has_new_line   = false ;
@@ -3541,7 +3541,11 @@ struct server_context {
35413541                                slot.n_past  = 0 ;
35423542                            }
35433543
3544-                             const  auto  n_swa = llama_model_n_swa (model);
3544+                             //  note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545+                             const  auto  n_swa = std::max (1 , llama_model_n_swa (model));
3546+ 
3547+                             //  the largest pos_min required for a checkpoint to be useful
3548+                             const  auto  pos_min_thold = std::max (0 , slot.n_past  - n_swa);
35453549
35463550                            if  (slot.n_past  > 0  && slot.n_past  < (int ) slot.cache_tokens .size ()) {
35473551                                const  auto  pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
@@ -3550,66 +3554,62 @@ struct server_context {
35503554                                    GGML_ABORT (" pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237" 
35513555                                }
35523556
3553-                                 const  auto  pos_min_thold = std::max (0 , slot.n_past  - n_swa);
3554- 
35553557                                if  (pos_min > pos_min_thold) {
35563558                                    SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
35573559
3558-                                     //  search for a SWA  checkpoint
3560+                                     //  search for a context  checkpoint
35593561                                    const  auto  it = std::find_if (
3560-                                         slot.swa_checkpoints .rbegin (),
3561-                                         slot.swa_checkpoints .rend (),
3562+                                         slot.ctx_checkpoints .rbegin (),
3563+                                         slot.ctx_checkpoints .rend (),
35623564                                        [&](const  auto  & cur) {
3563-                                             return  cur.pos_min  <= pos_min_thold;
3565+                                             //  guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3566+                                             return  cur.pos_min  < pos_min_thold;
35643567                                        }
35653568                                    );
35663569
3567-                                     bool  do_reset = it == slot.swa_checkpoints .rend ();
3570+                                     bool  do_reset = it == slot.ctx_checkpoints .rend ();
3571+                                     // printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
35683572
35693573                                    if  (!do_reset) {
3570-                                         //  restore the checkpoint
3571-                                         const  size_t  swa_size  = it->data .size ();
3572-                                         const  size_t  n = llama_state_seq_set_data_ext (ctx, it->data .data (), swa_size , slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY );
3574+                                         //  restore the context  checkpoint
3575+                                         const  size_t  ctx_checkpoint_size  = it->data .size ();
3576+                                         const  size_t  n = llama_state_seq_set_data_ext (ctx, it->data .data (), ctx_checkpoint_size , slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
35733577
3574-                                         if  (n != swa_size ) {
3575-                                             SLT_ERR (slot, " failed to restore SWA  checkpoint,  pos_min = %d, pos_max = %d, size = %.3f MiB\n " pos_min , it->pos_max , (float ) swa_size  / 1024  / 1024 );
3578+                                         if  (n != ctx_checkpoint_size ) {
3579+                                             SLT_ERR (slot, " failed to restore context  checkpoint ( pos_min = %d, pos_max = %d, size = %.3f MiB) \n " pos_min , it->pos_max , (float ) ctx_checkpoint_size  / 1024  / 1024 );
35763580                                            do_reset = true ;
3581+                                             // printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35773582                                        } else  {
3578-                                             slot.n_past  = std::min (slot.n_past , it->pos_max );
3579- 
3580-                                             SLT_WRN (slot, " SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n " pos_min , it->pos_max , (float ) swa_size / 1024  / 1024 );
3583+                                             slot.n_past  = std::min (slot.n_past , std::max (it->pos_min  + 1 , it->pos_max ));
3584+                                             SLT_WRN (slot, " restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024  / 1024 );
35813585                                        }
35823586                                    }
35833587
35843588                                    if  (do_reset) {
3585-                                         SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n " 
3589+                                         SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory , see %s)\n " 
35863590                                                " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" 
3587- 
35883591                                        slot.n_past  = 0 ;
3589-                                         slot.swa_checkpoints .clear ();
35903592                                    }
35913593                                }
35923594                            }
35933595
3594-                             if  (n_swa > 0 ) {
3595-                                 const  auto  pos_min_thold = std::max (0 , slot.n_past  - n_swa);
3596- 
3596+                             {
35973597                                //  erase any checkpoints with pos_min > pos_min_thold
3598-                                 for  (int  i = (int ) slot.swa_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599-                                     const  auto  & cur = slot.swa_checkpoints [i];
3598+                                 for  (int  i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599+                                     const  auto  & cur = slot.ctx_checkpoints [i];
36003600                                    if  (cur.pos_min  > pos_min_thold) {
3601-                                         slot.swa_checkpoints .erase (slot.swa_checkpoints .begin () + i);
3602- 
3603-                                         SLT_WRN (slot, " SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n " pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 );
3601+                                         SLT_WRN (slot, " erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024  / 1024 );
3602+                                         slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
36043603                                    }
36053604                                }
36063605                            }
36073606                        }
36083607
3608+                         //  [TAG_PROMPT_LOGITS]
36093609                        if  (slot.n_past  == slot.n_prompt_tokens  && slot.n_past  > 0 ) {
3610-                             SLT_WRN (slot, " need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n " n_past , slot.n_prompt_tokens );
3611- 
3610+                             SLT_WRN (slot, " need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n " n_past , slot.n_prompt_tokens );
36123611                            slot.n_past --;
3612+                             SLT_WRN (slot, " n_past was set to %d\n " n_past );
36133613                        }
36143614
36153615                        slot.n_prompt_tokens_cache      = slot.n_past ;
@@ -3623,17 +3623,17 @@ struct server_context {
36233623                        }
36243624                    }
36253625
3626-                     //  keep only the common part 
3626+                     //  truncate any tokens that are beyond n_past for this slot 
36273627                    if  (!llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 )) {
3628-                         //  could not partially delete (likely using a non-Transformer model) 
3628+                         SLT_WRN (slot,  " failed to truncate tokens beyond n_past = %d \n " , slot. n_past ); 
36293629                        llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
36303630
36313631                        //  there is no common part left
36323632                        slot.n_past                 = 0 ;
36333633                        slot.n_prompt_tokens_cache  = 0 ;
36343634                    }
36353635
3636-                     SLT_INF (slot, " kv cache rm  [%d, end)\n " n_past );
3636+                     SLT_INF (slot, " n_past = %d, memory_seq_rm  [%d, end)\n " , slot. n_past , slot.n_past );
36373637
36383638                    //  remove the non-common part from the cache
36393639                    slot.cache_tokens .keep_first (slot.n_past );
@@ -3854,37 +3854,38 @@ struct server_context {
38543854                    //  prompt evaluated for next-token prediction
38553855                    slot.state  = SLOT_STATE_GENERATING;
38563856
3857-                     //  make a checkpoint with the SWA memory
3858-                     //  checkpoints are needed only if we are not using "--swa-full"
3859-                     if  (llama_model_n_swa (model) > 0  && !params_base.swa_full  && params_base.n_swa_checkpoints  > 0 ) {
3860-                         if  (slot.swa_checkpoints .size () >= (size_t ) params_base.n_swa_checkpoints ) {
3861-                             {
3862-                                 const  auto  & cur = slot.swa_checkpoints .back ();
3863- 
3864-                                 SLT_WRN (slot, " SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n " 
3865-                                         cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 );
3866-                             }
3867- 
3868-                             slot.swa_checkpoints .erase (slot.swa_checkpoints .begin ());
3857+                     //  make a checkpoint of the parts of the memory that cannot be rolled back.
3858+                     //  checkpoints are created only if:
3859+                     //  - the model uses SWA and we are not using `swa_full`
3860+                     //  - the model architecture is marked as recurrent or hybrid
3861+                     // 
3862+                     //  TODO: try to make this conditional on the context or the memory module, instead of the model type
3863+                     const  bool  do_checkpoint =
3864+                         (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3865+                         (llama_model_n_swa (model) > 0  && !params_base.swa_full );
3866+ 
3867+                     if  (do_checkpoint && params_base.n_ctx_checkpoints  > 0 ) {
3868+                         while  (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3869+                             //  make room for the new checkpoint, if needed
3870+                             const  auto  & cur = slot.ctx_checkpoints .front ();
3871+                             SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " 
3872+                                     cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 );
3873+ 
3874+                             slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
38693875                        }
38703876
3871-                         const  size_t  swa_size  = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY );
3877+                         const  size_t  checkpoint_size  = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
38723878
3873-                         auto  & cur = slot.swa_checkpoints .emplace_back (swa_checkpoint {
3879+                         auto  & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint {
38743880                            /* .pos_min = */ llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id ),
38753881                            /* .pos_max = */ llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id ),
3876-                             /* .data    = */ uint8_t >(swa_size ),
3882+                             /* .data    = */ uint8_t >(checkpoint_size ),
38773883                        });
38783884
3879-                         llama_state_seq_get_data_ext (ctx, cur.data .data (), swa_size, slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3880- 
3881-                         float  size_total = 0 .0f ;
3882-                         for  (const  auto  & checkpoint : slot.swa_checkpoints ) {
3883-                             size_total += (float ) checkpoint.data .size () / 1024  / 1024 ;
3884-                         }
3885+                         llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
38853886
3886-                         SLT_WRN (slot, " SWA  checkpoint create,  pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d ( %.3f MiB)\n " 
3887-                                 cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 , ( int ) slot. swa_checkpoints . size (), params_base. n_swa_checkpoints , size_total );
3887+                         SLT_WRN (slot, " saved context  checkpoint %d of %d ( pos_min = %d, pos_max = %d, size = %.3f MiB)\n " 
3888+                                 ( int ) slot. ctx_checkpoints . size (), params_base. n_ctx_checkpoints ,  cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 );
38883889                    }
38893890                } else  if  (slot.state  != SLOT_STATE_GENERATING) {
38903891                    continue ; //  continue loop of slots
0 commit comments