From 3ce90845dae04f0629964e4cba894a09c6e85a96 Mon Sep 17 00:00:00 2001 From: bjahoor Date: Fri, 15 May 2026 13:15:26 -0700 Subject: [PATCH] server: fix prompt-cache reuse for hybrid/recurrent models Three changes to tools/server/server-context.cpp restore working prompt-cache reuse on hybrid Mamba+attention architectures (Nemotron-H, Jamba, Qwen3.5/3.6/Next, Granite-H, Falcon-H1) which currently force full re-processing on every conversation turn. 1. Checkpoint search predicate: for hybrid/recurrent models pos_min always equals the sequence length, so the SWA-based check never matches. Use pos_max <= pos_next instead. 2. seq_rm failure handling: when partial seq_rm fails after a checkpoint was restored, keep the cached state instead of clearing the slot. 3. Checkpoint creation threshold: lower from 64 to 4 tokens for hybrid/recurrent models so short prompts can also be cached. Tested on Qwen3.6-27B (RTX 3090, original work by Tongas) and on Nemotron-Elastic-12B-A2B (Jetson AGX Xavier sm_72). Based on prior work: see PR description for full attribution. --- tools/server/server-context.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ce743e6656d..d046c62f97a 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2554,6 +2554,12 @@ struct server_context_impl { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12, func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold); + // for hybrid/recurrent models (DeltaNet, Mamba), pos_min always equals + // the full sequence length, so the SWA-based pos_min check always fails. + // use pos_max <= pos_next instead to find the most recent valid checkpoint. + if (llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt)) { + return cur.pos_max <= pos_next; + } return cur.pos_min < pos_min_thold || cur.pos_min == 0; } ); @@ -2626,12 +2632,17 @@ struct server_context_impl { SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL && slot.n_prompt_tokens_cache > 0) { + // hybrid/recurrent: partial seq_rm always fails, but checkpoint restored valid state + SLT_INF(slot, "seq_rm failed (expected for hybrid) - keeping %d cached tokens from checkpoint\n", slot.n_prompt_tokens_cache); + } else { + SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - slot.prompt_clear(true); + slot.prompt_clear(true); - // there is no common part left - slot.n_prompt_tokens_cache = 0; + // there is no common part left + slot.n_prompt_tokens_cache = 0; + } } else { if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { GGML_ABORT("failed to truncate draft context\n");