Skip to content

Commit 829c701

Browse files
committed
server : fix checkpoint logic to support recurrent caches
1 parent e1b68d8 commit 829c701

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

tools/server/server.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3541,7 +3541,11 @@ struct server_context {
35413541
slot.n_past = 0;
35423542
}
35433543

3544-
const auto n_swa = llama_model_n_swa(model);
3544+
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545+
const auto n_swa = std::max(1, llama_model_n_swa(model));
3546+
3547+
// the largest pos_min required for a checkpoint to be useful
3548+
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
35453549

35463550
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
35473551
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@@ -3550,17 +3554,16 @@ struct server_context {
35503554
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
35513555
}
35523556

3553-
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
3554-
3555-
if (pos_min > pos_min_thold + 1) {
3557+
if (pos_min > pos_min_thold) {
35563558
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
35573559

35583560
// search for a context checkpoint
35593561
const auto it = std::find_if(
35603562
slot.ctx_checkpoints.rbegin(),
35613563
slot.ctx_checkpoints.rend(),
35623564
[&](const auto & cur) {
3563-
return cur.pos_min <= pos_min_thold;
3565+
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3566+
return cur.pos_min < pos_min_thold;
35643567
}
35653568
);
35663569

@@ -3577,7 +3580,7 @@ struct server_context {
35773580
do_reset = true;
35783581
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35793582
} else {
3580-
slot.n_past = std::min(slot.n_past, it->pos_max);
3583+
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
35813584
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
35823585
}
35833586
}
@@ -3586,25 +3589,23 @@ struct server_context {
35863589
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
35873590
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
35883591
slot.n_past = 0;
3589-
slot.ctx_checkpoints.clear();
35903592
}
35913593
}
35923594
}
35933595

3594-
if (n_swa > 0) {
3595-
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
3596-
3596+
{
35973597
// erase any checkpoints with pos_min > pos_min_thold
35983598
for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
35993599
const auto & cur = slot.ctx_checkpoints[i];
36003600
if (cur.pos_min > pos_min_thold) {
3601-
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
36023601
SLT_WRN(slot, "erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
3602+
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
36033603
}
36043604
}
36053605
}
36063606
}
36073607

3608+
// [TAG_PROMPT_LOGITS]
36083609
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
36093610
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
36103611
slot.n_past--;

0 commit comments

Comments
 (0)