-
Couldn't load subscription status.
- Fork 13.4k
implement context checkpointing for hybrid and recurrent models #16382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6b3d5e2
257d492
cfba346
ba574ba
fa222c5
475e80b
4fee0cc
a3b4c17
d304f02
bb92d83
126e08a
9f996a7
e1b68d8
829c701
85d5053
6fc5bcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -543,6 +543,9 @@ extern "C" { | |
| // Returns true if the model is recurrent (like Mamba, RWKV, etc.) | ||
| LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); | ||
|
|
||
| // Returns true if the model is hybrid (like Jamba, Granite, etc.) | ||
| LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); | ||
|
|
||
| // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) | ||
| LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); | ||
|
|
||
|
|
@@ -791,7 +794,7 @@ extern "C" { | |
| size_t n_token_capacity, | ||
| size_t * n_token_count_out); | ||
|
|
||
| #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 | ||
| #define LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY 1 | ||
|
||
|
|
||
| typedef uint32_t llama_state_seq_flags; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -764,7 +764,7 @@ struct completion_token_output { | |
| } | ||
| }; | ||
|
|
||
| struct swa_checkpoint { | ||
| struct ctx_checkpoint { | ||
| llama_pos pos_min; | ||
| llama_pos pos_max; | ||
|
|
||
|
|
@@ -1460,7 +1460,7 @@ struct server_slot { | |
|
|
||
| std::vector<completion_token_output> generated_token_probs; | ||
|
|
||
| std::vector<swa_checkpoint> swa_checkpoints; | ||
| std::vector<ctx_checkpoint> ctx_checkpoints; | ||
|
|
||
| bool has_next_token = true; | ||
| bool has_new_line = false; | ||
|
|
@@ -3555,38 +3555,38 @@ struct server_context { | |
| if (pos_min > pos_min_thold) { | ||
| SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); | ||
|
|
||
| // search for a SWA checkpoint | ||
| // search for a context checkpoint | ||
| const auto it = std::find_if( | ||
| slot.swa_checkpoints.rbegin(), | ||
| slot.swa_checkpoints.rend(), | ||
| slot.ctx_checkpoints.rbegin(), | ||
| slot.ctx_checkpoints.rend(), | ||
| [&](const auto & cur) { | ||
| return cur.pos_min <= pos_min_thold; | ||
| } | ||
| ); | ||
|
|
||
| bool do_reset = it == slot.swa_checkpoints.rend(); | ||
| bool do_reset = it == slot.ctx_checkpoints.rend(); | ||
| //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false"); | ||
|
|
||
| if (!do_reset) { | ||
| // restore the checkpoint | ||
| const size_t swa_size = it->data.size(); | ||
| const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); | ||
| // restore the context checkpoint | ||
| const size_t ctx_checkpoint_size = it->data.size(); | ||
| const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY); | ||
|
|
||
| if (n != swa_size) { | ||
| SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); | ||
| if (n != ctx_checkpoint_size) { | ||
| SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); | ||
| do_reset = true; | ||
| //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); | ||
| } else { | ||
| slot.n_past = std::min(slot.n_past, it->pos_max); | ||
|
|
||
| SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); | ||
| SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); | ||
| } | ||
| } | ||
|
|
||
| if (do_reset) { | ||
| SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", | ||
| "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); | ||
|
|
||
| slot.n_past = 0; | ||
| slot.swa_checkpoints.clear(); | ||
| slot.ctx_checkpoints.clear(); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -3595,21 +3595,20 @@ struct server_context { | |
| const auto pos_min_thold = std::max(0, slot.n_past - n_swa); | ||
|
|
||
| // erase any checkpoints with pos_min > pos_min_thold | ||
| for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) { | ||
| const auto & cur = slot.swa_checkpoints[i]; | ||
| for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) { | ||
| const auto & cur = slot.ctx_checkpoints[i]; | ||
| if (cur.pos_min > pos_min_thold) { | ||
| slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i); | ||
|
|
||
| SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||
| slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i); | ||
| SLT_WRN(slot, "erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { | ||
| SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); | ||
|
|
||
| SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens); | ||
| slot.n_past--; | ||
| SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); | ||
| } | ||
|
|
||
| slot.n_prompt_tokens_cache = slot.n_past; | ||
|
|
@@ -3623,17 +3622,17 @@ struct server_context { | |
| } | ||
| } | ||
|
|
||
| // keep only the common part | ||
| // truncate any tokens that are beyond n_past for this slot | ||
| if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { | ||
| // could not partially delete (likely using a non-Transformer model) | ||
| SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past); | ||
| llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); | ||
|
|
||
| // there is no common part left | ||
| slot.n_past = 0; | ||
| slot.n_prompt_tokens_cache = 0; | ||
| } | ||
|
|
||
| SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); | ||
| SLT_INF(slot, "n_past = %d\n", slot.n_past); | ||
ddh0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // remove the non-common part from the cache | ||
| slot.cache_tokens.keep_first(slot.n_past); | ||
|
|
@@ -3854,37 +3853,35 @@ struct server_context { | |
| // prompt evaluated for next-token prediction | ||
| slot.state = SLOT_STATE_GENERATING; | ||
|
|
||
| // make a checkpoint with the SWA memory | ||
| // checkpoints are needed only if we are not using "--swa-full" | ||
| if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) { | ||
| if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) { | ||
| { | ||
| const auto & cur = slot.swa_checkpoints.back(); | ||
|
|
||
| SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", | ||
| cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||
| } | ||
|
|
||
| slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); | ||
| // make a checkpoint of the parts of memory that cannot be rolled back. | ||
| // checkpoints are needed only if: | ||
| // - the model uses SWA and we are not using `swa_full` | ||
| // - the model architecture is marked as recurrent or hybrid | ||
| bool do_checkpoint = (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || | ||
| (llama_model_n_swa(model) > 0 && !params_base.swa_full); | ||
|
|
||
|
||
| if (do_checkpoint && params_base.n_ctx_checkpoints > 0) { | ||
| if (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { | ||
| // make room for the new checkpoint, if needed | ||
| const auto & cur = slot.ctx_checkpoints.back(); | ||
| SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", | ||
| cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||
|
|
||
| slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); | ||
| } | ||
|
|
||
| const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); | ||
| const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY); | ||
|
|
||
| auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{ | ||
| auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ | ||
| /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), | ||
| /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), | ||
| /*.data = */ std::vector<uint8_t>(swa_size), | ||
| /*.data = */ std::vector<uint8_t>(checkpoint_size), | ||
| }); | ||
|
|
||
| llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); | ||
|
|
||
| float size_total = 0.0f; | ||
| for (const auto & checkpoint : slot.swa_checkpoints) { | ||
| size_total += (float) checkpoint.data.size() / 1024 / 1024; | ||
| } | ||
| llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY); | ||
|
|
||
| SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n", | ||
| cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total); | ||
| SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", | ||
| (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||
| } | ||
| } else if (slot.state != SLOT_STATE_GENERATING) { | ||
| continue; // continue loop of slots | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.