From fc6a7e0d0d1224187098ba4ee6e1634acaa4a385 Mon Sep 17 00:00:00 2001 From: Regrad Date: Tue, 2 Jun 2026 19:33:38 +0300 Subject: [PATCH] server: improve checkpoint reuse heuristics for recurrent/hybrid models --- common/common.cpp | 4 +++- common/common.h | 4 +++- tools/server/server-context.cpp | 17 ++++++++++++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 97daf281783..a4f542cb199 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2055,10 +2055,12 @@ void common_prompt_checkpoint::clear() { void common_prompt_checkpoint::update_pos( int64_t n_tokens, llama_pos pos_min, - llama_pos pos_max) { + llama_pos pos_max, + llama_pos pos_end) { this->n_tokens = n_tokens; this->pos_min = pos_min; this->pos_max = pos_max; + this->pos_end = pos_end; } void common_prompt_checkpoint::update_tgt( diff --git a/common/common.h b/common/common.h index 99898800d1d..15da8c429fb 100644 --- a/common/common.h +++ b/common/common.h @@ -1054,6 +1054,7 @@ struct common_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; + llama_pos pos_end; std::vector data_tgt; std::vector data_dft; @@ -1066,7 +1067,8 @@ struct common_prompt_checkpoint { void update_pos( int64_t n_tokens, llama_pos pos_min, - llama_pos pos_max); + llama_pos pos_max, + llama_pos pos_end); void update_tgt( llama_context * ctx, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ae9e0bf60d8..8ab0923b077 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2007,7 +2007,7 @@ struct server_context_impl { auto & cur = slot.prompt.checkpoints.emplace_back(); - cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); + cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max, slot.prompt.tokens.pos_next()); cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); @@ -2442,7 +2442,8 @@ struct server_context_impl { slot.spec_ckpt.update_pos( slot.prompt.n_tokens(), llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), - llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id), + slot.prompt.tokens.pos_next()); if (use_ckpt_dft) { slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); @@ -2714,6 +2715,10 @@ struct server_context_impl { } llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); + const llama_pos prompt_end = slot.task->tokens.pos_next(); + const bool is_recurrent_or_hybrid = + llama_model_is_recurrent(model_tgt) || + llama_model_is_hybrid(model_tgt); // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, pos_next - n_swa - 1); @@ -2774,9 +2779,15 @@ struct server_context_impl { slot.prompt.checkpoints.rbegin(), slot.prompt.checkpoints.rend(), [&, func_name = __func__](const auto & cur) { + if (cur.pos_end > prompt_end) { + return false; + } // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12, func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold); + if (is_recurrent_or_hybrid) { + return cur.pos_max < pos_next || cur.pos_min == 0; + } return cur.pos_min < pos_min_thold || cur.pos_min == 0; } ); @@ -2806,7 +2817,7 @@ struct server_context_impl { // erase any checkpoints with pos_max > pos_next for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; - if (cur.pos_max > pos_next) { + if (cur.pos_end > prompt_end || cur.pos_max > pos_next) { SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else {