From 3074d63cf93c2a4c11b49f7ab067623720a66661 Mon Sep 17 00:00:00 2001 From: hoobnn <111053672+hoobnn@users.noreply.github.com> Date: Tue, 26 May 2026 08:46:11 +0800 Subject: [PATCH] [V1][Mamba] Opt-in granular prefill for align-mode prefix caching (#43587) In `mamba_cache_mode="align"`, the Mamba state is only materialized (and therefore cached) at the final aligned block boundary of each prefill chunk; intermediate boundaries within a chunk live in null blocks and are skipped by `cache_full_blocks`. With the default large `max_num_batched_tokens`, the whole prefix is a single chunk, so a request caches only one Mamba boundary -- its last full block. A later request that shares an early prefix but diverges before that boundary (e.g. incremental multimodal or agentic multi-turn with a fixed instruction suffix) gets zero Mamba cache hits, and since `BlockPool.get_cached_block` requires a hit in every KV-cache group, the whole request reports `num_cached_tokens == 0` even though the shared prefix blocks have identical hashes. Add an opt-in env var `VLLM_MAMBA_ALIGN_GRANULAR_PREFILL` that caps each align-mode prefill step to one aligned block, so every boundary's Mamba state is materialized and cached, enabling partial prefix-cache hits for these workloads. It reuses the existing, validated align caching path (state computation/writeback is unchanged), so generated outputs are unaffected. The trade-off is prefill throughput (one block per step), hence it defaults to off. Co-Authored-By: Claude Opus 4.7 Signed-off-by: hoobnn <111053672+hoobnn@users.noreply.github.com> --- vllm/envs.py | 12 ++++++++++++ vllm/v1/core/sched/scheduler.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index f5b2759e9934..7519ab37c103 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -262,6 +262,7 @@ VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool | None = None + VLLM_MAMBA_ALIGN_GRANULAR_PREFILL: bool = False VLLM_LOG_MODEL_INSPECTION: bool = False VLLM_DEBUG_MFU_METRICS: bool = False VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: bool = False @@ -1897,6 +1898,17 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_USE_V2_MODEL_RUNNER": lambda: maybe_convert_bool( os.getenv("VLLM_USE_V2_MODEL_RUNNER", None) ), + # In Mamba cache 'align' mode, materialize and cache the Mamba state at + # every aligned block boundary by capping each prefill step to one aligned + # block. This enables partial prefix-cache hits for requests that share an + # early prefix but diverge later (e.g. incremental multimodal / agentic + # multi-turn), at the cost of prefill throughput (one block per step). + # Without it, a single large prefill chunk only caches the chunk's final + # boundary, so such requests miss the shared prefix entirely (see #43587). + "VLLM_MAMBA_ALIGN_GRANULAR_PREFILL": lambda: ( + os.getenv("VLLM_MAMBA_ALIGN_GRANULAR_PREFILL", "0").strip().lower() + in ("1", "true") + ), # Log model inspection after loading. # If enabled, logs a transformers-style hierarchical view of the model # with quantization methods and attention backends. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c69c9a8119ab..37e94c9c3292 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,6 +7,7 @@ from dataclasses import replace from typing import Any +import vllm.envs as envs from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( @@ -324,6 +325,21 @@ def _mamba_block_aligned_split( else: # prefill the last few tokens pass + # In 'align' mode the Mamba state is only materialized (and thus + # cached) at the final aligned boundary of each prefill chunk; the + # intermediate boundaries within a chunk are stored in null blocks + # and never cached. A single large prefill chunk therefore caches + # only one boundary, so a later request that shares an early prefix + # but diverges before that boundary gets zero Mamba hits (#43587). + # Capping each prefill step to one aligned block materializes and + # caches every boundary's state, enabling partial prefix-cache hits + # for incremental workloads, at the cost of prefill throughput. + if ( + envs.VLLM_MAMBA_ALIGN_GRANULAR_PREFILL + and num_computed_tokens < last_cache_position + and num_new_tokens > block_size + ): + num_new_tokens = block_size return num_new_tokens def schedule(self) -> SchedulerOutput: