Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@
VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool | None = None
VLLM_MAMBA_ALIGN_GRANULAR_PREFILL: bool = False
VLLM_LOG_MODEL_INSPECTION: bool = False
VLLM_DEBUG_MFU_METRICS: bool = False
VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: bool = False
Expand Down Expand Up @@ -1891,6 +1892,17 @@ def _resolve_rust_frontend_path() -> str | None:
"VLLM_USE_V2_MODEL_RUNNER": lambda: maybe_convert_bool(
os.getenv("VLLM_USE_V2_MODEL_RUNNER", None)
),
# In Mamba cache 'align' mode, materialize and cache the Mamba state at
# every aligned block boundary by capping each prefill step to one aligned
# block. This enables partial prefix-cache hits for requests that share an
# early prefix but diverge later (e.g. incremental multimodal / agentic
# multi-turn), at the cost of prefill throughput (one block per step).
# Without it, a single large prefill chunk only caches the chunk's final
# boundary, so such requests miss the shared prefix entirely (see #43587).
"VLLM_MAMBA_ALIGN_GRANULAR_PREFILL": lambda: (
os.getenv("VLLM_MAMBA_ALIGN_GRANULAR_PREFILL", "0").strip().lower()
in ("1", "true")
),
# Log model inspection after loading.
# If enabled, logs a transformers-style hierarchical view of the model
# with quantization methods and attention backends.
Expand Down
16 changes: 16 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import replace
from typing import Any

import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
Expand Down Expand Up @@ -329,6 +330,21 @@ def _mamba_block_aligned_split(
else:
# prefill the last few tokens
pass
# In 'align' mode the Mamba state is only materialized (and thus
# cached) at the final aligned boundary of each prefill chunk; the
# intermediate boundaries within a chunk are stored in null blocks
# and never cached. A single large prefill chunk therefore caches
# only one boundary, so a later request that shares an early prefix
# but diverges before that boundary gets zero Mamba hits (#43587).
# Capping each prefill step to one aligned block materializes and
# caches every boundary's state, enabling partial prefix-cache hits
# for incremental workloads, at the cost of prefill throughput.
if (
envs.VLLM_MAMBA_ALIGN_GRANULAR_PREFILL
and num_computed_tokens < last_cache_position
and num_new_tokens > block_size
):
num_new_tokens = block_size
return num_new_tokens

def schedule(self) -> SchedulerOutput:
Expand Down
Loading