Skip to content

[WIP][Model Runner V2] support spec decode + mamba align prefix caching#42792

Draft
TheEpicDolphin wants to merge 1 commit into
vllm-project:mainfrom
TheEpicDolphin:mrv2-mamba-hybrid-align
Draft

[WIP][Model Runner V2] support spec decode + mamba align prefix caching#42792
TheEpicDolphin wants to merge 1 commit into
vllm-project:mainfrom
TheEpicDolphin:mrv2-mamba-hybrid-align

Conversation

@TheEpicDolphin

@TheEpicDolphin TheEpicDolphin commented May 15, 2026

Copy link
Copy Markdown
Collaborator

Accuracy Benchmark

Server Command

VLLM_USE_V2_MODEL_RUNNER=1 vllm serve Qwen/Qwen3.5-35B-A3B-FP8 \
  -tp 2 -dp 1 \
  --enable-prefix-caching \
  --mamba-cache-mode align \
  --max-num-seqs 64 \
  --attention-config '{"use_trtllm_attention": 0}' \
  --speculative-config '{"method": "mtp", "num_speculative_tokens": 2}' \
  --default-chat-template-kwargs '{"enable_thinking": false}'

Results

|  Tasks  |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot|      3|flexible-extract|     8|exact_match|↑  |0.7998|±  |0.0110|
|         |       |strict-match    |     8|exact_match|↑  |0.7968|±  |0.0111|

Performance Benchmark

Server Command

VLLM_USE_V2_MODEL_RUNNER=1 vllm serve Qwen/Qwen3.5-35B-A3B-FP8 \
  -tp 2 -dp 1 \
  --enable-prefix-caching \
  --mamba-cache-mode align \
  --max-num-seqs 64 \
  --attention-config '{"use_trtllm_attention": 0}' \
  --speculative-config '{"method": "mtp", "num_speculative_tokens": 2}'

Results

Metric #42406 This PR Delta
General
Successful requests 512 512
Failed requests 0 0
Max request concurrency 16 16
Benchmark duration (s) 119.40 113.97 -4.5%
Total input tokens 2,167,806 2,271,087 +4.8%
Total generated tokens 262,144 262,144
Request throughput (req/s) 4.29 4.49 +4.7%
Output token throughput (tok/s) 2,195.48 2,300.20 +4.8%
Peak output token throughput (tok/s) 1,553.00 1,040.00 -33.0%
Peak concurrent requests 28.00 28.00
Total token throughput (tok/s) 20,351.05 22,227.97 +9.2%
Time to First Token
Mean TTFT (ms) 282.49 105.66 -62.6%
Median TTFT (ms) 145.19 84.13 -42.1%
P99 TTFT (ms) 3,800.32 995.73 -73.8%
Time per Output Token
Mean TPOT (ms) 6.64 6.68 +0.6%
Median TPOT (ms) 6.59 6.61 +0.3%
P99 TPOT (ms) 10.03 7.95 -20.7%
Inter-token Latency
Mean ITL (ms) 16.35 17.04 +4.2%
Median ITL (ms) 10.58 15.83 +49.6%
P99 ITL (ms) 105.52 38.90 -63.1%
End-to-end Latency
Mean E2EL (ms) 3,677.74 3,521.25 -4.3%
Median E2EL (ms) 3,528.08 3,468.17 -1.7%
P99 E2EL (ms) 8,904.19 4,462.88 -49.9%
Speculative Decoding
Acceptance rate (%) 73.88 78.34 +4.46pp
Acceptance length 2.48 2.57 +3.6%
Drafts 105,785 102,119 -3.5%
Draft tokens 211,570 204,238 -3.5%
Accepted tokens 156,310 160,003 +2.4%
Position 0 acceptance (%) 83.86 85.68 +1.82pp
Position 1 acceptance (%) 63.90 71.00 +7.10pp

@mergify mergify Bot added the v1 label May 15, 2026
@mergify

mergify Bot commented May 15, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 15, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Mamba models and the "align" cache mode for the V2 model runner, introducing src_ssm_indices_tensor_d to support separate read/write blocks in SSM kernels during speculative decoding. Review feedback indicates that the construction of this tensor fails to correctly implement Mamba's recurrence, as tokens should read from the previous token's destination. Furthermore, the logic for copying SSM states from staging to committed blocks in "align" mode is flagged as potentially incorrect or unnecessary.

Comment on lines +463 to +467
src_ssm_indices_tensor_d = (
committed_phys.unsqueeze(1)
.expand_as(state_indices_tensor_d)
.contiguous()
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The construction of src_ssm_indices_tensor_d does not correctly implement recurrence for speculative decoding in "align" mode. Currently, it sets all tokens in the speculative window to read from the same committed_phys block. For a sequential model like Mamba, each token in the window (except the first) should read the state produced by the previous token. To achieve this, the source indices for token $i$ should be the destination indices of token $i-1$.

            src_ssm_indices_tensor_d = torch.cat([
                committed_phys.unsqueeze(1),
                state_indices_tensor_d[:, :-1]
            ], dim=1).contiguous()

Comment on lines +221 to +229
# Destination is the running block.
dst_phys = state_indices_tensor[:, 0]
# Copy for every layer in the group.
for layer_name in group.layer_names:
layer = fwd_ctx[layer_name]
ssm_state = layer.kv_cache[1]
src_idx = src_phys[needs_copy_mask].long()
dst_idx = dst_phys[needs_copy_mask].long()
ssm_state[dst_idx] = ssm_state[src_idx]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The copy logic in _copy_ssm_staging_to_committed appears incorrect for "align" mamba cache mode. In "align" mode, each position in the sequence is mapped to a specific physical block. Copying the state from the last accepted token's block (src_phys) back to the first block of the window (dst_phys, which corresponds to the current position) is unnecessary and potentially harmful, as the next decoding step will naturally look for the state at the updated sequence position in its corresponding block. This logic seems more suited for a mode where a single "running" block is reused across steps.

@TheEpicDolphin TheEpicDolphin force-pushed the mrv2-mamba-hybrid-align branch 2 times, most recently from 305c01b to d914bac Compare May 15, 2026 23:53
@TheEpicDolphin TheEpicDolphin changed the title test [Model Runner V2] support spec decode + mamba align prefix caching May 15, 2026
@TheEpicDolphin TheEpicDolphin changed the title [Model Runner V2] support spec decode + mamba align prefix caching [WIP][Model Runner V2] support spec decode + mamba align prefix caching May 16, 2026
conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
).transpose(0, 1)

@izhuhaoran izhuhaoran May 16, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May have cross-block correctness ?
Let K_t = (num_computed_t + 1 + num_spec - 1) // block_size be the running-block position at step t. In this PR causal_conv1d_update is still called with conv_state_indices = state_indices_tensor_d (the same tensor for read and write), so the kernel reads from and writes to column 0 = block_table[K_t]. The intra-block num_accepted_tokens - 1 offset rolls the window inside that block, but it cannot move the window to a different physical block.

When K_{t+1} > K_t — i.e. the previous step pushed num_computed past a block_size boundary — step t+1's column 0 is block_table[K_{t+1}], which in step t was a staging slot that the conv kernel never wrote to. The kernel reads whatever was there (leftover SSM staging or zero-init), so the conv initial state is wrong.

Suggested fix: mirror the SSM trick on the conv side — add a src_conv_indices_tensor_d and let the kernel read from it while still writing to state_indices_tensor_d[:, current_last_index]. This needs a small change inside causal_conv1d_update with the inner triton kernel.

One caveat worth flagging: if the longer-term plan is to switch the conv update path to a FlashInfer cuteDSL implementation (or similar external kernel), it's not obvious that the external kernel will expose a split read/write indices API. So the kernel-side change here may or may not be portable forward — worth a thought.

Comment on lines +962 to 984
state_indices_tensor_d_input = state_indices_tensor_d
else:
# Read from separate set of blocks. Used for MRV2's "align"
# mamba cache mode.
state_indices_tensor_d_input = (
attn_metadata.src_ssm_indices_tensor_d
)
state_indices_tensor_d_output = state_indices_tensor_d

# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
conv_state,
self.conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
num_accepted_tokens=num_accepted_tokens,
query_start_loc=query_start_loc_d,
max_query_len=state_indices_tensor_d.size(-1),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
src_phys = state_indices_tensor.gather(1, src_block.unsqueeze(1)).squeeze(1)
# Destination is the running block.
dst_phys = state_indices_tensor[:, 0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar with https://github.com/vllm-project/vllm/pull/42792/changes#r3252543339, but on the SSM side.

dst_phys = state_indices_tensor[:, 0] writes the accepted SSM state into block_table[K_t] (this step's view). Step t+1 then reads from src_ssm_indices_tensor_d = block_table[(num_computed_new - 1) // block_size]. These two block indices coincide only when no boundary was crossed by the accepted tokens. With partial acceptance (k < boundary distance), the new committed-block index stays below K_t, and the next step reads stale SSM from a block the postprocess didn't update.

Suggested fix: postprocess dst_phys should be block_table[(num_computed_new - 1) // block_size] instead of state_indices_tensor[:, 0].

num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Get block containing last committed token.
committed_block_idx = torch.clamp(
(num_computed_tokens[:num_decodes] - 1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

committed_block_idx is computed from common_attn_metadata.compute_num_computed_tokens(). Under async scheduling this returns an optimistic count — it assumes the previous step's drafts were all accepted. When fewer are accepted, committed_block_idx overshoots by one and committed_phys points at a block that doesn't yet hold a committed state, so the SSM kernel may reads from the wrong slot.

@TheEpicDolphin TheEpicDolphin May 18, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think num_computed_tokens is updated with the actual number of accepted tokens at the end of each step here:

num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
num_computed += query_len - num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)

The CPU mirror of this tensor is optimistic though, but I'm not using that one.

# Copy for every layer in the group.
for layer_name in group.layer_names:
layer = fwd_ctx[layer_name]
ssm_state = layer.kv_cache[1]

@izhuhaoran izhuhaoran May 16, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ssm_state = layer.kv_cache[1] assumes index 1 is the temporal state. This holds for Mamba1/Mamba2/GDN, but not for models that register more than one conv state. For example, KDA uses (get_conv_copy_spec, get_conv_copy_spec, get_conv_copy_spec, get_temporal_copy_spec), so kv_cache[1] is conv, not SSM — the postprocess would silently corrupt the conv state in that case.

Also Qwen3-Next routes through GatedDeltaNetAttention , this PR doesn't propagate the change to that path (only mamba1 and mamba2 ?), so MRv2 + spec-decode + align on GDN models may have issues. (The Qwen3.5-35B-A3B-FP8 lm_eval may don't hit prefix cache)

# committed block read is correct.
self._copy_ssm_staging_to_committed(input_batch, num_accepted_tokens)

def _copy_ssm_staging_to_committed(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous implementation (#42406) or v1, when a step's accepted tokens crossed a block_size boundary, postprocess copied the running state into block_table[aligned_new_computed_tokens // block_size - 1] — the slot a future prefix-cache hit on that boundary token will read from.

In this PR, _copy_ssm_staging_to_committed only copies staging[k-1] -> state_indices_tensor[:, 0] (= block_table[K_t], this step's running block). There doesn't seem to be anything else that writes to block_table[aligned_new_computed // block_size - 1] for boundaries crossed during spec decode. If that's the case, a later request that prefix-hits such a boundary would read stale state from the aligned-save slot — am I missing a path that handles it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right, im not correctly handling the boundary crossing cases for either conv or SSM state. Fortunately it's pretty straightforward to fix. I added a check for boundary crossings and copy over both states to the completed blocks when any are detected.

@izhuhaoran izhuhaoran left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for taking another shot at this and for cleaning it up into a standalone PR — splitting the SSM read/write indices via src_ssm_indices_tensor_d is a genuinely nice simplification compared to the preprocess-copy approach in my earlier PR, and I'd like to adopt that idea.

That said, on a careful first pass I think the current diff has a few correctness issues, mostly around what happens when a step advances the "running" mamba block position. The detailed comments are as above, happy to discuss any of these.

One thought on validation: gsm8k likely has too short a context to reliably exercise the prefix-caching paths (cross-block-boundary spec decode + later prefix hits). Something closer to test_mamba_prefix_cache_mrv2 in #42406 might catch the cases I'm worried about above.

Thanks again for pushing this forward — I'll also take the read/write index split idea back to #42406 and think about how to implement prefix caching better.

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
@TheEpicDolphin TheEpicDolphin force-pushed the mrv2-mamba-hybrid-align branch from d914bac to 746f6c2 Compare May 18, 2026 21:32
@TheEpicDolphin

Copy link
Copy Markdown
Collaborator Author

Thanks for taking a look @izhuhaoran. For completeness, i updated this PR with my best attempt to handle block boundary crossings for both conv and SSM state. I haven't tested these changes, as I don't plan to publish this PR, but I figured I'd include them in case they are helpful for your final implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants