diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4f9a86cc1..7276059a2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2572,7 +2572,7 @@ struct server_context { GGML_ASSERT(slot.ga_n == 1); // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + slot.n_past = common_part(ctx, model, slot.cache_tokens, slot.prompt); // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5911eeeb7..3cd645f12 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -314,6 +314,31 @@ static size_t common_part(const std::string & a, const std::string & b) { return i; } +static size_t common_part(const llama_context * ctx, const llama_model * model, const std::vector & a, const std::string & b) { + size_t pos = 0; + size_t token_idx = 0; + + for (const auto & token : a) { + std::string piece = llama_token_to_piece(ctx, token); + + if (pos + piece.size() <= b.size() && b.compare(pos, piece.size(), piece) == 0) { + pos += piece.size(); + token_idx++; + continue; + } + + //Below is to handle the auto insert BOS case + if (token_idx == 0 && token == llama_token_bos(model)) { + token_idx++; + continue; + } + + return token_idx; + } + + return token_idx; +} + static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); }