From 91cbeba33b1924e01865c5b3f6e8c64fee52aafd Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 21 May 2026 10:31:45 +0200 Subject: [PATCH 1/3] common : fix state save in common_prompt_batch_decode This commit addresses a bug in common_prompt_batch_decode that affects the session state store/restore in completion.cpp and save-load-state.cpp. The motivation for this is that currently the code is saving n-1 tokens in both the session_tokens and in the KV cache. Then when loading the session tokens, and if the prompt matches, it would replay the last saved token (n-1) into the next position, effectively replaying the same token in the wrong position. The fix is to store all n tokens in session_tokens, while the memory state only reflects n-1 processed tokens as the saving happens before the last token is decoded in common_prompt_batch_decode. I ran both completion.cpp and save-load-state.cpp with a transformer, a recurrent, and a hybrid model. Resolves: https://github.com/ggml-org/llama.cpp/issues/23400 Co-authored-by: fairydreaming <166155368+fairydreaming@users.noreply.github.com> --- common/common.cpp | 4 ++-- tests/test-save-load-state.cpp | 6 +++--- tools/completion/completion.cpp | 15 ++++++--------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d77ddeda10e..72425fa24a5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1989,8 +1989,8 @@ bool common_prompt_batch_decode( } n_past += n_tokens_before_last; - llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last); - LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last); + llama_state_save_file(ctx, state_path.data(), tokens.data(), n_eval); + LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_eval); llama_token last_token = tokens.back(); llama_batch batch = llama_batch_get_one(&last_token, 1); diff --git a/tests/test-save-load-state.cpp b/tests/test-save-load-state.cpp index 97ab7c6de3b..fd0f4b294fe 100644 --- a/tests/test-save-load-state.cpp +++ b/tests/test-save-load-state.cpp @@ -111,7 +111,7 @@ static bool test_state_load(struct llama_model * model, const struct common_para LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out); // Replay last token - int n_past = (int) n_token_count_out; + int n_past = (int) n_token_count_out - 1; if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) { return false; } @@ -165,7 +165,7 @@ static bool test_seq_cp_host(struct llama_model * model, const struct common_par LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out); // Replay last token - int n_past = (int) n_token_count_out; + int n_past = (int) n_token_count_out - 1; if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) { return false; } @@ -240,7 +240,7 @@ static bool test_seq_cp_device(struct llama_model * model, const struct common_p LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out); // Replay last token - int n_past = (int) n_token_count_out; + int n_past = (int) n_token_count_out - 1; if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) { return false; } diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index dffcadd4131..e4c807151a2 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -373,16 +373,10 @@ int llama_completion(int argc, char ** argv) { __func__, n_match, embd_inp.size()); } - if (session_tokens.size() == n_match) { - // [TAG_CONTEXT_STATE_LOGITS] - // in this case, we are going to reuse the logits from the session - // if we ever decide to remove the logits from the session, we need to handle this somehow - // ref: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941 - } - // remove any "future" tokens that we might have inherited from the previous session if (session_tokens.size() > n_match) { - if (!llama_memory_seq_rm(mem, -1, n_match, -1)) { + llama_pos pos = n_match > 0 ? (llama_pos)(n_match - 1) : 0; + if (!llama_memory_seq_rm(mem, -1, pos, -1)) { LOG_WRN("%s: unable to reuse common prefix (for example, when the memory is recurrent)\n", __func__); llama_memory_clear(mem, true); session_tokens.clear(); @@ -398,7 +392,7 @@ int llama_completion(int argc, char ** argv) { // Logits are not stored as part of the session state so we need to // "replay" the last token to get logits for sampling. if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { - if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) { + if (!common_replay_last_token(ctx, session_tokens.back(), n_match - 1)) { return 1; } @@ -991,7 +985,10 @@ int llama_completion(int argc, char ** argv) { if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { LOG("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); + session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + LOG_INF("saved final session to %s, n_tokens = %ld\n", path_session.data(), session_tokens.size()); + } LOG("\n\n"); From 411c926af1b3b9698a77cb257bb80f0e08cde605 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 21 May 2026 15:40:31 +0200 Subject: [PATCH 2/3] common : fix session storing (wip) --- common/common.cpp | 25 +++++++++++++------------ common/common.h | 3 ++- tests/test-save-load-state.cpp | 2 +- tools/completion/completion.cpp | 10 ++++++---- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 72425fa24a5..6c6928ca8d3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1963,36 +1963,37 @@ bool common_replay_last_token(struct llama_context * ctx, llama_token last_token bool common_prompt_batch_decode( struct llama_context * ctx, - const std::vector & tokens, + const std::vector & all_tokens, + int n_tokens, int & n_past, int n_batch, std::string_view state_path, bool save_state) { - const int n_eval = tokens.size(); - if (n_eval == 0) { + if (n_tokens == 0) { return true; } + const int offset = all_tokens.size() - n_tokens; - if (save_state && n_eval > 1) { - const int n_tokens_before_last = n_eval - 1; + if (save_state && n_tokens > 1) { + const int n_tokens_before_last = n_tokens - 1; - GGML_ASSERT(n_eval <= n_batch); + GGML_ASSERT(n_tokens <= n_batch); // Decode all but the last token so we can save the memory state before decoding the last token. // This is done so we can restore the session state later and replay the last token. // Memory implementations in recurrent/hybrid models don't support removing tokens from their // memory, so we can't just remove the last token from the memory and replay the last token which // is the reason for this logic. - if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_tokens_before_last))) { + if (llama_decode(ctx, llama_batch_get_one(const_cast(all_tokens.data() + offset), n_tokens_before_last))) { LOG_ERR("%s : failed to eval\n", __func__); return false; } n_past += n_tokens_before_last; - llama_state_save_file(ctx, state_path.data(), tokens.data(), n_eval); - LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_eval); + llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size()); + LOG_INF("saved session before last token to %s, n_tokens = %zu\n", state_path.data(), all_tokens.size()); - llama_token last_token = tokens.back(); + llama_token last_token = all_tokens.back(); llama_batch batch = llama_batch_get_one(&last_token, 1); int32_t pos = n_past; batch.pos = &pos; @@ -2003,11 +2004,11 @@ bool common_prompt_batch_decode( } n_past++; } else { - if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_eval))) { + if (llama_decode(ctx, llama_batch_get_one(const_cast(all_tokens.data() + offset), n_tokens))) { LOG_ERR("%s : failed to eval\n", __func__); return false; } - n_past += n_eval; + n_past += n_tokens; } return true; diff --git a/common/common.h b/common/common.h index dec90456afa..7f94d0f58e4 100644 --- a/common/common.h +++ b/common/common.h @@ -929,7 +929,8 @@ void common_batch_add( // tokens from memory, so this approach works across all model architectures. bool common_prompt_batch_decode( struct llama_context * ctx, - const std::vector & embd, + const std::vector & all_tokens, + int n_tokens, int & n_past, int n_batch, std::string_view state_path, diff --git a/tests/test-save-load-state.cpp b/tests/test-save-load-state.cpp index fd0f4b294fe..338bcde3097 100644 --- a/tests/test-save-load-state.cpp +++ b/tests/test-save-load-state.cpp @@ -63,7 +63,7 @@ static std::string test_baseline(struct llama_model * model, const struct common auto tokens = common_tokenize(ctx.get(), params.prompt, true); auto n_past = 0; - if (!common_prompt_batch_decode(ctx.get(), tokens, n_past, params.n_batch, params.out_file, true)) { + if (!common_prompt_batch_decode(ctx.get(), tokens, (int)tokens.size(), n_past, params.n_batch, params.out_file, true)) { LOG_ERR("%s: failed to decode prompt\n", __func__); return {}; } diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index e4c807151a2..6d2dcb56b2f 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -689,12 +689,14 @@ int llama_completion(int argc, char ** argv) { if (!embd.empty()) { const bool is_last_batch = (n_consumed >= (int) embd_inp.size()); const bool save_now = session_do_save && is_last_batch; - if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) { + session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); + if (!common_prompt_batch_decode(ctx, session_tokens, embd.size(), n_past, params.n_batch, path_session, save_now)) { return 1; } - session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); - n_session_consumed = session_tokens.size(); - session_do_save = false; + n_session_consumed += embd.size(); + if (save_now) { + session_do_save = false; + } LOG_DBG("n_past = %d\n", n_past); From 5f4ae00d96902d0f921b9dbe6a90f0b257e341d6 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 22 May 2026 10:22:26 +0200 Subject: [PATCH 3/3] common : rename n_tokens parameter to n_new --- common/common.cpp | 18 +++++++++--------- common/common.h | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6c6928ca8d3..5206743a64c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1964,20 +1964,20 @@ bool common_replay_last_token(struct llama_context * ctx, llama_token last_token bool common_prompt_batch_decode( struct llama_context * ctx, const std::vector & all_tokens, - int n_tokens, + int n_new, int & n_past, int n_batch, std::string_view state_path, bool save_state) { - if (n_tokens == 0) { + if (n_new == 0) { return true; } - const int offset = all_tokens.size() - n_tokens; + const int offset = all_tokens.size() - n_new; - if (save_state && n_tokens > 1) { - const int n_tokens_before_last = n_tokens - 1; + if (save_state && n_new > 1) { + const int n_tokens_before_last = n_new - 1; - GGML_ASSERT(n_tokens <= n_batch); + GGML_ASSERT(n_new <= n_batch); // Decode all but the last token so we can save the memory state before decoding the last token. // This is done so we can restore the session state later and replay the last token. @@ -1991,7 +1991,7 @@ bool common_prompt_batch_decode( n_past += n_tokens_before_last; llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size()); - LOG_INF("saved session before last token to %s, n_tokens = %zu\n", state_path.data(), all_tokens.size()); + LOG_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size()); llama_token last_token = all_tokens.back(); llama_batch batch = llama_batch_get_one(&last_token, 1); @@ -2004,11 +2004,11 @@ bool common_prompt_batch_decode( } n_past++; } else { - if (llama_decode(ctx, llama_batch_get_one(const_cast(all_tokens.data() + offset), n_tokens))) { + if (llama_decode(ctx, llama_batch_get_one(const_cast(all_tokens.data() + offset), n_new))) { LOG_ERR("%s : failed to eval\n", __func__); return false; } - n_past += n_tokens; + n_past += n_new; } return true; diff --git a/common/common.h b/common/common.h index 7f94d0f58e4..3fabde02a7c 100644 --- a/common/common.h +++ b/common/common.h @@ -930,7 +930,7 @@ void common_batch_add( bool common_prompt_batch_decode( struct llama_context * ctx, const std::vector & all_tokens, - int n_tokens, + int n_new, int & n_past, int n_batch, std::string_view state_path,