Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1963,36 +1963,37 @@ bool common_replay_last_token(struct llama_context * ctx, llama_token last_token

bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & all_tokens,
int n_new,
int & n_past,
int n_batch,
std::string_view state_path,
bool save_state) {
const int n_eval = tokens.size();
if (n_eval == 0) {
if (n_new == 0) {
return true;
}
const int offset = all_tokens.size() - n_new;

if (save_state && n_eval > 1) {
const int n_tokens_before_last = n_eval - 1;
if (save_state && n_new > 1) {
const int n_tokens_before_last = n_new - 1;

GGML_ASSERT(n_eval <= n_batch);
GGML_ASSERT(n_new <= n_batch);

// Decode all but the last token so we can save the memory state before decoding the last token.
// This is done so we can restore the session state later and replay the last token.
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
// memory, so we can't just remove the last token from the memory and replay the last token which
// is the reason for this logic.
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_tokens_before_last))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_tokens_before_last;

llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last);
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last);
llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size());
LOG_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());

llama_token last_token = tokens.back();
llama_token last_token = all_tokens.back();
llama_batch batch = llama_batch_get_one(&last_token, 1);
int32_t pos = n_past;
batch.pos = &pos;
Expand All @@ -2003,11 +2004,11 @@ bool common_prompt_batch_decode(
}
n_past++;
} else {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_new))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
n_past += n_new;
}

return true;
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,8 @@ void common_batch_add(
// tokens from memory, so this approach works across all model architectures.
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & embd,
const std::vector<llama_token> & all_tokens,
int n_new,
int & n_past,
int n_batch,
std::string_view state_path,
Expand Down
8 changes: 4 additions & 4 deletions tests/test-save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ static std::string test_baseline(struct llama_model * model, const struct common
auto tokens = common_tokenize(ctx.get(), params.prompt, true);

auto n_past = 0;
if (!common_prompt_batch_decode(ctx.get(), tokens, n_past, params.n_batch, params.out_file, true)) {
if (!common_prompt_batch_decode(ctx.get(), tokens, (int)tokens.size(), n_past, params.n_batch, params.out_file, true)) {
LOG_ERR("%s: failed to decode prompt\n", __func__);
return {};
}
Expand Down Expand Up @@ -111,7 +111,7 @@ static bool test_state_load(struct llama_model * model, const struct common_para
LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out);

// Replay last token
int n_past = (int) n_token_count_out;
int n_past = (int) n_token_count_out - 1;
if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) {
return false;
}
Expand Down Expand Up @@ -165,7 +165,7 @@ static bool test_seq_cp_host(struct llama_model * model, const struct common_par
LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out);

// Replay last token
int n_past = (int) n_token_count_out;
int n_past = (int) n_token_count_out - 1;
if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) {
return false;
}
Expand Down Expand Up @@ -240,7 +240,7 @@ static bool test_seq_cp_device(struct llama_model * model, const struct common_p
LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out);

// Replay last token
int n_past = (int) n_token_count_out;
int n_past = (int) n_token_count_out - 1;
if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) {
return false;
}
Expand Down
25 changes: 12 additions & 13 deletions tools/completion/completion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,16 +373,10 @@ int llama_completion(int argc, char ** argv) {
__func__, n_match, embd_inp.size());
}

if (session_tokens.size() == n_match) {
// [TAG_CONTEXT_STATE_LOGITS]
// in this case, we are going to reuse the logits from the session
// if we ever decide to remove the logits from the session, we need to handle this somehow
// ref: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941
}

// remove any "future" tokens that we might have inherited from the previous session
if (session_tokens.size() > n_match) {
if (!llama_memory_seq_rm(mem, -1, n_match, -1)) {
llama_pos pos = n_match > 0 ? (llama_pos)(n_match - 1) : 0;
if (!llama_memory_seq_rm(mem, -1, pos, -1)) {
LOG_WRN("%s: unable to reuse common prefix (for example, when the memory is recurrent)\n", __func__);
llama_memory_clear(mem, true);
session_tokens.clear();
Expand All @@ -398,7 +392,7 @@ int llama_completion(int argc, char ** argv) {
// Logits are not stored as part of the session state so we need to
// "replay" the last token to get logits for sampling.
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) {
if (!common_replay_last_token(ctx, session_tokens.back(), n_match - 1)) {
return 1;
}

Expand Down Expand Up @@ -695,12 +689,14 @@ int llama_completion(int argc, char ** argv) {
if (!embd.empty()) {
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
const bool save_now = session_do_save && is_last_batch;
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
if (!common_prompt_batch_decode(ctx, session_tokens, embd.size(), n_past, params.n_batch, path_session, save_now)) {
return 1;
}
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
session_do_save = false;
n_session_consumed += embd.size();
if (save_now) {
session_do_save = false;
}

LOG_DBG("n_past = %d\n", n_past);

Expand Down Expand Up @@ -991,7 +987,10 @@ int llama_completion(int argc, char ** argv) {

if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
LOG("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG_INF("saved final session to %s, n_tokens = %ld\n", path_session.data(), session_tokens.size());

}

LOG("\n\n");
Expand Down
Loading