@@ -3676,6 +3676,20 @@ struct server_context {
36763676 alora_disabled_id = enabled_loras[0 ];
36773677 }
36783678
3679+ bool do_checkpoint = params_base.n_ctx_checkpoints > 0 ;
3680+
3681+ // make a checkpoint of the parts of the memory that cannot be rolled back.
3682+ // checkpoints are created only if:
3683+ // - the model uses SWA and we are not using `swa_full`
3684+ // - the model architecture is marked as recurrent or hybrid
3685+ //
3686+ // TODO: try to make this conditional on the context or the memory module, instead of the model type
3687+ do_checkpoint = do_checkpoint && (
3688+ llama_model_is_recurrent (model) ||
3689+ llama_model_is_hybrid (model) ||
3690+ (llama_model_n_swa (model) > 0 && !params_base.swa_full )
3691+ );
3692+
36793693 // add prompt tokens for processing in the current batch
36803694 while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
36813695 // get next token to process
@@ -3700,6 +3714,11 @@ struct server_context {
37003714
37013715 slot.n_prompt_tokens_processed ++;
37023716 slot.n_past ++;
3717+
3718+ // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
3719+ if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64 ) {
3720+ break ;
3721+ }
37033722 }
37043723
37053724 // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
@@ -3730,6 +3749,39 @@ struct server_context {
37303749 slot.i_batch = batch.n_tokens - 1 ;
37313750
37323751 SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.n_tokens );
3752+
3753+ const auto pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
3754+ const auto pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id );
3755+
3756+ // no need for empty or small checkpoints
3757+ do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64 );
3758+
3759+ // no need to create checkpoints that are too close together
3760+ do_checkpoint = do_checkpoint && (slot.ctx_checkpoints .empty () || pos_max > slot.ctx_checkpoints .back ().pos_max + 64 );
3761+
3762+ if (do_checkpoint) {
3763+ while (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3764+ // make room for the new checkpoint, if needed
3765+ const auto & cur = slot.ctx_checkpoints .front ();
3766+ SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3767+ cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3768+
3769+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
3770+ }
3771+
3772+ const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3773+
3774+ auto & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint{
3775+ /* .pos_min = */ pos_min,
3776+ /* .pos_max = */ pos_max,
3777+ /* .data = */ std::vector<uint8_t >(checkpoint_size),
3778+ });
3779+
3780+ llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3781+
3782+ SLT_WRN (slot, " saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3783+ (int ) slot.ctx_checkpoints .size (), params_base.n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3784+ }
37333785 }
37343786 }
37353787
@@ -3853,40 +3905,6 @@ struct server_context {
38533905
38543906 // prompt evaluated for next-token prediction
38553907 slot.state = SLOT_STATE_GENERATING;
3856-
3857- // make a checkpoint of the parts of the memory that cannot be rolled back.
3858- // checkpoints are created only if:
3859- // - the model uses SWA and we are not using `swa_full`
3860- // - the model architecture is marked as recurrent or hybrid
3861- //
3862- // TODO: try to make this conditional on the context or the memory module, instead of the model type
3863- const bool do_checkpoint =
3864- (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3865- (llama_model_n_swa (model) > 0 && !params_base.swa_full );
3866-
3867- if (do_checkpoint && params_base.n_ctx_checkpoints > 0 ) {
3868- while (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3869- // make room for the new checkpoint, if needed
3870- const auto & cur = slot.ctx_checkpoints .front ();
3871- SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3872- cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3873-
3874- slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
3875- }
3876-
3877- const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3878-
3879- auto & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint{
3880- /* .pos_min = */ llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id ),
3881- /* .pos_max = */ llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id ),
3882- /* .data = */ std::vector<uint8_t >(checkpoint_size),
3883- });
3884-
3885- llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3886-
3887- SLT_WRN (slot, " saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3888- (int ) slot.ctx_checkpoints .size (), params_base.n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3889- }
38903908 } else if (slot.state != SLOT_STATE_GENERATING) {
38913909 continue ; // continue loop of slots
38923910 }
0 commit comments