From a7cfa4e1922b2079342618b99605db2a4ad77334 Mon Sep 17 00:00:00 2001 From: Reed Mayhew Date: Thu, 28 May 2026 07:08:50 -0400 Subject: [PATCH] server : checkpoint before every user turn boundary #22929 creates a context checkpoint only before the last user message, so prompts with a stable prefix and content that changes between turns lose all checkpoint cache hits (the surviving checkpoints sit past the divergence point) and re-evaluate the full prompt every turn, notably on SWA models. Create a checkpoint before every user message instead, derived from the same message_spans. prompt_get_n_before_user() becomes prompt_get_user_boundaries(); the prefill batch breaks at each boundary and a checkpoint is allowed at any of them, still bounded by --checkpoint-min-step and --ctx-checkpoints. --- tools/server/server-context.cpp | 105 +++++++++++++++++--------------- tools/server/server-task.h | 4 +- 2 files changed, 58 insertions(+), 51 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ae9e0bf60d8..6acd1063e71 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2920,8 +2920,10 @@ struct server_context_impl { has_mtmd = true; } - const int32_t n_before_user = slot.task->params.n_before_user; - const bool n_before_user_known = n_before_user > 0; + const auto & user_boundaries = slot.task->params.user_boundaries; + const auto is_user_boundary = [&user_boundaries](int32_t pos) { + return std::binary_search(user_boundaries.begin(), user_boundaries.end(), pos); + }; // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { @@ -2951,10 +2953,8 @@ struct server_context_impl { slot.n_prompt_tokens_processed++; - // stop the prompt batch exactly before the latest user input, so a checkpoint - // can be created after the previous messages - if (n_before_user_known && - slot.prompt.n_tokens() == n_before_user) { + // stop the prompt batch before each user message so a checkpoint can be created + if (is_user_boundary((int32_t) slot.prompt.n_tokens())) { break; } @@ -3000,7 +3000,7 @@ struct server_context_impl { slot.init_sampler(); } else { // skip ordinary mid-prompt checkpoints - if (!n_before_user_known && !near_prompt_end) { + if (user_boundaries.empty() && !near_prompt_end) { do_checkpoint = false; } } @@ -3012,21 +3012,12 @@ struct server_context_impl { // their token position is the batch start rather than the prompt end const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; - { - const bool is_on_user = - n_before_user_known && - n_tokens_start == n_before_user; - - const bool is_after_user = - n_before_user_known && - n_tokens_start > n_before_user; - + if (do_checkpoint && !user_boundaries.empty()) { const bool is_allowed = - !n_before_user_known || - is_on_user || - (is_after_user && near_prompt_end); + is_user_boundary(n_tokens_start) || + (n_tokens_start > user_boundaries.back() && near_prompt_end); - if (do_checkpoint && !is_allowed) { + if (!is_allowed) { do_checkpoint = false; } } @@ -3558,48 +3549,64 @@ void server_context::on_sleeping_changed(std::function callback) { impl->queue_tasks.on_sleeping_state(std::move(callback)); } -// compute the number of tokens before the last user message in the prompt -static int32_t prompt_get_n_before_user( - const json & message_spans, +static int32_t prompt_n_tokens_before_byte( + int32_t byte_pos, const std::string & prompt, const std::vector & files, const llama_vocab * vocab, mtmd_context * mctx) { - int32_t result = -1; - int32_t byte_pos = -1; + GGML_ASSERT(byte_pos >= 0 && (size_t) byte_pos <= prompt.size()); - for (const auto & span : message_spans) { - const std::string role = json_value(span, "role", std::string()); + const std::string prefix = prompt.substr(0, (size_t) byte_pos); - if (role == "user") { - byte_pos = json_value(span, "pos", -1); - } + const std::string marker = get_media_marker(); + size_t n_prefix_media = 0; + for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { + n_prefix_media++; } - if (byte_pos >= 0) { - GGML_ASSERT((size_t) byte_pos <= prompt.size()); + GGML_ASSERT(n_prefix_media <= files.size()); - const std::string prefix = prompt.substr(0, (size_t) byte_pos); + if (mctx != nullptr && n_prefix_media > 0) { + // TODO: this makes a copy - avoid it + std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); + return (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); + } - const std::string marker = get_media_marker(); - size_t n_prefix_media = 0; - for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { - n_prefix_media++; - } + return (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); +} + +// compute the number of tokens before each user message in the prompt +static std::vector prompt_get_user_boundaries( + const json & message_spans, + const std::string & prompt, + const std::vector & files, + const llama_vocab * vocab, + mtmd_context * mctx) { + std::vector result; + result.reserve(message_spans.size()); - GGML_ASSERT(n_prefix_media <= files.size()); + for (const auto & span : message_spans) { + if (json_value(span, "role", std::string()) != "user") { + continue; + } - if (mctx != nullptr && n_prefix_media > 0) { - // TODO: this makes a copy - avoid it - std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); + const int32_t byte_pos = json_value(span, "pos", -1); + if (byte_pos < 0) { + continue; + } - result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); - } else { - result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); + const int32_t n_tok = prompt_n_tokens_before_byte(byte_pos, prompt, files, vocab, mctx); + if (n_tok > 0) { + result.push_back(n_tok); } + } + + std::sort(result.begin(), result.end()); + result.erase(std::unique(result.begin(), result.end()), result.end()); - SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n", - byte_pos, n_prefix_media, result); + if (!result.empty()) { + SRV_TRC("message_spans: %zu user turn boundary(ies)\n", result.size()); } return result; @@ -3657,8 +3664,8 @@ std::unique_ptr server_routes::handle_completions_impl( const auto message_spans = json_value(data, "message_spans", json::array()); if (prompt.is_string() && message_spans.is_array()) { - task.params.n_before_user = - prompt_get_n_before_user( + task.params.user_boundaries = + prompt_get_user_boundaries( message_spans, prompt.get(), files, diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 60e216e7927..6576e9db73f 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -61,8 +61,8 @@ struct task_params { int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) - // number of prompt tokens before the latest user message - int32_t n_before_user = -1; + // number of prompt tokens before each user message + std::vector user_boundaries; int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit