[WIP][Model Runner V2] support spec decode + mamba align prefix caching#42792
[WIP][Model Runner V2] support spec decode + mamba align prefix caching#42792TheEpicDolphin wants to merge 1 commit into
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request enables Mamba models and the "align" cache mode for the V2 model runner, introducing src_ssm_indices_tensor_d to support separate read/write blocks in SSM kernels during speculative decoding. Review feedback indicates that the construction of this tensor fails to correctly implement Mamba's recurrence, as tokens should read from the previous token's destination. Furthermore, the logic for copying SSM states from staging to committed blocks in "align" mode is flagged as potentially incorrect or unnecessary.
| src_ssm_indices_tensor_d = ( | ||
| committed_phys.unsqueeze(1) | ||
| .expand_as(state_indices_tensor_d) | ||
| .contiguous() | ||
| ) |
There was a problem hiding this comment.
The construction of src_ssm_indices_tensor_d does not correctly implement recurrence for speculative decoding in "align" mode. Currently, it sets all tokens in the speculative window to read from the same committed_phys block. For a sequential model like Mamba, each token in the window (except the first) should read the state produced by the previous token. To achieve this, the source indices for token
src_ssm_indices_tensor_d = torch.cat([
committed_phys.unsqueeze(1),
state_indices_tensor_d[:, :-1]
], dim=1).contiguous()| # Destination is the running block. | ||
| dst_phys = state_indices_tensor[:, 0] | ||
| # Copy for every layer in the group. | ||
| for layer_name in group.layer_names: | ||
| layer = fwd_ctx[layer_name] | ||
| ssm_state = layer.kv_cache[1] | ||
| src_idx = src_phys[needs_copy_mask].long() | ||
| dst_idx = dst_phys[needs_copy_mask].long() | ||
| ssm_state[dst_idx] = ssm_state[src_idx] |
There was a problem hiding this comment.
The copy logic in _copy_ssm_staging_to_committed appears incorrect for "align" mamba cache mode. In "align" mode, each position in the sequence is mapped to a specific physical block. Copying the state from the last accepted token's block (src_phys) back to the first block of the window (dst_phys, which corresponds to the current position) is unnecessary and potentially harmful, as the next decoding step will naturally look for the state at the updated sequence position in its corresponding block. This logic seems more suited for a mode where a single "running" block is reused across steps.
305c01b to
d914bac
Compare
| conv_state_indices=state_indices_tensor_d, | ||
| block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, | ||
| initial_state_idx=block_idx_last_computed_token_d, | ||
| ).transpose(0, 1) |
There was a problem hiding this comment.
May have cross-block correctness ?
Let K_t = (num_computed_t + 1 + num_spec - 1) // block_size be the running-block position at step t. In this PR causal_conv1d_update is still called with conv_state_indices = state_indices_tensor_d (the same tensor for read and write), so the kernel reads from and writes to column 0 = block_table[K_t]. The intra-block num_accepted_tokens - 1 offset rolls the window inside that block, but it cannot move the window to a different physical block.
When K_{t+1} > K_t — i.e. the previous step pushed num_computed past a block_size boundary — step t+1's column 0 is block_table[K_{t+1}], which in step t was a staging slot that the conv kernel never wrote to. The kernel reads whatever was there (leftover SSM staging or zero-init), so the conv initial state is wrong.
Suggested fix: mirror the SSM trick on the conv side — add a src_conv_indices_tensor_d and let the kernel read from it while still writing to state_indices_tensor_d[:, current_last_index]. This needs a small change inside causal_conv1d_update with the inner triton kernel.
One caveat worth flagging: if the longer-term plan is to switch the conv update path to a FlashInfer cuteDSL implementation (or similar external kernel), it's not obvious that the external kernel will expose a split read/write indices API. So the kernel-side change here may or may not be portable forward — worth a thought.
| state_indices_tensor_d_input = state_indices_tensor_d | ||
| else: | ||
| # Read from separate set of blocks. Used for MRV2's "align" | ||
| # mamba cache mode. | ||
| state_indices_tensor_d_input = ( | ||
| attn_metadata.src_ssm_indices_tensor_d | ||
| ) | ||
| state_indices_tensor_d_output = state_indices_tensor_d | ||
|
|
||
| # 2. Convolution sequence transformation | ||
| hidden_states_B_C_d = causal_conv1d_update( | ||
| hidden_states_B_C_d, | ||
| conv_state, | ||
| self.conv_weights, | ||
| self.conv1d.bias, | ||
| self.activation, | ||
| conv_state_indices=state_indices_tensor_d, | ||
| block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, | ||
| initial_state_idx=block_idx_last_computed_token_d, | ||
| num_accepted_tokens=num_accepted_tokens, | ||
| query_start_loc=query_start_loc_d, | ||
| max_query_len=state_indices_tensor_d.size(-1), | ||
| ) |
There was a problem hiding this comment.
| ) | ||
| src_phys = state_indices_tensor.gather(1, src_block.unsqueeze(1)).squeeze(1) | ||
| # Destination is the running block. | ||
| dst_phys = state_indices_tensor[:, 0] |
There was a problem hiding this comment.
Similar with https://github.com/vllm-project/vllm/pull/42792/changes#r3252543339, but on the SSM side.
dst_phys = state_indices_tensor[:, 0] writes the accepted SSM state into block_table[K_t] (this step's view). Step t+1 then reads from src_ssm_indices_tensor_d = block_table[(num_computed_new - 1) // block_size]. These two block indices coincide only when no boundary was crossed by the accepted tokens. With partial acceptance (k < boundary distance), the new committed-block index stays below K_t, and the next step reads stale SSM from a block the postprocess didn't update.
Suggested fix: postprocess dst_phys should be block_table[(num_computed_new - 1) // block_size] instead of state_indices_tensor[:, 0].
| num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() | ||
| # Get block containing last committed token. | ||
| committed_block_idx = torch.clamp( | ||
| (num_computed_tokens[:num_decodes] - 1) |
There was a problem hiding this comment.
committed_block_idx is computed from common_attn_metadata.compute_num_computed_tokens(). Under async scheduling this returns an optimistic count — it assumes the previous step's drafts were all accepted. When fewer are accepted, committed_block_idx overshoots by one and committed_phys points at a block that doesn't yet hold a committed state, so the SSM kernel may reads from the wrong slot.
There was a problem hiding this comment.
I think num_computed_tokens is updated with the actual number of accepted tokens at the end of each step here:
vllm/vllm/v1/worker/gpu/input_batch.py
Lines 472 to 474 in 57fef4e
The CPU mirror of this tensor is optimistic though, but I'm not using that one.
| # Copy for every layer in the group. | ||
| for layer_name in group.layer_names: | ||
| layer = fwd_ctx[layer_name] | ||
| ssm_state = layer.kv_cache[1] |
There was a problem hiding this comment.
ssm_state = layer.kv_cache[1] assumes index 1 is the temporal state. This holds for Mamba1/Mamba2/GDN, but not for models that register more than one conv state. For example, KDA uses (get_conv_copy_spec, get_conv_copy_spec, get_conv_copy_spec, get_temporal_copy_spec), so kv_cache[1] is conv, not SSM — the postprocess would silently corrupt the conv state in that case.
Also Qwen3-Next routes through GatedDeltaNetAttention , this PR doesn't propagate the change to that path (only mamba1 and mamba2 ?), so MRv2 + spec-decode + align on GDN models may have issues. (The Qwen3.5-35B-A3B-FP8 lm_eval may don't hit prefix cache)
| # committed block read is correct. | ||
| self._copy_ssm_staging_to_committed(input_batch, num_accepted_tokens) | ||
|
|
||
| def _copy_ssm_staging_to_committed( |
There was a problem hiding this comment.
In the previous implementation (#42406) or v1, when a step's accepted tokens crossed a block_size boundary, postprocess copied the running state into block_table[aligned_new_computed_tokens // block_size - 1] — the slot a future prefix-cache hit on that boundary token will read from.
In this PR, _copy_ssm_staging_to_committed only copies staging[k-1] -> state_indices_tensor[:, 0] (= block_table[K_t], this step's running block). There doesn't seem to be anything else that writes to block_table[aligned_new_computed // block_size - 1] for boundaries crossed during spec decode. If that's the case, a later request that prefix-hits such a boundary would read stale state from the aligned-save slot — am I missing a path that handles it?
There was a problem hiding this comment.
You're absolutely right, im not correctly handling the boundary crossing cases for either conv or SSM state. Fortunately it's pretty straightforward to fix. I added a check for boundary crossings and copy over both states to the completed blocks when any are detected.
There was a problem hiding this comment.
Thanks a lot for taking another shot at this and for cleaning it up into a standalone PR — splitting the SSM read/write indices via src_ssm_indices_tensor_d is a genuinely nice simplification compared to the preprocess-copy approach in my earlier PR, and I'd like to adopt that idea.
That said, on a careful first pass I think the current diff has a few correctness issues, mostly around what happens when a step advances the "running" mamba block position. The detailed comments are as above, happy to discuss any of these.
One thought on validation: gsm8k likely has too short a context to reliably exercise the prefix-caching paths (cross-block-boundary spec decode + later prefix hits). Something closer to test_mamba_prefix_cache_mrv2 in #42406 might catch the cases I'm worried about above.
Thanks again for pushing this forward — I'll also take the read/write index split idea back to #42406 and think about how to implement prefix caching better.
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
d914bac to
746f6c2
Compare
|
Thanks for taking a look @izhuhaoran. For completeness, i updated this PR with my best attempt to handle block boundary crossings for both conv and SSM state. I haven't tested these changes, as I don't plan to publish this PR, but I figured I'd include them in case they are helpful for your final implementation. |
Accuracy Benchmark
Server Command
Results
Performance Benchmark
Server Command
Results