-
Couldn't load subscription status.
- Fork 13.5k
implement context checkpointing for hybrid and recurrent models #16382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
6b3d5e2
257d492
cfba346
ba574ba
fa222c5
475e80b
4fee0cc
a3b4c17
d304f02
bb92d83
126e08a
9f996a7
e1b68d8
829c701
85d5053
6fc5bcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -764,7 +764,7 @@ struct completion_token_output { | |
| } | ||
| }; | ||
|
|
||
| struct swa_checkpoint { | ||
| struct ctx_checkpoint { | ||
| llama_pos pos_min; | ||
| llama_pos pos_max; | ||
|
|
||
|
|
@@ -1460,7 +1460,7 @@ struct server_slot { | |
|
|
||
| std::vector<completion_token_output> generated_token_probs; | ||
|
|
||
| std::vector<swa_checkpoint> swa_checkpoints; | ||
| std::vector<ctx_checkpoint> ctx_checkpoints; | ||
|
|
||
| bool has_next_token = true; | ||
| bool has_new_line = false; | ||
|
|
@@ -3555,38 +3555,38 @@ struct server_context { | |
| if (pos_min > pos_min_thold) { | ||
| SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); | ||
|
|
||
| // search for a SWA checkpoint | ||
| // search for a context checkpoint | ||
| const auto it = std::find_if( | ||
| slot.swa_checkpoints.rbegin(), | ||
| slot.swa_checkpoints.rend(), | ||
| slot.ctx_checkpoints.rbegin(), | ||
| slot.ctx_checkpoints.rend(), | ||
| [&](const auto & cur) { | ||
| return cur.pos_min <= pos_min_thold; | ||
| } | ||
| ); | ||
|
|
||
| bool do_reset = it == slot.swa_checkpoints.rend(); | ||
| bool do_reset = it == slot.ctx_checkpoints.rend(); | ||
| //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false"); | ||
|
|
||
| if (!do_reset) { | ||
| // restore the checkpoint | ||
| const size_t swa_size = it->data.size(); | ||
| const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); | ||
| // restore the context checkpoint | ||
| const size_t ctx_checkpoint_size = it->data.size(); | ||
| const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY); | ||
|
|
||
| if (n != swa_size) { | ||
| SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); | ||
| if (n != ctx_checkpoint_size) { | ||
| SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); | ||
| do_reset = true; | ||
| //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); | ||
| } else { | ||
| slot.n_past = std::min(slot.n_past, it->pos_max); | ||
|
|
||
| SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); | ||
| SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); | ||
| } | ||
| } | ||
|
|
||
| if (do_reset) { | ||
| SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", | ||
| SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", | ||
| "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); | ||
|
|
||
| slot.n_past = 0; | ||
| slot.swa_checkpoints.clear(); | ||
| slot.ctx_checkpoints.clear(); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -3595,21 +3595,20 @@ struct server_context { | |
| const auto pos_min_thold = std::max(0, slot.n_past - n_swa); | ||
|
|
||
| // erase any checkpoints with pos_min > pos_min_thold | ||
| for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) { | ||
| const auto & cur = slot.swa_checkpoints[i]; | ||
| for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) { | ||
| const auto & cur = slot.ctx_checkpoints[i]; | ||
| if (cur.pos_min > pos_min_thold) { | ||
| slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i); | ||
|
|
||
| SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||
| slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i); | ||
| SLT_WRN(slot, "erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { | ||
| SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); | ||
|
|
||
| SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens); | ||
| slot.n_past--; | ||
| SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); | ||
| } | ||
|
|
||
| slot.n_prompt_tokens_cache = slot.n_past; | ||
|
|
@@ -3623,17 +3622,17 @@ struct server_context { | |
| } | ||
| } | ||
|
|
||
| // keep only the common part | ||
| // truncate any tokens that are beyond n_past for this slot | ||
| if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { | ||
| // could not partially delete (likely using a non-Transformer model) | ||
| SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past); | ||
| llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); | ||
|
|
||
| // there is no common part left | ||
| slot.n_past = 0; | ||
| slot.n_prompt_tokens_cache = 0; | ||
| } | ||
|
|
||
| SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); | ||
| SLT_INF(slot, "n_past = %d\n", slot.n_past); | ||
ddh0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // remove the non-common part from the cache | ||
| slot.cache_tokens.keep_first(slot.n_past); | ||
|
|
@@ -3854,37 +3853,35 @@ struct server_context { | |
| // prompt evaluated for next-token prediction | ||
| slot.state = SLOT_STATE_GENERATING; | ||
|
|
||
| // make a checkpoint with the SWA memory | ||
| // checkpoints are needed only if we are not using "--swa-full" | ||
| if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) { | ||
| if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) { | ||
| { | ||
| const auto & cur = slot.swa_checkpoints.back(); | ||
|
|
||
| SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", | ||
| cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||
| } | ||
|
|
||
| slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); | ||
| // make a checkpoint of the parts of memory that cannot be rolled back. | ||
| // checkpoints are needed only if: | ||
| // - the model uses SWA and we are not using `swa_full` | ||
| // - the model architecture is marked as recurrent or hybrid | ||
| bool do_checkpoint = (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || | ||
| (llama_model_n_swa(model) > 0 && !params_base.swa_full); | ||
|
|
||
|
||
| if (do_checkpoint && params_base.n_ctx_checkpoints > 0) { | ||
| if (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { | ||
| // make room for the new checkpoint, if needed | ||
| const auto & cur = slot.ctx_checkpoints.back(); | ||
| SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", | ||
| cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||
|
|
||
| slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); | ||
| } | ||
|
|
||
| const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); | ||
| const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY); | ||
|
|
||
| auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{ | ||
| auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ | ||
| /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), | ||
| /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), | ||
| /*.data = */ std::vector<uint8_t>(swa_size), | ||
| /*.data = */ std::vector<uint8_t>(checkpoint_size), | ||
| }); | ||
|
|
||
| llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); | ||
|
|
||
| float size_total = 0.0f; | ||
| for (const auto & checkpoint : slot.swa_checkpoints) { | ||
| size_total += (float) checkpoint.data.size() / 1024 / 1024; | ||
| } | ||
| llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY); | ||
|
|
||
| SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n", | ||
| cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total); | ||
| SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", | ||
| (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||
| } | ||
| } else if (slot.state != SLOT_STATE_GENERATING) { | ||
| continue; // continue loop of slots | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a bit better name than this. The old name does not work, but the proposed new name is confusing.
The purpose of this flag is to indicate that we want save only the "small" caches such as SWA, "recr", etc. But I can't think of a good name to call it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean. I can't think of anything better at the moment