Skip to content

[Model Runner V2] support mamba hybrid models align prefix cache#42406

Open
izhuhaoran wants to merge 15 commits into
vllm-project:mainfrom
izhuhaoran:mrv2-qwen35-prefix-cache
Open

[Model Runner V2] support mamba hybrid models align prefix cache#42406
izhuhaoran wants to merge 15 commits into
vllm-project:mainfrom
izhuhaoran:mrv2-qwen35-prefix-cache

Conversation

@izhuhaoran

@izhuhaoran izhuhaoran commented May 12, 2026

Copy link
Copy Markdown
Contributor

Purpose

A follow PR of #35520

Support mamba_cache_mode="align" for Mamba prefix caching with Model Runner V2. The implementation keeps the align-mode state handling inside MambaHybridModelState, avoids CPU/GPU synchronization in the hot path.

Design

  1. No post-copy needed: The kernel reads from the src block (previous running block) and writes to the window block (which becomes the new committed block). Only the aligned-save (boundary checkpoint) remains — no "copy back" step.

  2. Conv state covered: The conv kernel also receives a separate src physical block + intra-block token offset, handling both decode and prefill paths. Notably, under spec decode the conv and SSM src positions are asymmetric — SSM state lives on the spec block (at state_idx + num_accepted - 1), while conv state always lives on the main block (at state_idx) with a separate token offset.

  3. Kernel-level src/dst support: Modified the Triton SSM and conv kernels to accept explicit src and dst indices controlling read/write positions. This is wired into all three GDN variants (Qwen, Kimi, OLMo) and the packed recurrent decode path, while Mamba1/Mamba2 layers natively support src/dst separation.

Kernel requirements / limitations

The copy-free "align" path removes the per-step state pre-copy by making kernels read the initial state from a source block while writing the running state to the window block. This requires every mamba backend's kernels to support decoupled read/write state indices ("src redirect"):

  • SSM / recurrent decode kernels must accept src_ssm_state_indices: read the initial recurrent state from the per-sequence source block, write the updated state to the (different) window block given by ssm_state_indices.

  • Conv kernels (prefill + decode) must accept src_conv_state_indices plus src_conv_token_offset: read the initial conv window from the source block at a token offset (needed under spec decode), write to the window block.

  • A src index of NULL_BLOCK_ID (0) means "fresh state" (zero-init / skip the read).

All in-tree Triton kernels used by the GDN family (fused_recurrent, fused_sigmoid_gating_delta_rule_update, causal_conv1d_fn / causal_conv1d_update) have been extended accordingly.

Test

Nvidia H20, Qwen3.5-35B-A3B-FP8, tp=2, mtp with num_speculative_tokens=2, language-model-only (text only)

  • Passed MRV2 Mamba prefix cache UT: tests/v1/e2e/general/test_mamba_prefix_cache.py::test_mamba_prefix_cache_mrv2

  • Passed lm_eval gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7938|±  |0.0111|
|     |       |strict-match    |     5|exact_match|↑  |0.7801|±  |0.0114|

  • Performance Bench test
python -m vllm.entrypoints.cli.main bench serve --model=Qwen3.5-35B-A3B-FP8 --dataset-name prefix_repetition --prefix-repetition-prefix-len 4096  \
            --prefix-repetition-suffix-len 128 --prefix-repetition-num-prefixes 1 --prefix-repetition-output-len 512 \
            --max-concurrency=16 --num-prompts=512 --percentile-metrics="ttft,tpot,itl,e2el" --ignore-eos --seed=13 --disable-shuffle
runner prefix cache req/s output tok/s total tok/s mean TTFT ms mean TPOT ms mean ITL ms mean E2EL ms
v2 off 3.04 1558.75 14418.44 212.02 9.73 19.75 5184.04
v2 on/align 3.09 1583.50 14647.42 174.29 9.68 19.68 5118.87
v1 off 3.01 1538.68 14232.83 216.91 9.92 20.38 5283.79
v1 on/align 2.90 1483.84 13725.52 174.91 10.33 21.32 5452.73

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the v1 label May 12, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for Mamba hybrid models and the 'align' cache mode within Model Runner V2. It implements Triton kernels for Mamba state copying and alignment, updates the ModelState interface to include pre- and post-processing hooks, and integrates these into the GPUModelRunner. Furthermore, it adds comprehensive end-to-end tests for Mamba prefix caching in the V2 model runner. I have no feedback to provide.

@tdoublep tdoublep self-requested a review May 12, 2026 18:10
@izhuhaoran izhuhaoran force-pushed the mrv2-qwen35-prefix-cache branch from 7a6713a to 6a9594c Compare May 12, 2026 18:30
@izhuhaoran

Copy link
Copy Markdown
Contributor Author

@WoosukKwon @TheEpicDolphin This PR adds support for Mamba align-mode prefix caching with Model Runner V2. Could you please review it when you have time? Also looking forward to @tdoublep’s suggestions ( Thanks for your interest! )

@xhdidi

xhdidi commented May 13, 2026

Copy link
Copy Markdown

Have you tested the performance of mrv2 + prefix-caching? How much of an improvement it represents compared to mrv1 + prefix-caching?

@izhuhaoran

Copy link
Copy Markdown
Contributor Author

Have you tested the performance of mrv2 + prefix-caching? How much of an improvement it represents compared to mrv1 + prefix-caching?

I added a benchmark comparison to the PR description. Unfortunately, in my current setup, MRV2 + prefix caching does not show a big performance improvement over MRV1 + prefix caching yet. The result depend heavily on hardware and workload, so I would recommend benchmarking diff settings in your own environment and choosing the one that works best for your case.

@TheEpicDolphin

TheEpicDolphin commented May 16, 2026

Copy link
Copy Markdown
Collaborator

Hey @izhuhaoran, a few weeks ago I took a stab at implementing mamba align mode while based on your earlier PR. I initially worked on it here, but have cleaned it up into this PR: #42792 for better readability.

I took a different approach from you that yields identical accuracy results, but better performance on the same benchmark command. The main difference is that I provide the SSM kernel with a separate read index tensor (src_ssm_indices_tensor_d) that always points to the committed block. The kernel reads initial state from the committed block and writes updated state to the draft staging blocks. After sampling (during postprocess_state), the last-accepted draft SSM state is copied back to the committed block. This simplified the implementation to support this feature significantly.

I'd like to get your feedback on my approach. Could it be applied here to simplify things, or does it have limitations that I'm not aware of?

@izhuhaoran

izhuhaoran commented May 16, 2026

Copy link
Copy Markdown
Contributor Author

Hey @izhuhaoran, a few weeks ago I took a stab at implementing mamba align mode while based on your earlier PR. I initially worked on it here, but have cleaned it up into this PR: #42792 for better readability.

@TheEpicDolphin Thanks a lot for your time and effort on this — splitting the SSM read/write indices via src_ssm_indices_tensor_d is a brilliant simplification, I've left some preliminary thoughts or concerns on #42792 . And I'd like to adopt it here.

@njhill njhill added the v2 label May 20, 2026
@mergify

mergify Bot commented May 20, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @izhuhaoran.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@JaheimLee

Copy link
Copy Markdown

Hi, any update?

@izhuhaoran

Copy link
Copy Markdown
Contributor Author

Hi, any update?

@JaheimLee Hi, thanks for your interest! Sorry for the delay — I'm currently working on a refactored version based on the review feedback. It's still being tested and polished, so not quite ready yet. Will update once it's in good shape!

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
@mergify

mergify Bot commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @izhuhaoran.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 2, 2026
@xhdidi

xhdidi commented Jun 3, 2026

Copy link
Copy Markdown

I found that repeatedly sending the same requests significantly improved the mean TTFT, but the service logs consistently showed "Prefix cache hit rate: 0.0%".

Test command:
vllm bench serve --backend vllm --model Qwen3.5-35B-A3B --endpoint /v1/completions --dataset-name sharegpt --dataset-path ShareGPT_V4.3_unfiltered_cleaned_split.json --num-prompts 20
Serving Benchmark Result:
The mean TTFT was 616.08ms in the first test and 453.92ms in the second test.

@izhuhaoran

izhuhaoran commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

I found that repeatedly sending the same requests significantly improved the mean TTFT, but the service logs consistently showed "Prefix cache hit rate: 0.0%".

Test command: vllm bench serve --backend vllm --model Qwen3.5-35B-A3B --endpoint /v1/completions --dataset-name sharegpt --dataset-path ShareGPT_V4.3_unfiltered_cleaned_split.json --num-prompts 20 Serving Benchmark Result: The mean TTFT was 616.08ms in the first test and 453.92ms in the second test.

Thanks for the report! Yes, I've tested using the bench command from this comment. The prefix cache hit rate is at the same level between model runner v1 and v2. Also, align cache mode pads the attention block size to a large number to match the mamba state page size, and it will only cache the latest block boundary checkpoint rather than all blocks, so please try longer shared prefixes to see cache hits.

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
…amba align

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
@mergify mergify Bot removed the needs-rebase label Jun 3, 2026

@TheEpicDolphin TheEpicDolphin left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic looks sound to me! I left some comments mostly regarding readability and reducing code duplication.

Comment thread vllm/v1/worker/gpu/model_states/interface.py Outdated
Comment thread vllm/v1/worker/gpu/warmup.py
Comment thread vllm/v1/worker/gpu/model_states/mamba_hybrid.py Outdated
Comment thread vllm/model_executor/layers/fla/ops/fused_recurrent.py Outdated
Comment thread vllm/v1/attention/backends/gdn_attn.py Outdated
Comment thread vllm/v1/attention/backends/gdn_attn.py Outdated
@mergify

mergify Bot commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @izhuhaoran.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 6, 2026
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>

@TheEpicDolphin TheEpicDolphin left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! cc: @WoosukKwon

Comment thread vllm/model_executor/layers/mamba/gdn/qwen_gdn_linear_attn.py Outdated
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
@ZJY0516

ZJY0516 commented Jun 12, 2026

Copy link
Copy Markdown
Member

My another concern is we may use decode kernels from flashinfer in the future, which doesn't have parameter like src_ssm_indices

@vadiklyutiy

Copy link
Copy Markdown
Member

You emphasize a lot about decouple read and write.
But why is it so necessary? Block size usually pretty big, something like 512 tokens. One copy per 512 tokens doesn't look very expensive.
From other side with decouple read and write you double memory consumption what is the real problem for mamba model due to big state size.

@izhuhaoran

Copy link
Copy Markdown
Contributor Author

You emphasize a lot about decouple read and write. But why is it so necessary? Block size usually pretty big, something like 512 tokens. One copy per 512 tokens doesn't look very expensive. From other side with decouple read and write you double memory consumption what is the real problem for mamba model due to big state size.

Hi Vadim @vadiklyutiy — good question, but I think there's a misunderstanding: decoupling read/write doesn't add any extra memory.

The key constraint: in mamba align prefix-cache mode, a cache-hit state block is immutable — it's shared across sequences via prefix caching, so writing to it in place would corrupt the cached state that other sequences point at. That's why we can't do an in-place update. A new block has to be allocated for the updated state either way.

v1's model runner handled this with a pre-copy: before the forward, a copy kernel copies the state from the old (immutable) block into the newly-allocated block, then updates it in place. Decoupling src/dst simply lets the kernel read the initial state directly from the immutable src block and write the new state into the newly-allocated dst block — skipping that pre-copy.

So both approaches allocate the exact same dst block; decoupling uses the same memory as pre-copy (no doubling), it just removes the copy step. The performance gain is modest but essentially free.

(The idea comes from @TheEpicDolphin's discussion here: #42406 (comment))

@izhuhaoran

Copy link
Copy Markdown
Contributor Author

That said — if you feel this gain is too minor to justify the constraint it puts on future mamba kernels (they'd need to support the decoupled src/dst read path), I'd also accept reverting to the v1 style (my original implementation). I'd be glad to discuss the design further.

cc @TheEpicDolphin @ZJY0516 — what are your thoughts?

@vadiklyutiy

Copy link
Copy Markdown
Member

That said — if you feel this gain is too minor to justify the constraint it puts on future mamba kernels (they'd need to support the decoupled src/dst read path), I'd also accept reverting to the v1 style (my original implementation). I'd be glad to discuss the design further.

cc @TheEpicDolphin @ZJY0516 — what are your thoughts?

not right now. I want to understand the difference in approaches.

@tdoublep

Copy link
Copy Markdown
Member

Might be worth checking with ppl working on FlashInfer GDN to size how disruptive this change would be

@vadiklyutiy

Copy link
Copy Markdown
Member

The key constraint: in mamba align prefix-cache mode, a cache-hit state block is immutable — it's shared across sequences via prefix caching, so writing to it in place would corrupt the cached state that other sequences point at. That's why we can't do an in-place update. A new block has to be allocated for the updated state either way.

Hm, align mode is for caching chat multiturn conversations. The case when several requests share "align" prefix should be at least uncommon(formally it is possible when you fork chat).

@izhuhaoran

Copy link
Copy Markdown
Contributor Author

Hm, align mode is for caching chat multiturn conversations. The case when several requests share "align" prefix should be at least uncommon(formally it is possible when you fork chat).

Good point @tdoublep. I'm actually already implementing a fallback pre-copy path (WIP) for exactly this — similar to how we gate cudagraph support, it checks whether the backend declares support for the decoupled src/dst read and falls back to pre-copy if not. So a backend like FlashInfer GDN that doesn't (yet) support it would automatically use the v1-style pre-copy, keeping this change non-disruptive while backends that do support it get the copy-free path.

@vadiklyutiy

Copy link
Copy Markdown
Member

I'd prefer to have one path instead of 2. Having several paths more complicated: there are more code, you need test both path (otherwise not tested will be likely buggy), etc

@tdoublep

Copy link
Copy Markdown
Member

align mode is for caching chat multiturn conversations.

@vadiklyutiy just FYI align mode can now also support shared common prefix since #37898 was merged

@izhuhaoran

izhuhaoran commented Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

I'd prefer to have one path instead of 2. Having several paths more complicated: there are more code, you need test both path (otherwise not tested will be likely buggy), etc

Fair point — I agree maintaining two paths has a real cost. Just a fallback for the case where some backends can't support the decoupled src/dst read.

So for a single path, the two options are:

  1. pre-copy only — the established v1 approach; works for every backend.
  2. copy-free only — slightly faster, needs mamba backend to support the decoupled src/dst read (maybe further flashinfer gdn decode kernels).

Which would you prefer? Happy to go whichever way you all think is best. Also glad to hear others' thoughts on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants