diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index bfe3443c1de..95c2ffd6971 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2928,8 +2928,10 @@ struct server_context_impl { has_mtmd = true; } - const int32_t n_before_user = slot.task->params.n_before_user; - const bool n_before_user_known = n_before_user > 0; + const auto & user_boundaries = slot.task->params.user_boundaries; + const auto is_user_boundary = [&user_boundaries](int32_t pos) { + return std::binary_search(user_boundaries.begin(), user_boundaries.end(), pos); + }; // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { @@ -2959,10 +2961,8 @@ struct server_context_impl { slot.n_prompt_tokens_processed++; - // stop the prompt batch exactly before the latest user input, so a checkpoint - // can be created after the previous messages - if (n_before_user_known && - slot.prompt.n_tokens() == n_before_user) { + // stop the prompt batch before each user message so a checkpoint can be created + if (is_user_boundary((int32_t) slot.prompt.n_tokens())) { break; } @@ -3008,7 +3008,7 @@ struct server_context_impl { slot.init_sampler(); } else { // skip ordinary mid-prompt checkpoints - if (!n_before_user_known && !near_prompt_end) { + if (user_boundaries.empty() && !near_prompt_end) { do_checkpoint = false; } } @@ -3020,21 +3020,12 @@ struct server_context_impl { // their token position is the batch start rather than the prompt end const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; - { - const bool is_on_user = - n_before_user_known && - n_tokens_start == n_before_user; - - const bool is_after_user = - n_before_user_known && - n_tokens_start > n_before_user; - + if (do_checkpoint && !user_boundaries.empty()) { const bool is_allowed = - !n_before_user_known || - is_on_user || - (is_after_user && near_prompt_end); + is_user_boundary(n_tokens_start) || + (n_tokens_start > user_boundaries.back() && near_prompt_end); - if (do_checkpoint && !is_allowed) { + if (!is_allowed) { do_checkpoint = false; } } @@ -3566,48 +3557,64 @@ void server_context::on_sleeping_changed(std::function callback) { impl->queue_tasks.on_sleeping_state(std::move(callback)); } -// compute the number of tokens before the last user message in the prompt -static int32_t prompt_get_n_before_user( - const json & message_spans, +static int32_t prompt_n_tokens_before_byte( + int32_t byte_pos, const std::string & prompt, const std::vector & files, const llama_vocab * vocab, mtmd_context * mctx) { - int32_t result = -1; - int32_t byte_pos = -1; + GGML_ASSERT(byte_pos >= 0 && (size_t) byte_pos <= prompt.size()); - for (const auto & span : message_spans) { - const std::string role = json_value(span, "role", std::string()); + const std::string prefix = prompt.substr(0, (size_t) byte_pos); - if (role == "user") { - byte_pos = json_value(span, "pos", -1); - } + const std::string marker = get_media_marker(); + size_t n_prefix_media = 0; + for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { + n_prefix_media++; } - if (byte_pos >= 0) { - GGML_ASSERT((size_t) byte_pos <= prompt.size()); + GGML_ASSERT(n_prefix_media <= files.size()); - const std::string prefix = prompt.substr(0, (size_t) byte_pos); + if (mctx != nullptr && n_prefix_media > 0) { + // TODO: this makes a copy - avoid it + std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); + return (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); + } - const std::string marker = get_media_marker(); - size_t n_prefix_media = 0; - for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { - n_prefix_media++; - } + return (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); +} + +// compute the number of tokens before each user message in the prompt +static std::vector prompt_get_user_boundaries( + const json & message_spans, + const std::string & prompt, + const std::vector & files, + const llama_vocab * vocab, + mtmd_context * mctx) { + std::vector result; + result.reserve(message_spans.size()); - GGML_ASSERT(n_prefix_media <= files.size()); + for (const auto & span : message_spans) { + if (json_value(span, "role", std::string()) != "user") { + continue; + } - if (mctx != nullptr && n_prefix_media > 0) { - // TODO: this makes a copy - avoid it - std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); + const int32_t byte_pos = json_value(span, "pos", -1); + if (byte_pos < 0) { + continue; + } - result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); - } else { - result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); + const int32_t n_tok = prompt_n_tokens_before_byte(byte_pos, prompt, files, vocab, mctx); + if (n_tok > 0) { + result.push_back(n_tok); } + } + + std::sort(result.begin(), result.end()); + result.erase(std::unique(result.begin(), result.end()), result.end()); - SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n", - byte_pos, n_prefix_media, result); + if (!result.empty()) { + SRV_TRC("message_spans: %zu user turn boundary(ies)\n", result.size()); } return result; @@ -3665,8 +3672,8 @@ std::unique_ptr server_routes::handle_completions_impl( const auto message_spans = json_value(data, "message_spans", json::array()); if (prompt.is_string() && message_spans.is_array()) { - task.params.n_before_user = - prompt_get_n_before_user( + task.params.user_boundaries = + prompt_get_user_boundaries( message_spans, prompt.get(), files, diff --git a/tools/server/server-task.h b/tools/server/server-task.h index d47dc690cff..906867f8808 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -61,8 +61,8 @@ struct task_params { int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) - // number of prompt tokens before the latest user message - int32_t n_before_user = -1; + // number of prompt tokens before each user message + std::vector user_boundaries; int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit