Skip to content

[V1] [Hybrid] Lighter Mamba Prefix Caching for Hybrid Models#28176

Closed
peakcrosser7 wants to merge 3 commits intovllm-project:mainfrom
peakcrosser7:ups/mamba_prefix_cache
Closed

[V1] [Hybrid] Lighter Mamba Prefix Caching for Hybrid Models#28176
peakcrosser7 wants to merge 3 commits intovllm-project:mainfrom
peakcrosser7:ups/mamba_prefix_cache

Conversation

@peakcrosser7
Copy link
Contributor

@peakcrosser7 peakcrosser7 commented Nov 6, 2025

Purpose

Currently, Automatic Prefix Caching for Mamba-based hybrid models does not support architectures such as GDN. To address this, we propose a lightweight Mamba Prefix Caching design called Lighter-Mamba-Prefix-Cache.
Its core idea is to directly cache Mamba states using a block-aligned scheduling approach, enabling rapid support for Prefix Caching in Mamba models without modifying any kernel code, while maintaining compatibility with SPS, MTP, and Eagle.
This solution has already been validated on Qwen3-Next-80B-A3B-Instruct.

Design Details

Block Allocation Design

For each request, the number of blocks allocated per Mamba group is changed from the original fixed 1 + sps to 2 + sps + N, where:

  • The 1st block: Used to load the Mamba state from a prefix-caching hit. If there is no cache hit, a null-block is used as a placeholder.
  • The next 1 + sps blocks: Reserved for runtime usage, identical to the original 1 + sps blocks without prefix caching.
  • The following N blocks: Used to cache Mamba states.
block_alloc

Prefix Matching Logic

scheduler_logic worker_logic

Block-Aligned Scheduling

Since requests are hashed at the granularity of block_size, Mamba states must be aligned to block_size boundaries before caching. This ensures that each Mamba state corresponds to exactly one block hash.
Lighter prefix cache stores variable-length chunk states—i.e., the number of tokens (or the incremental length) associated with each cached Mamba state may vary, but it is always a multiple of block_size.

With Mamba Prefix Caching enabled, the scheduler behaves as follows:

  • Decode requests: Scheduling logic remains unchanged (Lighter prefix caching is not yet supported for decode).
  • Prefill requests:
    • The number of tokens scheduled per step must be an integer multiple of block_size, except for the final chunk of the request.
    • The last prefill chunk is split to align with block_size, ensuring its size is ≤ block_size. This maximizes the length of the prompt that can be cached during the prefill phase.

Test Plan

TODO

Test Result

TODO


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@mergify
Copy link

mergify bot commented Nov 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @peakcrosser7.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 6, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a lightweight prefix caching mechanism for Mamba-based hybrid models, which is a significant feature addition. The implementation seems well-thought-out and consistent with the design described. The changes span across the scheduler, cache manager, and model runner to support block-aligned caching of Mamba states. I've identified one critical issue regarding the handling of a new environment variable, which could lead to incorrect behavior. Other than that, the changes look good.

Comment on lines +1455 to +1457
"VLLM_USE_LIGHTER_MAMBA_CACHE": lambda: os.getenv(
"VLLM_USE_LIGHTER_MAMBA_CACHE", False
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for parsing the VLLM_USE_LIGHTER_MAMBA_CACHE environment variable is incorrect. The lambda lambda: os.getenv("VLLM_USE_LIGHTER_MAMBA_CACHE", False) will evaluate to a truthy value for any non-empty string, including "0", which is likely not the intended behavior for a boolean flag. This can lead to the feature being unintentionally enabled. To ensure correct boolean parsing, it should be compared against "1", similar to how other boolean flags are handled in this file.

Suggested change
"VLLM_USE_LIGHTER_MAMBA_CACHE": lambda: os.getenv(
"VLLM_USE_LIGHTER_MAMBA_CACHE", False
),
"VLLM_USE_LIGHTER_MAMBA_CACHE": lambda: os.getenv(
"VLLM_USE_LIGHTER_MAMBA_CACHE", "0"
) == "1",

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1455 to +1457
"VLLM_USE_LIGHTER_MAMBA_CACHE": lambda: os.getenv(
"VLLM_USE_LIGHTER_MAMBA_CACHE", False
),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Parse VLLM_USE_LIGHTER_MAMBA_CACHE as boolean

The new env flag is exposed as lambda: os.getenv("VLLM_USE_LIGHTER_MAMBA_CACHE", False). Unlike the rest of the boolean envs (which run the value through bool(int(...))), this returns the raw string. As a consequence setting VLLM_USE_LIGHTER_MAMBA_CACHE=0 or False will still be truthy and the lighter cache path is enabled unintentionally. This can enable an experimental code path for all deployments even when the user explicitly disables it. The getter should coerce the string to a real boolean, e.g. bool(int(os.getenv(..., "0"))).

Useful? React with 👍 / 👎.

Comment on lines 263 to +275
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
# We allocate 1 block for each request now, so max_memory_usage_bytes is
# the same as page_size_bytes.
# Need to update this when supporting prefix caching.
if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
else:
# NOTE: We allocate 1 block per request by default. With prefix
# caching enabled, up to 2 additional blocks are required: one
# for reading the matched prefix and one for caching the current
# state.
return self.page_size_bytes * (3 if self.enable_caching else 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Include speculative blocks in Mamba memory estimate

In the lighter Mamba branch MambaSpec.max_memory_usage_bytes now returns page_size_bytes * (3 if self.enable_caching else 1) regardless of num_speculative_blocks. However, allocation paths still reserve 1 + num_speculative_blocks blocks (plus an extra for caching) when speculative decoding (EAGLE/MTP) is active. With num_speculative_blocks > 0 the memory calculation now underestimates the number of blocks per request, so the block pool will be sized for at most 3 blocks while execution tries to allocate 4+, causing allocation failures or unexpected preemption. The returned size should include num_speculative_blocks in the multiplier.

Useful? React with 👍 / 👎.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If I understand correctly, we will always cache the state at token num_tokens - num_tokens%block_size, but there is no ensure that tokens in other positions are cached. In this case, can you ensure system prompt is cached?
  2. Memory layout: I feel that we don't need such a new kv cache memory design. We can have a block_id list with length num_tokens / block_size + num_spec_decode_tokens, and always make sure that the state of the previous schedule step is at block num_computed_tokens / block_size

For example, if num_computed_tokens=29 and we schedule 1 new token, the kv cache before this step is:
block 0: N/A
block 1: token 29
block 2: N/A
block 3: N/A

-> run main model
block 0: N/A
block 1: token 30
block 2: token 31
block 3: token 32
-> adjust kv cache based on number of accepted tokens

  • if no token is accepted:
    block 0: N/A
    block 1: token 30
    block 2: N/A
    block 3: N/A
  • if token 30 is accepted:
    block 0: N/A
    block 1: token 31
    block 2: N/A
    block 3: N/A
  • if token 30 & 31 are accepted:
    block 0: N/A
    block 1: token 31
    block 2: token 32
    block 3: N/A
  • if token 30 & 31 & 32 are accepted:
    block 0: token 15
    block 1: token 31
    block 2: token 33
    block 3: N/A

Then, in the next schedule step, the previous state is always at block [num_computed_tokens / block_size]

  1. I'm still concerning about whether we should increase the complexity of scheduler to avoid kernel changes of mamba layers.

@minminsun
Copy link

  1. If I understand correctly, we will always cache the state at token num_tokens - num_tokens%block_size, but there is no ensure that tokens in other positions are cached. In this case, can you ensure system prompt is cached?
  2. Memory layout: I feel that we don't need such a new kv cache memory design. We can have a block_id list with length num_tokens / block_size + num_spec_decode_tokens, and always make sure that the state of the previous schedule step is at block num_computed_tokens / block_size

For example, if num_computed_tokens=29 and we schedule 1 new token, the kv cache before this step is: block 0: N/A block 1: token 29 block 2: N/A block 3: N/A

-> run main model block 0: N/A block 1: token 30 block 2: token 31 block 3: token 32 -> adjust kv cache based on number of accepted tokens

  • if no token is accepted:
    block 0: N/A
    block 1: token 30
    block 2: N/A
    block 3: N/A
  • if token 30 is accepted:
    block 0: N/A
    block 1: token 31
    block 2: N/A
    block 3: N/A
  • if token 30 & 31 are accepted:
    block 0: N/A
    block 1: token 31
    block 2: token 32
    block 3: N/A
  • if token 30 & 31 & 32 are accepted:
    block 0: token 15
    block 1: token 31
    block 2: token 33
    block 3: N/A

Then, in the next schedule step, the previous state is always at block [num_computed_tokens / block_size]

  1. I'm still concerning about whether we should increase the complexity of scheduler to avoid kernel changes of mamba layers.

Thank you for your attention to this PR.

Regarding the first question: There is insufficient memory to store per-token state, so we cannot guarantee caching of system prompts especially when their token count is less than block_size.

As for the third concern: We believe modifying the scheduler is significantly less complex than altering attention kernels. Besides FLA, multiple kernel implementations exist for linear attention, making it impractical to update all of them. Moreover, requiring kernels to retain excessive internal token states would degrade performance.

This prefix-caching solution has been stably deployed in Alibaba Cloud’s Qwen3-Next online serving system for near 1 month, maintaining a consistently healthy cache hit ratio.

@peakcrosser7
Copy link
Contributor Author

  1. If I understand correctly, we will always cache the state at token num_tokens - num_tokens%block_size, but there is no ensure that tokens in other positions are cached. In this case, can you ensure system prompt is cached?
  2. Memory layout: I feel that we don't need such a new kv cache memory design. We can have a block_id list with length num_tokens / block_size + num_spec_decode_tokens, and always make sure that the state of the previous schedule step is at block num_computed_tokens / block_size

For example, if num_computed_tokens=29 and we schedule 1 new token, the kv cache before this step is: block 0: N/A block 1: token 29 block 2: N/A block 3: N/A

-> run main model block 0: N/A block 1: token 30 block 2: token 31 block 3: token 32 -> adjust kv cache based on number of accepted tokens

  • if no token is accepted:
    block 0: N/A
    block 1: token 30
    block 2: N/A
    block 3: N/A
  • if token 30 is accepted:
    block 0: N/A
    block 1: token 31
    block 2: N/A
    block 3: N/A
  • if token 30 & 31 are accepted:
    block 0: N/A
    block 1: token 31
    block 2: token 32
    block 3: N/A
  • if token 30 & 31 & 32 are accepted:
    block 0: token 15
    block 1: token 31
    block 2: token 33
    block 3: N/A

Then, in the next schedule step, the previous state is always at block [num_computed_tokens / block_size]

  1. I'm still concerning about whether we should increase the complexity of scheduler to avoid kernel changes of mamba layers.

Hi @heheda12345 , thank you very much for your detailed review.
I'd like to provide some supplementary explanations for the three points you raised.

  1. Regarding your first point, there’s a slight misunderstanding. Lighter Mamba Prefix Cache does cache the state at position num_tokens - num_tokens % block_size, but this is the last cached position for a given prompt, not the only one. Built on chunked prefill, Lighter Mamba Prefix Cache caches one state per chunk, where each chunk is block-aligned.
    For long prompts, the scheduler splits them into multiple chunks, and a state is cached after each — see the example in the diagram below. Additionally, for prompts longer than block_size but shorter than chunk_size, block-aligned scheduling still ensures exactly one state is cached. Only prompts shorter than block_size are not cached at all. This design is intentional: (a) short prompts already have low TTFT, so the benefit of prefix caching is marginal; and (b) chunk-granularity caching reduces overall block memory usage.
cached_states
  1. On your second point, the memory layout of Lighter Mamba Prefix Cache is fully backward-compatible with the non-prefix-caching case. Without prefix caching enabled, each request uses a fixed 1 + sps blocks per KV group. Lighter Mamba Prefix Cache simply adds one extra block at the beginning (to load a cached state on hit, or a null block on miss) and potentially one at the end (to store the new state).
    Crucially, most Mamba models derive state_indices_tensor from block_table_tensor[:, 0]. With Lighter Mamba Prefix Cache, using block_table_tensor[:, 1:] preserves the exact same layout as the non-caching case. As long as state_indices_tensor is made contiguous via .contiguous(), it integrates seamlessly with existing Mamba code —requiring virtually no code changes to support new Mamba architectures.

  2. Regarding your third concern, Lighter Mamba Prefix Cache introduces minimal scheduler complexity — it primarily enforces block-aligned chunking during prefill, which is conceptually similar to the V0 design. In contrast, modifying the Mamba kernel would require model-specific customizations across different Mamba variants. By avoiding kernel changes entirely, Lighter Mamba Prefix Cache achieves its lightweight.

Finally, thank you again for your review and feedback on this PR. We truly appreciate your time and insights, and we hope the clarifications above address your concerns.

@QilaiZhang
Copy link

@minminsun I'm curious about the performance of chunk-granularity caching. Would it be possible to share any cache hit ratio results? Thanks!

Jacki1223 pushed a commit to Jacki1223/sglang that referenced this pull request Nov 20, 2025
…guity

This commit integrates the key optimization from vLLM PR #28176 to improve
Qwen3-Next inference performance by ensuring Mamba state indices tensors
are explicitly contiguous.

## Changes:

### 1. hybrid_linear_attn_backend.py
- Added `.contiguous()` calls to `mamba_cache_indices` in three critical paths:
  * `_forward_metadata()`: Normal forward pass metadata preparation
  * `_capture_metadata()`: CUDA graph capture path
  * `_replay_metadata()`: CUDA graph replay path

### 2. mamba2_metadata.py
- Added `.contiguous()` calls in two metadata preparation methods:
  * `prepare_decode()`: Decode-only path (used during CUDA graph)
  * `prepare_mixed()`: Mixed prefill/decode path

## Rationale:

The vLLM PR #28176 identified that "state indices tensor must be explicitly
contiguous because requests can contain multiple blocks." This optimization
ensures better memory layout and improved kernel performance when processing
batched requests with Mamba-based hybrid models like Qwen3-Next.

## Benefits:

- Improved memory access patterns for Mamba state lookups
- Better performance for multi-block requests
- Consistent with vLLM's lightweight Mamba prefix caching approach
- No functional changes, purely performance optimization

Reference: vllm-project/vllm#28176
@minminsun
Copy link

@minminsun I'm curious about the performance of chunk-granularity caching. Would it be possible to share any cache hit ratio results? Thanks!

Though I cannot provide specific cache hit ratio metrics for our production services, the drop relative to block-granularity caching is less than 10%.

@joennlae
Copy link
Contributor

Fantastic work. Any plans on upstreaming this?

joennlae added a commit to 44ai-labs/vllm that referenced this pull request Dec 15, 2025
Copied and rebased from vllm-project#28176

Thanks to @peakcrosser7 and @minminsun

Signed-off-by: Jannis Schönleber <joennlae@gmail.com>
@joennlae
Copy link
Contributor

I added an rebased version here: #30725

@peakcrosser7
Copy link
Contributor Author

Hi, @joennlae ! Thanks for your positive feedback and creating the rebase PR!
Regarding the plan to merge upstream, @heheda12345 and I are currently developing a more complete implementation #29272 . It's based on the core ideas from this PR and adopts the same memory layout as FullAttn for organizing block_ids.
We plan to submit it as a new pull request to the main branch once it's finished, which should be soon.

@heheda12345
Copy link
Collaborator

We are iterating on #29272

@joennlae
Copy link
Contributor

Ah perfect :-)

@peakcrosser7
Copy link
Contributor Author

Closed because of #30877

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase qwen Related to Qwen models v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants