[Model Runner V2] support mamba hybrid models align prefix cache#42406
[Model Runner V2] support mamba hybrid models align prefix cache#42406izhuhaoran wants to merge 15 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request enables support for Mamba hybrid models and the 'align' cache mode within Model Runner V2. It implements Triton kernels for Mamba state copying and alignment, updates the ModelState interface to include pre- and post-processing hooks, and integrates these into the GPUModelRunner. Furthermore, it adds comprehensive end-to-end tests for Mamba prefix caching in the V2 model runner. I have no feedback to provide.
7a6713a to
6a9594c
Compare
|
@WoosukKwon @TheEpicDolphin This PR adds support for Mamba align-mode prefix caching with Model Runner V2. Could you please review it when you have time? Also looking forward to @tdoublep’s suggestions ( Thanks for your interest! ) |
|
Have you tested the performance of mrv2 + prefix-caching? How much of an improvement it represents compared to mrv1 + prefix-caching? |
I added a benchmark comparison to the PR description. Unfortunately, in my current setup, MRV2 + prefix caching does not show a big performance improvement over MRV1 + prefix caching yet. The result depend heavily on hardware and workload, so I would recommend benchmarking diff settings in your own environment and choosing the one that works best for your case. |
|
Hey @izhuhaoran, a few weeks ago I took a stab at implementing mamba align mode while based on your earlier PR. I initially worked on it here, but have cleaned it up into this PR: #42792 for better readability. I took a different approach from you that yields identical accuracy results, but better performance on the same benchmark command. The main difference is that I provide the SSM kernel with a separate read index tensor ( I'd like to get your feedback on my approach. Could it be applied here to simplify things, or does it have limitations that I'm not aware of? |
@TheEpicDolphin Thanks a lot for your time and effort on this — splitting the SSM read/write indices via |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Hi, any update? |
@JaheimLee Hi, thanks for your interest! Sorry for the delay — I'm currently working on a refactored version based on the review feedback. It's still being tested and polished, so not quite ready yet. Will update once it's in good shape! |
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
|
This pull request has merge conflicts that must be resolved before it can be |
|
I found that repeatedly sending the same requests significantly improved the mean TTFT, but the service logs consistently showed "Prefix cache hit rate: 0.0%". Test command: |
Thanks for the report! Yes, I've tested using the bench command from this comment. The prefix cache hit rate is at the same level between model runner v1 and v2. Also, align cache mode pads the attention block size to a large number to match the mamba state page size, and it will only cache the latest block boundary checkpoint rather than all blocks, so please try longer shared prefixes to see cache hits. |
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
…amba align Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
TheEpicDolphin
left a comment
There was a problem hiding this comment.
The logic looks sound to me! I left some comments mostly regarding readability and reducing code duplication.
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
TheEpicDolphin
left a comment
There was a problem hiding this comment.
LGTM! cc: @WoosukKwon
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
|
My another concern is we may use decode kernels from flashinfer in the future, which doesn't have parameter like |
|
You emphasize a lot about decouple read and write. |
Hi Vadim @vadiklyutiy — good question, but I think there's a misunderstanding: decoupling read/write doesn't add any extra memory. The key constraint: in mamba align prefix-cache mode, a cache-hit state block is immutable — it's shared across sequences via prefix caching, so writing to it in place would corrupt the cached state that other sequences point at. That's why we can't do an in-place update. A new block has to be allocated for the updated state either way. v1's model runner handled this with a pre-copy: before the forward, a copy kernel copies the state from the old (immutable) block into the newly-allocated block, then updates it in place. Decoupling src/dst simply lets the kernel read the initial state directly from the immutable src block and write the new state into the newly-allocated dst block — skipping that pre-copy. So both approaches allocate the exact same dst block; decoupling uses the same memory as pre-copy (no doubling), it just removes the copy step. The performance gain is modest but essentially free. (The idea comes from @TheEpicDolphin's discussion here: #42406 (comment)) |
|
That said — if you feel this gain is too minor to justify the constraint it puts on future mamba kernels (they'd need to support the decoupled src/dst read path), I'd also accept reverting to the v1 style (my original implementation). I'd be glad to discuss the design further. cc @TheEpicDolphin @ZJY0516 — what are your thoughts? |
not right now. I want to understand the difference in approaches. |
|
Might be worth checking with ppl working on FlashInfer GDN to size how disruptive this change would be |
Hm, align mode is for caching chat multiturn conversations. The case when several requests share "align" prefix should be at least uncommon(formally it is possible when you fork chat). |
Good point @tdoublep. I'm actually already implementing a fallback pre-copy path (WIP) for exactly this — similar to how we gate cudagraph support, it checks whether the backend declares support for the decoupled src/dst read and falls back to pre-copy if not. So a backend like FlashInfer GDN that doesn't (yet) support it would automatically use the v1-style pre-copy, keeping this change non-disruptive while backends that do support it get the copy-free path. |
|
I'd prefer to have one path instead of 2. Having several paths more complicated: there are more code, you need test both path (otherwise not tested will be likely buggy), etc |
@vadiklyutiy just FYI align mode can now also support shared common prefix since #37898 was merged |
Fair point — I agree maintaining two paths has a real cost. Just a fallback for the case where some backends can't support the decoupled src/dst read. So for a single path, the two options are:
Which would you prefer? Happy to go whichever way you all think is best. Also glad to hear others' thoughts on this. |
Purpose
A follow PR of #35520
Support
mamba_cache_mode="align"for Mamba prefix caching with Model Runner V2. The implementation keeps the align-mode state handling insideMambaHybridModelState, avoids CPU/GPU synchronization in the hot path.Design
No post-copy needed: The kernel reads from the src block (previous running block) and writes to the window block (which becomes the new committed block). Only the aligned-save (boundary checkpoint) remains — no "copy back" step.
Conv state covered: The conv kernel also receives a separate src physical block + intra-block token offset, handling both decode and prefill paths. Notably, under spec decode the conv and SSM src positions are asymmetric — SSM state lives on the spec block (at
state_idx + num_accepted - 1), while conv state always lives on the main block (atstate_idx) with a separate token offset.Kernel-level src/dst support: Modified the Triton SSM and conv kernels to accept explicit src and dst indices controlling read/write positions. This is wired into all three GDN variants (Qwen, Kimi, OLMo) and the packed recurrent decode path, while Mamba1/Mamba2 layers natively support src/dst separation.
Kernel requirements / limitations
The copy-free "align" path removes the per-step state pre-copy by making kernels read the initial state from a source block while writing the running state to the window block. This requires every mamba backend's kernels to support decoupled read/write state indices ("src redirect"):
SSM / recurrent decode kernels must accept
src_ssm_state_indices: read the initial recurrent state from the per-sequence source block, write the updated state to the (different) window block given byssm_state_indices.Conv kernels (prefill + decode) must accept
src_conv_state_indicesplussrc_conv_token_offset: read the initial conv window from the source block at a token offset (needed under spec decode), write to the window block.A src index of
NULL_BLOCK_ID(0) means "fresh state" (zero-init / skip the read).All in-tree Triton kernels used by the GDN family (
fused_recurrent,fused_sigmoid_gating_delta_rule_update,causal_conv1d_fn/causal_conv1d_update) have been extended accordingly.Test
Nvidia H20, Qwen3.5-35B-A3B-FP8, tp=2, mtp with num_speculative_tokens=2, language-model-only (text only)
Passed MRV2 Mamba prefix cache UT:
tests/v1/e2e/general/test_mamba_prefix_cache.py::test_mamba_prefix_cache_mrv2Passed
lm_eval gsm8k