From 28edfdcfaf6aa8dfc6c062786a6b3671d799f0c2 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Wed, 29 Apr 2026 13:53:08 +0300 Subject: [PATCH 01/12] Fix "all" mode support with SpecDec Signed-off-by: Roi Koren --- tests/kernels/mamba/test_mamba_mixer2.py | 77 ++++++++++- .../test_mamba_update_block_table.py | 130 +++++++++++++++++- .../layers/mamba/mamba_mixer2.py | 44 ++++-- vllm/v1/attention/backends/mamba_attn.py | 17 ++- 4 files changed, 244 insertions(+), 24 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index 973e7885c680..708bf9338dde 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -11,9 +11,14 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + Mixer2RMSNormGated, + _gather_decode_state_indices, +) from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed +from vllm.v1.attention.backends.utils import mamba_get_block_table_tensor +from vllm.v1.kv_cache_interface import MambaSpec @multi_gpu_test(num_gpus=2) @@ -136,3 +141,73 @@ def mixer2_gated_norm_tensor_parallel( atol=5e-3, rtol=1e-3, ) + + +def test_gather_decode_state_indices_no_spec(): + n, max_blocks = 3, 5 + state_indices = torch.arange(n * max_blocks, dtype=torch.int32).reshape( + n, max_blocks + ) + last_computed = torch.tensor([0, 1, 2], dtype=torch.int32) + last_scheduled = torch.tensor([1, 2, 3], dtype=torch.int32) + + in_slots, out_slots = _gather_decode_state_indices( + state_indices, last_computed, last_scheduled, num_spec_tokens=0 + ) + + assert in_slots.shape == (n,) + assert out_slots.shape == (n,) + torch.testing.assert_close(in_slots, torch.tensor([0, 6, 12], dtype=torch.int32)) + torch.testing.assert_close(out_slots, torch.tensor([1, 7, 13], dtype=torch.int32)) + + +def test_gather_decode_state_indices_with_spec_matches_align_layout(): + n, max_blocks_full, num_spec_tokens, block_size = 3, 5, 1, 16 + full_block_table = torch.arange(n * max_blocks_full, dtype=torch.int32).reshape( + n, max_blocks_full + ) + last_scheduled = torch.tensor([1, 2, 3], dtype=torch.int32) + seq_lens = (last_scheduled.to(torch.int64) + 1) * block_size + + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="align", + num_speculative_blocks=num_spec_tokens, + ) + align_indices = mamba_get_block_table_tensor( + full_block_table, seq_lens, spec, "align" + ) + + in_slots, _ = _gather_decode_state_indices( + full_block_table, + block_idx_last_computed_token_d=torch.tensor([0, 1, 2], dtype=torch.int32), + block_idx_last_scheduled_token_d=last_scheduled, + num_spec_tokens=num_spec_tokens, + ) + + assert in_slots.shape == align_indices.shape + torch.testing.assert_close(in_slots.to(align_indices.dtype), align_indices) + + +def test_gather_decode_state_indices_with_spec(): + n, max_blocks, num_spec_tokens = 3, 5, 1 + state_indices = torch.arange(n * max_blocks, dtype=torch.int32).reshape( + n, max_blocks + ) + last_computed = torch.tensor([0, 1, 2], dtype=torch.int32) + last_scheduled = torch.tensor([1, 2, 3], dtype=torch.int32) + + in_slots, out_slots = _gather_decode_state_indices( + state_indices, + last_computed, + last_scheduled, + num_spec_tokens=num_spec_tokens, + ) + + assert in_slots.shape == (n, 1 + num_spec_tokens) + assert out_slots.shape == (n, 1 + num_spec_tokens) + expected = torch.tensor([[1, 2], [7, 8], [13, 14]], dtype=torch.int32) + torch.testing.assert_close(in_slots, expected) + torch.testing.assert_close(out_slots, expected) diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index 923939053ece..6746af41b6df 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -32,17 +32,25 @@ class _ConcreteMambaBuilder( metadata_cls = BaseMambaAttentionMetadata -def _make_vllm_config(block_size, max_model_len, max_num_seqs): +def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0): """Create a minimal mock VllmConfig with only the fields the builder accesses, avoiding any model download / HF config inspection.""" + speculative_config = ( + SimpleNamespace( + num_speculative_tokens=num_speculative_tokens, + parallel_drafting=False, + ) + if num_speculative_tokens > 0 + else None + ) return SimpleNamespace( cache_config=SimpleNamespace(mamba_cache_mode="all"), compilation_config=SimpleNamespace( cudagraph_mode=CUDAGraphMode.FULL, max_cudagraph_capture_size=None, ), - speculative_config=None, - num_speculative_tokens=0, + speculative_config=speculative_config, + num_speculative_tokens=num_speculative_tokens, parallel_config=SimpleNamespace(decode_context_parallel_size=1), scheduler_config=SimpleNamespace(max_num_seqs=max_num_seqs), model_config=SimpleNamespace(max_model_len=max_model_len), @@ -59,7 +67,7 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): num_reqs = 4 device = torch.device("cpu") - vllm_config = _make_vllm_config(block_size, max_model_len, num_reqs) + vllm_config = _make_vllm_config(max_model_len, num_reqs) spec = MambaSpec( block_size=block_size, @@ -149,3 +157,117 @@ def shares_storage(tensor, buffer): metadata_b.block_idx_last_computed_token, block_idx_vals, ) + + +def test_state_indices_tensor_d_includes_num_speculative_blocks(): + """Regression test for https://github.com/vllm-project/vllm/issues/39809 + bug 1: with mamba_cache_mode='all' and speculative decoding enabled, + the cudagraph buffer for state_indices_tensor_d must allocate the same + per-request column count as the runtime block table, which includes + num_speculative_blocks trailing scratch columns.""" + + block_size = 16 + max_model_len = 256 + max_num_seqs = 4 + num_speculative_tokens = 1 + num_speculative_blocks = 2 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, + max_num_seqs, + num_speculative_tokens=num_speculative_tokens, + ) + + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + num_speculative_blocks=num_speculative_blocks, + ) + + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + + expected_cols = (max_model_len // block_size) + num_speculative_blocks + assert builder.state_indices_tensor_d.shape == (max_num_seqs, expected_cols) + + +def test_block_idx_cudagraph_capture_padded_by_num_reqs(): + """Regression test for https://github.com/vllm-project/vllm/issues/39809 + bug 2: with mamba_cache_mode='all' and spec decode, _update_metadata_for + _cudagraph_capture must slice block_idx_last_{scheduled,computed}_token + by the request count (padded_bs == num_reqs), not by num_decode_tokens. + Past num_decodes, the slice must be zero-filled.""" + + block_size = 16 + max_model_len = 256 + max_num_seqs = 8 + num_speculative_tokens = 1 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, + max_num_seqs, + num_speculative_tokens=num_speculative_tokens, + ) + + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + num_speculative_blocks=2, + ) + + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + + builder.block_idx_last_scheduled_token.fill_(-1) + builder.block_idx_last_computed_token.fill_(-1) + + num_decodes = 2 + num_reqs = 3 + num_decode_tokens = num_decodes * (1 + num_speculative_tokens) + seq_lens = torch.full((num_reqs,), 64, dtype=torch.int32, device=device) + block_idx_vals = torch.tensor([3, 5], dtype=torch.int32, device=device) + state_indices_d = torch.zeros( + (num_decodes, builder.state_indices_tensor_d.shape[1]), + dtype=torch.int32, + device=device, + ) + query_start_loc_d = torch.arange( + num_decodes + 1, dtype=torch.int32, device=device + ) * (1 + num_speculative_tokens) + num_accepted_tokens = torch.ones(num_decodes, dtype=torch.int32, device=device) + + metadata = BaseMambaAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_reqs=num_reqs, + has_initial_states_p=None, + query_start_loc_p=None, + num_computed_tokens_p=None, + state_indices_tensor_p=None, + state_indices_tensor_d=state_indices_d, + query_start_loc_d=query_start_loc_d, + num_accepted_tokens=num_accepted_tokens, + block_idx_last_scheduled_token=block_idx_vals, + block_idx_first_scheduled_token_p=None, + block_idx_last_computed_token=block_idx_vals, + seq_lens=seq_lens, + ) + + out = builder._update_metadata_for_cudagraph_capture(metadata) + + assert out.block_idx_last_scheduled_token.shape == (num_reqs,) + assert out.block_idx_last_computed_token.shape == (num_reqs,) + torch.testing.assert_close( + out.block_idx_last_scheduled_token[:num_decodes], block_idx_vals + ) + torch.testing.assert_close( + out.block_idx_last_computed_token[:num_decodes], block_idx_vals + ) + assert torch.all(out.block_idx_last_scheduled_token[num_decodes:] == 0) + assert torch.all(out.block_idx_last_computed_token[num_decodes:] == 0) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 674134f373ea..4e8d2549ce92 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -60,6 +60,30 @@ # Added by the IBM Team, 2024 +def _gather_decode_state_indices( + state_indices_tensor_d: torch.Tensor, + block_idx_last_computed_token_d: torch.Tensor, + block_idx_last_scheduled_token_d: torch.Tensor, + num_spec_tokens: int, +) -> tuple[torch.Tensor, torch.Tensor]: + if num_spec_tokens > 0: + offsets = torch.arange( + 1 + num_spec_tokens, + device=block_idx_last_scheduled_token_d.device, + dtype=block_idx_last_scheduled_token_d.dtype, + ) + indices = block_idx_last_scheduled_token_d.unsqueeze(1) + offsets.unsqueeze(0) + gathered = state_indices_tensor_d.gather(1, indices) + return gathered, gathered + input_slots = state_indices_tensor_d.gather( + 1, block_idx_last_computed_token_d.unsqueeze(1) + ).squeeze(1) + output_slots = state_indices_tensor_d.gather( + 1, block_idx_last_scheduled_token_d.unsqueeze(1) + ).squeeze(1) + return input_slots, output_slots + + # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated # --8<-- [start:mixer2_gated_rms_norm] @CustomOp.register("mixer2_gated_rms_norm") @@ -944,18 +968,14 @@ def conv_ssm_forward( if has_decode: assert state_indices_tensor_d is not None if is_mamba_cache_all: - state_indices_tensor_d_input = state_indices_tensor_d.gather( - 1, block_idx_last_computed_token_d.unsqueeze(1) - ).squeeze(1) - state_indices_tensor_d_output = state_indices_tensor_d.gather( - 1, block_idx_last_scheduled_token_d.unsqueeze(1) - ).squeeze(1) - # for decode: - # block_idx_first_scheduled_token_d == - # block_idx_last_scheduled_token_d - # at block boundaries: - # block_idx_first_scheduled_token_d > - # block_idx_last_computed_token_d + state_indices_tensor_d_input, state_indices_tensor_d_output = ( + _gather_decode_state_indices( + state_indices_tensor_d, + block_idx_last_computed_token_d, + block_idx_last_scheduled_token_d, + self.num_spec, + ) + ) else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 716dfcde592f..9dfe85b94079 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -108,12 +108,13 @@ def __init__( ) if self.vllm_config.cache_config.mamba_cache_mode == "all": - max_num_blocks = cdiv( - self.vllm_config.model_config.max_model_len, - self.kv_cache_spec.block_size, + max_num_blocks = ( + cdiv( + self.vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size, + ) + + self.kv_cache_spec.num_speculative_blocks ) - # Speculative decoding not supported with prefix caching, - # so keep shape consistent with prefill buffer # TODO: reduce this size as needed for decode-only cudagraph capture self.state_indices_tensor_d: torch.Tensor = torch.empty( ( @@ -524,16 +525,18 @@ def _update_metadata_for_cudagraph_capture( non_blocking=True, ) block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - : metadata.num_decode_tokens + :padded_bs ] + block_idx_last_scheduled_token[metadata.num_decodes :] = 0 self.block_idx_last_computed_token[: metadata.num_decodes].copy_( block_idx_last_computed_token[: metadata.num_decodes], non_blocking=True, ) block_idx_last_computed_token = self.block_idx_last_computed_token[ - : metadata.num_decode_tokens + :padded_bs ] + block_idx_last_computed_token[metadata.num_decodes :] = 0 return replace( metadata, From 4e42497ea4e9d88b7a05f45bab9331d75a1e2fb7 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Wed, 29 Apr 2026 13:55:41 +0300 Subject: [PATCH 02/12] Delete some tests Signed-off-by: Roi Koren --- tests/kernels/mamba/test_mamba_mixer2.py | 77 +----------------------- 1 file changed, 1 insertion(+), 76 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index 708bf9338dde..973e7885c680 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -11,14 +11,9 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - Mixer2RMSNormGated, - _gather_decode_state_indices, -) +from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed -from vllm.v1.attention.backends.utils import mamba_get_block_table_tensor -from vllm.v1.kv_cache_interface import MambaSpec @multi_gpu_test(num_gpus=2) @@ -141,73 +136,3 @@ def mixer2_gated_norm_tensor_parallel( atol=5e-3, rtol=1e-3, ) - - -def test_gather_decode_state_indices_no_spec(): - n, max_blocks = 3, 5 - state_indices = torch.arange(n * max_blocks, dtype=torch.int32).reshape( - n, max_blocks - ) - last_computed = torch.tensor([0, 1, 2], dtype=torch.int32) - last_scheduled = torch.tensor([1, 2, 3], dtype=torch.int32) - - in_slots, out_slots = _gather_decode_state_indices( - state_indices, last_computed, last_scheduled, num_spec_tokens=0 - ) - - assert in_slots.shape == (n,) - assert out_slots.shape == (n,) - torch.testing.assert_close(in_slots, torch.tensor([0, 6, 12], dtype=torch.int32)) - torch.testing.assert_close(out_slots, torch.tensor([1, 7, 13], dtype=torch.int32)) - - -def test_gather_decode_state_indices_with_spec_matches_align_layout(): - n, max_blocks_full, num_spec_tokens, block_size = 3, 5, 1, 16 - full_block_table = torch.arange(n * max_blocks_full, dtype=torch.int32).reshape( - n, max_blocks_full - ) - last_scheduled = torch.tensor([1, 2, 3], dtype=torch.int32) - seq_lens = (last_scheduled.to(torch.int64) + 1) * block_size - - spec = MambaSpec( - block_size=block_size, - shapes=((1,), (1,)), - dtypes=(torch.float32,), - mamba_cache_mode="align", - num_speculative_blocks=num_spec_tokens, - ) - align_indices = mamba_get_block_table_tensor( - full_block_table, seq_lens, spec, "align" - ) - - in_slots, _ = _gather_decode_state_indices( - full_block_table, - block_idx_last_computed_token_d=torch.tensor([0, 1, 2], dtype=torch.int32), - block_idx_last_scheduled_token_d=last_scheduled, - num_spec_tokens=num_spec_tokens, - ) - - assert in_slots.shape == align_indices.shape - torch.testing.assert_close(in_slots.to(align_indices.dtype), align_indices) - - -def test_gather_decode_state_indices_with_spec(): - n, max_blocks, num_spec_tokens = 3, 5, 1 - state_indices = torch.arange(n * max_blocks, dtype=torch.int32).reshape( - n, max_blocks - ) - last_computed = torch.tensor([0, 1, 2], dtype=torch.int32) - last_scheduled = torch.tensor([1, 2, 3], dtype=torch.int32) - - in_slots, out_slots = _gather_decode_state_indices( - state_indices, - last_computed, - last_scheduled, - num_spec_tokens=num_spec_tokens, - ) - - assert in_slots.shape == (n, 1 + num_spec_tokens) - assert out_slots.shape == (n, 1 + num_spec_tokens) - expected = torch.tensor([[1, 2], [7, 8], [13, 14]], dtype=torch.int32) - torch.testing.assert_close(in_slots, expected) - torch.testing.assert_close(out_slots, expected) From 3d0a9b67435bab23e2fb4376a2fca8e5ea46f70a Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Wed, 29 Apr 2026 13:57:59 +0300 Subject: [PATCH 03/12] Revert "Default to 'align' mamba cache mode for Mamba-based models when speculative decoding is enabled (#40454)" This reverts commit f819265a4ab0187181575c02174ec4a2f91d9220. Signed-off-by: Roi Koren --- vllm/model_executor/models/config.py | 29 +++++++++------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 459c16f8ec97..422acca642a9 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -350,26 +350,15 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config.enable_prefix_caching: if cache_config.mamba_cache_mode == "none": - if ( - model_config.supports_mamba_prefix_caching - and vllm_config.speculative_config is not None - ): - cache_config.mamba_cache_mode = "align" - logger.warning( - "Mamba cache mode is set to 'align' for %s by default " - "when prefix caching and speculative decoding are enabled", - model_config.architecture, - ) - else: - cache_config.mamba_cache_mode = ( - "all" if model_config.supports_mamba_prefix_caching else "align" - ) - logger.warning( - "Mamba cache mode is set to '%s' for %s by default " - "when prefix caching is enabled", - cache_config.mamba_cache_mode, - model_config.architecture, - ) + cache_config.mamba_cache_mode = ( + "all" if model_config.supports_mamba_prefix_caching else "align" + ) + logger.warning( + "Mamba cache mode is set to '%s' for %s by default " + "when prefix caching is enabled", + cache_config.mamba_cache_mode, + model_config.architecture, + ) if ( cache_config.mamba_cache_mode == "all" and not model_config.supports_mamba_prefix_caching From 432bbdc0354455844c4b70c7df253761e4c45a8b Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Wed, 29 Apr 2026 16:43:49 +0300 Subject: [PATCH 04/12] Fix mypy Signed-off-by: Roi Koren --- vllm/v1/attention/backends/mamba_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 9dfe85b94079..4ccfb0b87491 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -113,7 +113,7 @@ def __init__( self.vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size, ) - + self.kv_cache_spec.num_speculative_blocks + + kv_cache_spec.num_speculative_blocks ) # TODO: reduce this size as needed for decode-only cudagraph capture self.state_indices_tensor_d: torch.Tensor = torch.empty( From 4a0ae44444152e10fb882d4e27c6c44020ecc0e6 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Tue, 5 May 2026 15:53:04 +0300 Subject: [PATCH 05/12] CR Signed-off-by: Roi Koren --- .../layers/mamba/mamba_mixer2.py | 50 +++++++------------ 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 4e8d2549ce92..297f7dbd32bd 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -60,30 +60,6 @@ # Added by the IBM Team, 2024 -def _gather_decode_state_indices( - state_indices_tensor_d: torch.Tensor, - block_idx_last_computed_token_d: torch.Tensor, - block_idx_last_scheduled_token_d: torch.Tensor, - num_spec_tokens: int, -) -> tuple[torch.Tensor, torch.Tensor]: - if num_spec_tokens > 0: - offsets = torch.arange( - 1 + num_spec_tokens, - device=block_idx_last_scheduled_token_d.device, - dtype=block_idx_last_scheduled_token_d.dtype, - ) - indices = block_idx_last_scheduled_token_d.unsqueeze(1) + offsets.unsqueeze(0) - gathered = state_indices_tensor_d.gather(1, indices) - return gathered, gathered - input_slots = state_indices_tensor_d.gather( - 1, block_idx_last_computed_token_d.unsqueeze(1) - ).squeeze(1) - output_slots = state_indices_tensor_d.gather( - 1, block_idx_last_scheduled_token_d.unsqueeze(1) - ).squeeze(1) - return input_slots, output_slots - - # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated # --8<-- [start:mixer2_gated_rms_norm] @CustomOp.register("mixer2_gated_rms_norm") @@ -968,14 +944,26 @@ def conv_ssm_forward( if has_decode: assert state_indices_tensor_d is not None if is_mamba_cache_all: - state_indices_tensor_d_input, state_indices_tensor_d_output = ( - _gather_decode_state_indices( - state_indices_tensor_d, - block_idx_last_computed_token_d, - block_idx_last_scheduled_token_d, - self.num_spec, + if self.num_spec > 0: + offsets = torch.arange( + 1 + self.num_spec, + device=block_idx_last_scheduled_token_d.device, + dtype=block_idx_last_scheduled_token_d.dtype, ) - ) + indices = block_idx_last_scheduled_token_d.unsqueeze( + 1 + ) + offsets.unsqueeze(0) + state_indices_tensor_d_input = state_indices_tensor_d.gather( + 1, indices + ) + state_indices_tensor_d_output = state_indices_tensor_d_input + else: + state_indices_tensor_d_input = state_indices_tensor_d.gather( + 1, block_idx_last_computed_token_d.unsqueeze(1) + ).squeeze(1) + state_indices_tensor_d_output = state_indices_tensor_d.gather( + 1, block_idx_last_scheduled_token_d.unsqueeze(1) + ).squeeze(1) else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d From e1685ebfeb80e9418445821f1f5832e0b7de38a3 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Tue, 12 May 2026 16:34:36 +0300 Subject: [PATCH 06/12] Plumb prev-step input anchor for all-mode + SpecDec Signed-off-by: Roi Koren --- .../test_mamba_update_block_table.py | 207 ++++++++++++++++++ vllm/v1/attention/backends/mamba2_attn.py | 4 +- vllm/v1/attention/backends/mamba_attn.py | 81 ++++++- vllm/v1/worker/gpu_input_batch.py | 2 + vllm/v1/worker/gpu_model_runner.py | 42 ++++ 5 files changed, 333 insertions(+), 3 deletions(-) diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index 6746af41b6df..a601eafd3b72 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -20,6 +20,7 @@ from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadata, BaseMambaAttentionMetadataBuilder, + _compute_block_idx_last_scheduled_prev_step, ) from vllm.v1.kv_cache_interface import MambaSpec @@ -114,6 +115,7 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): block_idx_last_computed_token=( builder_a.block_idx_last_computed_token[:num_reqs] ), + block_idx_last_scheduled_token_prev_step=None, seq_lens=seq_lens, ) @@ -256,6 +258,7 @@ def test_block_idx_cudagraph_capture_padded_by_num_reqs(): block_idx_last_scheduled_token=block_idx_vals, block_idx_first_scheduled_token_p=None, block_idx_last_computed_token=block_idx_vals, + block_idx_last_scheduled_token_prev_step=None, seq_lens=seq_lens, ) @@ -271,3 +274,207 @@ def test_block_idx_cudagraph_capture_padded_by_num_reqs(): ) assert torch.all(out.block_idx_last_scheduled_token[num_decodes:] == 0) assert torch.all(out.block_idx_last_computed_token[num_decodes:] == 0) + + +def test_block_idx_prev_step_persistent_buffer_allocated(): + """With mamba_cache_mode='all' + spec decode, the builder must allocate + block_idx_last_scheduled_token_prev_step as a persistent buffer with the + same shape as the existing block_idx_last_{scheduled,computed}_token + buffers, so cudagraph capture records a stable pointer for the prev-step + input anchor consumed by mamba_mixer2's input gather.""" + block_size = 16 + max_model_len = 256 + max_num_seqs = 8 + num_speculative_tokens = 1 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, + max_num_seqs, + num_speculative_tokens=num_speculative_tokens, + ) + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + num_speculative_blocks=2, + ) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + + assert hasattr(builder, "block_idx_last_scheduled_token_prev_step") + assert builder.block_idx_last_scheduled_token_prev_step.shape == (max_num_seqs,) + assert builder.block_idx_last_scheduled_token_prev_step.dtype == torch.int32 + + +def test_block_idx_prev_step_persistent_buffer_skipped_without_spec_decode(): + """Without spec decode, the prev-step buffer is unused and must not be + allocated — the input anchor reduces to last_computed_token.""" + block_size = 16 + max_model_len = 256 + max_num_seqs = 8 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, max_num_seqs, num_speculative_tokens=0 + ) + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + ) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + + assert not hasattr(builder, "block_idx_last_scheduled_token_prev_step") + + +def test_block_idx_prev_step_cudagraph_capture_uses_persistent_buffer(): + """_update_metadata_for_cudagraph_capture must copy the prev-step anchor + into the builder's persistent buffer (so cudagraph replay reads from the + same underlying memory), pad past num_decodes with zero, and return a + slice of the persistent buffer in the metadata.""" + block_size = 16 + max_model_len = 256 + max_num_seqs = 8 + num_speculative_tokens = 1 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, + max_num_seqs, + num_speculative_tokens=num_speculative_tokens, + ) + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + num_speculative_blocks=2, + ) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + builder.block_idx_last_scheduled_token.fill_(-1) + builder.block_idx_last_computed_token.fill_(-1) + builder.block_idx_last_scheduled_token_prev_step.fill_(-1) + + num_decodes = 2 + num_reqs = 3 + num_decode_tokens = num_decodes * (1 + num_speculative_tokens) + seq_lens = torch.full((num_reqs,), 64, dtype=torch.int32, device=device) + block_idx_vals = torch.tensor([3, 5], dtype=torch.int32, device=device) + prev_step_vals = torch.tensor([2, 4], dtype=torch.int32, device=device) + state_indices_d = torch.zeros( + (num_decodes, builder.state_indices_tensor_d.shape[1]), + dtype=torch.int32, + device=device, + ) + query_start_loc_d = torch.arange( + num_decodes + 1, dtype=torch.int32, device=device + ) * (1 + num_speculative_tokens) + num_accepted_tokens = torch.ones(num_decodes, dtype=torch.int32, device=device) + + metadata = BaseMambaAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_reqs=num_reqs, + has_initial_states_p=None, + query_start_loc_p=None, + num_computed_tokens_p=None, + state_indices_tensor_p=None, + state_indices_tensor_d=state_indices_d, + query_start_loc_d=query_start_loc_d, + num_accepted_tokens=num_accepted_tokens, + block_idx_last_scheduled_token=block_idx_vals, + block_idx_first_scheduled_token_p=None, + block_idx_last_computed_token=block_idx_vals, + block_idx_last_scheduled_token_prev_step=prev_step_vals, + seq_lens=seq_lens, + ) + + out = builder._update_metadata_for_cudagraph_capture(metadata) + + # Output field exists and is identity-shared with the persistent buffer. + assert out.block_idx_last_scheduled_token_prev_step is not None + assert ( + out.block_idx_last_scheduled_token_prev_step.untyped_storage().data_ptr() + == builder.block_idx_last_scheduled_token_prev_step.untyped_storage().data_ptr() + ), ( + "prev-step buffer must live in the builder's persistent buffer, not " + "in the caller-provided tensor" + ) + + # Padded by num_reqs (not num_decode_tokens) — same fix as bug 2 for the + # other block_idx_* fields. + assert out.block_idx_last_scheduled_token_prev_step.shape == (num_reqs,) + + # First num_decodes values: input values copied through. + torch.testing.assert_close( + out.block_idx_last_scheduled_token_prev_step[:num_decodes], + prev_step_vals, + ) + + # Tail values past num_decodes: zero-filled padding for cudagraph capture. + assert torch.all(out.block_idx_last_scheduled_token_prev_step[num_decodes:] == 0) + + +def test_prev_step_anchor_first_decode_after_prefill(): + """First decode step after prefill: the previous step (prefill) stored + its terminal mamba state at block (num_computed - 1) // block_size. + The worker-side tracker has no entry for these requests, so we pass + -1 to indicate the fallback path.""" + block_size = 16 + num_computed = torch.tensor([100, 16, 64, 1], dtype=torch.int64) + prev_last_scheduled = torch.full((4,), -1, dtype=torch.int64) + + result = _compute_block_idx_last_scheduled_prev_step( + num_computed, prev_last_scheduled, block_size + ) + + expected = torch.tensor([6, 0, 3, 0], dtype=torch.int64) + torch.testing.assert_close(result, expected) + + +def test_prev_step_anchor_subsequent_decode_uses_tracker(): + """For decode steps with a prior decode step, the worker-side tracker + holds the previous step's last_scheduled block index. The function must + use that value verbatim — even when last_computed (= (num_computed - 1) + // block_size) would give a different answer, which is exactly what + happens after partial draft acceptance straddles a block boundary.""" + block_size = 16 + # num_computed[N] = 16 means the last committed token is at position 15 + # (block 0). Naive last_computed would be 0. But the tracker says the + # previous step ended with last_scheduled = 1 (because step N-1 had + # scheduled tokens up to position 17, in block 1; some drafts were + # rejected so only 2 tokens committed, leaving the committed tail in + # block 0). The tracker's value (1) is what we need. + num_computed = torch.tensor([16], dtype=torch.int64) + prev_last_scheduled = torch.tensor([1], dtype=torch.int64) + + result = _compute_block_idx_last_scheduled_prev_step( + num_computed, prev_last_scheduled, block_size + ) + + expected = torch.tensor([1], dtype=torch.int64) + torch.testing.assert_close(result, expected) + + +def test_prev_step_anchor_mixed_batch(): + """Mixed batch: some requests are first-decode (tracker -1), others are + subsequent decode (tracker has a value). Each path resolves + independently.""" + block_size = 16 + num_computed = torch.tensor([16, 100, 16, 64], dtype=torch.int64) + # req 0: subsequent decode, tracker says block 1 (boundary-crossing case) + # req 1: first decode after prefill, fallback to last_computed = 6 + # req 2: subsequent decode, tracker says block 0 (no boundary involvement) + # req 3: first decode, fallback to last_computed = 3 + prev_last_scheduled = torch.tensor([1, -1, 0, -1], dtype=torch.int64) + + result = _compute_block_idx_last_scheduled_prev_step( + num_computed, prev_last_scheduled, block_size + ) + + expected = torch.tensor([1, 6, 0, 3], dtype=torch.int64) + torch.testing.assert_close(result, expected) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index fa7d4bd2ec51..5f25c4a79520 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -137,7 +137,9 @@ def build( **kwargs: Any, ) -> Mamba2AttentionMetadata: common = self._compute_common_metadata( - common_attn_metadata, num_accepted_tokens=kwargs.get("num_accepted_tokens") + common_attn_metadata, + num_accepted_tokens=kwargs.get("num_accepted_tokens"), + prev_last_scheduled_idx=kwargs.get("prev_last_scheduled_idx"), ) seq_idx_p = None diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 4ccfb0b87491..ab28fdf49668 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -26,6 +26,19 @@ M = TypeVar("M", bound="BaseMambaAttentionMetadata") +def _compute_block_idx_last_scheduled_prev_step( + num_computed_tokens: torch.Tensor, + prev_last_scheduled_idx: torch.Tensor, + mamba_block_size: int, +) -> torch.Tensor: + fallback = torch.clamp((num_computed_tokens - 1) // mamba_block_size, min=0) + return torch.where( + prev_last_scheduled_idx >= 0, + prev_last_scheduled_idx, + fallback, + ) + + @dataclass class BaseMambaAttentionMetadata: num_prefills: int @@ -56,6 +69,7 @@ class BaseMambaAttentionMetadata: block_idx_last_scheduled_token: torch.Tensor | None block_idx_first_scheduled_token_p: torch.Tensor | None block_idx_last_computed_token: torch.Tensor | None + block_idx_last_scheduled_token_prev_step: torch.Tensor | None # The following tensor is only used for prefix caching in align mode seq_lens: torch.Tensor @@ -134,6 +148,14 @@ def __init__( dtype=torch.int32, device=device, ) + if self.use_spec_decode: + self.block_idx_last_scheduled_token_prev_step: torch.Tensor = ( + torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + ) else: self.state_indices_tensor_d = torch.empty( (self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens), @@ -177,7 +199,23 @@ def build_for_cudagraph_capture( if self.num_spec_tokens > 0: num_accepted_tokens = torch.diff(m.query_start_loc) - return self.build(0, m, num_accepted_tokens=num_accepted_tokens) + prev_last_scheduled_idx = None + if ( + self.use_spec_decode + and self.vllm_config.cache_config.mamba_cache_mode == "all" + ): + prev_last_scheduled_idx = torch.zeros( + (m.num_reqs,), + dtype=torch.int32, + device=m.query_start_loc.device, + ) + + return self.build( + 0, + m, + num_accepted_tokens=num_accepted_tokens, + prev_last_scheduled_idx=prev_last_scheduled_idx, + ) def build( self, @@ -186,6 +224,7 @@ def build( fast_build: bool = False, *, num_accepted_tokens: torch.Tensor | None = None, + prev_last_scheduled_idx: torch.Tensor | None = None, **kwargs: Any, ) -> M: """ @@ -193,7 +232,9 @@ def build( Subclasses (e.g., Mamba2) can override to add additional metadata. """ return self._compute_common_metadata( - common_attn_metadata, num_accepted_tokens=num_accepted_tokens + common_attn_metadata, + num_accepted_tokens=num_accepted_tokens, + prev_last_scheduled_idx=prev_last_scheduled_idx, ) def _compute_chunk_metadata( @@ -342,6 +383,7 @@ def _compute_common_metadata( common_attn_metadata: CommonAttentionMetadata, *, num_accepted_tokens: torch.Tensor | None = None, + prev_last_scheduled_idx: torch.Tensor | None = None, ) -> M: """ Compute metadata common to both Mamba1 and Mamba2. @@ -376,6 +418,7 @@ def _compute_common_metadata( block_idx_first_scheduled_token_p = None block_idx_last_computed_token = None block_idx_last_scheduled_token = None + block_idx_last_scheduled_token_prev_step = None # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None @@ -394,6 +437,14 @@ def _compute_common_metadata( ) = self._compute_prefix_caching_block_indices( common_attn_metadata, mamba_block_size ) + if self.use_spec_decode and prev_last_scheduled_idx is not None: + block_idx_last_scheduled_token_prev_step = ( + _compute_block_idx_last_scheduled_prev_step( + num_computed_tokens, + prev_last_scheduled_idx, + mamba_block_size, + ) + ) else: state_indices_tensor = mamba_get_block_table_tensor( common_attn_metadata.block_table_tensor, @@ -471,6 +522,9 @@ def _compute_common_metadata( block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, block_idx_last_computed_token=block_idx_last_computed_token, + block_idx_last_scheduled_token_prev_step=( + block_idx_last_scheduled_token_prev_step + ), num_computed_tokens_p=num_computed_tokens_p, num_reqs=num_reqs, seq_lens=common_attn_metadata.seq_lens, @@ -494,6 +548,9 @@ def _update_metadata_for_cudagraph_capture( num_accepted_tokens = metadata.num_accepted_tokens block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token block_idx_last_computed_token = metadata.block_idx_last_computed_token + block_idx_last_scheduled_token_prev_step = ( + metadata.block_idx_last_scheduled_token_prev_step + ) if ( metadata.num_prefills == 0 and metadata.num_decodes <= self.decode_cudagraph_max_bs @@ -538,6 +595,23 @@ def _update_metadata_for_cudagraph_capture( ] block_idx_last_computed_token[metadata.num_decodes :] = 0 + if ( + self.use_spec_decode + and block_idx_last_scheduled_token_prev_step is not None + ): + self.block_idx_last_scheduled_token_prev_step[ + : metadata.num_decodes + ].copy_( + block_idx_last_scheduled_token_prev_step[ + : metadata.num_decodes + ], + non_blocking=True, + ) + block_idx_last_scheduled_token_prev_step = ( + self.block_idx_last_scheduled_token_prev_step[:padded_bs] + ) + block_idx_last_scheduled_token_prev_step[metadata.num_decodes :] = 0 + return replace( metadata, state_indices_tensor_d=state_indices_tensor_d, @@ -545,6 +619,9 @@ def _update_metadata_for_cudagraph_capture( num_accepted_tokens=num_accepted_tokens, block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_last_computed_token=block_idx_last_computed_token, + block_idx_last_scheduled_token_prev_step=( + block_idx_last_scheduled_token_prev_step + ), ) def update_block_table( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 31571e9db26b..7e92037f426b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -59,6 +59,8 @@ class CachedRequestState: # Used when both async_scheduling and spec_decode are enabled. prev_num_draft_len: int = 0 + mamba_last_scheduled_idx: int = -1 + # for pooling models pooling_params: PoolingParams | None = None pooling_states: PoolingStates | None = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2c010040bc21..5f35be24bfce 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -888,6 +888,12 @@ def __init__( self.kv_connector_output: KVConnectorOutput | None = None self.mamba_state_idx: dict[str, int] = {} self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None + self.mamba_prev_last_scheduled_idx: CpuGpuBuffer | None = None + if self.cache_config.mamba_cache_mode == "all" and self.num_spec_tokens > 0: + self.mamba_prev_last_scheduled_idx = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self._mamba_block_size: int | None = None self.layerwise_nvtx_hooks_registered = False def update_max_model_len(self, max_model_len: int) -> None: @@ -1310,6 +1316,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None # Update the cached states. req_state.num_computed_tokens = num_computed_tokens + if resumed_from_preemption: + req_state.mamba_last_scheduled_idx = -1 + if not is_last_rank: if not req_data.new_token_ids: # Async scheduled PP: Sampled tokens propagated via GPU broadcast. @@ -1515,6 +1524,25 @@ def _update_states_after_model_execute( assert self.num_accepted_tokens_event is not None self.num_accepted_tokens_event.record() + if self.mamba_prev_last_scheduled_idx is not None: + if self._mamba_block_size is None: + self._mamba_block_size = next( + g.kv_cache_spec.block_size + for g in self.kv_cache_config.kv_cache_groups + if isinstance(g.kv_cache_spec, MambaSpec) + ) + block_size = self._mamba_block_size + full_decode_len = 1 + self.num_spec_tokens + scheduled = scheduler_output.num_scheduled_tokens + for req_id in self.input_batch.req_ids[:num_reqs]: + req = self.requests[req_id] + num_query = scheduled.get(req_id, 0) + if num_query == full_decode_len: + seq_len = req.num_computed_tokens + num_query + req.mamba_last_scheduled_idx = max(0, (seq_len - 1) // block_size) + else: + req.mamba_last_scheduled_idx = -1 + def _update_streaming_request( self, req_id: str, new_req_data: NewRequestData ) -> CachedRequestState: @@ -2012,6 +2040,13 @@ def _prepare_inputs( self.num_accepted_tokens.np.fill(1) self.num_accepted_tokens.gpu.fill_(1) + if self.mamba_prev_last_scheduled_idx is not None: + np_view = self.mamba_prev_last_scheduled_idx.np + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + np_view[i] = self.requests[req_id].mamba_last_scheduled_idx + np_view[num_reqs:].fill(-1) + self.mamba_prev_last_scheduled_idx.copy_to_gpu() + # Update num_computed_tokens on GPU. In async spec decode, # CPU values are optimistic (all drafts accepted). The kernel # corrects on GPU using the previous step's @@ -2319,6 +2354,13 @@ def _build_attn_group_metadata( :num_reqs_padded ], ) + if ( + isinstance(builder, Mamba2AttentionMetadataBuilder) + and self.mamba_prev_last_scheduled_idx is not None + ): + extra_attn_metadata_args["prev_last_scheduled_idx"] = ( + self.mamba_prev_last_scheduled_idx.gpu[:num_reqs_padded] + ) if for_cudagraph_capture: attn_metadata_i = builder.build_for_cudagraph_capture( From e7ad4fa70ca15a28bc252a098634b52b436ea943 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Tue, 12 May 2026 16:54:45 +0300 Subject: [PATCH 07/12] Fix state indices Signed-off-by: Roi Koren --- .../layers/mamba/mamba_mixer2.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 297f7dbd32bd..0848c576f38a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -755,6 +755,14 @@ def conv_ssm_forward( dim=0, ) ) + if attn_metadata.block_idx_last_scheduled_token_prev_step is not None: + block_idx_last_scheduled_token_prev_step_d, _ = torch.split( + attn_metadata.block_idx_last_scheduled_token_prev_step, + [num_decodes, num_prefills], + dim=0, + ) + else: + block_idx_last_scheduled_token_prev_step_d = None # Prefill-only variables: block_idx_first_scheduled_token_p = ( attn_metadata.block_idx_first_scheduled_token_p @@ -766,6 +774,7 @@ def conv_ssm_forward( block_idx_first_scheduled_token_p = None block_idx_last_scheduled_token_d = None block_idx_last_computed_token_d = None + block_idx_last_scheduled_token_prev_step_d = None num_computed_tokens_p = None preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( @@ -945,18 +954,25 @@ def conv_ssm_forward( assert state_indices_tensor_d is not None if is_mamba_cache_all: if self.num_spec > 0: + assert block_idx_last_scheduled_token_prev_step_d is not None offsets = torch.arange( 1 + self.num_spec, device=block_idx_last_scheduled_token_d.device, dtype=block_idx_last_scheduled_token_d.dtype, ) - indices = block_idx_last_scheduled_token_d.unsqueeze( + input_indices = ( + block_idx_last_scheduled_token_prev_step_d.unsqueeze(1) + + offsets.unsqueeze(0) + ) + output_indices = block_idx_last_scheduled_token_d.unsqueeze( 1 ) + offsets.unsqueeze(0) state_indices_tensor_d_input = state_indices_tensor_d.gather( - 1, indices + 1, input_indices + ) + state_indices_tensor_d_output = state_indices_tensor_d.gather( + 1, output_indices ) - state_indices_tensor_d_output = state_indices_tensor_d_input else: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) From 99a2791a708f4a787251198cf7a0c5f71a6e35f8 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Wed, 13 May 2026 13:18:49 +0300 Subject: [PATCH 08/12] Small refactor Signed-off-by: Roi Koren --- vllm/v1/worker/gpu_input_batch.py | 2 - vllm/v1/worker/gpu_model_runner.py | 63 +++++------- vllm/v1/worker/mamba_utils.py | 159 ++++++++++++++++++++--------- 3 files changed, 135 insertions(+), 89 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 7e92037f426b..31571e9db26b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -59,8 +59,6 @@ class CachedRequestState: # Used when both async_scheduling and spec_decode are enabled. prev_num_draft_len: int = 0 - mamba_last_scheduled_idx: int = -1 - # for pooling models pooling_params: PoolingParams | None = None pooling_states: PoolingStates | None = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5f35be24bfce..996778aa05ee 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -893,7 +893,6 @@ def __init__( self.mamba_prev_last_scheduled_idx = self._make_buffer( self.max_num_reqs, dtype=torch.int32 ) - self._mamba_block_size: int | None = None self.layerwise_nvtx_hooks_registered = False def update_max_model_len(self, max_model_len: int) -> None: @@ -1317,7 +1316,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None req_state.num_computed_tokens = num_computed_tokens if resumed_from_preemption: - req_state.mamba_last_scheduled_idx = -1 + self.mamba_state_idx.pop(req_id, None) if not is_last_rank: if not req_data.new_token_ids: @@ -1502,21 +1501,12 @@ def _update_states_after_model_execute( num_reqs = output_token_ids.size(0) self.num_accepted_tokens.gpu[:num_reqs] = (output_token_ids != -1).sum(dim=1) - if self.cache_config.mamba_cache_mode == "align": + is_align = self.cache_config.mamba_cache_mode == "align" + if is_align: for i, num_tokens in enumerate( self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() ): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens - mamba_utils.postprocess_mamba( - scheduler_output, - self.kv_cache_config, - self.input_batch, - self.requests, - self.mamba_state_idx, - self.compilation_config.static_forward_context, - self.model.get_mamba_state_copy_func(), - self._get_mamba_copy_bufs(), - ) else: self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_( self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True @@ -1524,24 +1514,23 @@ def _update_states_after_model_execute( assert self.num_accepted_tokens_event is not None self.num_accepted_tokens_event.record() - if self.mamba_prev_last_scheduled_idx is not None: - if self._mamba_block_size is None: - self._mamba_block_size = next( - g.kv_cache_spec.block_size - for g in self.kv_cache_config.kv_cache_groups - if isinstance(g.kv_cache_spec, MambaSpec) - ) - block_size = self._mamba_block_size - full_decode_len = 1 + self.num_spec_tokens - scheduled = scheduler_output.num_scheduled_tokens - for req_id in self.input_batch.req_ids[:num_reqs]: - req = self.requests[req_id] - num_query = scheduled.get(req_id, 0) - if num_query == full_decode_len: - seq_len = req.num_computed_tokens + num_query - req.mamba_last_scheduled_idx = max(0, (seq_len - 1) // block_size) - else: - req.mamba_last_scheduled_idx = -1 + mamba_utils.postprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.input_batch, + self.requests, + self.mamba_state_idx, + self.num_spec_tokens, + num_reqs, + forward_context=( + self.compilation_config.static_forward_context if is_align else None + ), + mamba_state_copy_funcs=( + self.model.get_mamba_state_copy_func() if is_align else None + ), + copy_bufs=self._get_mamba_copy_bufs() if is_align else None, + ) def _update_streaming_request( self, req_id: str, new_req_data: NewRequestData @@ -2041,11 +2030,13 @@ def _prepare_inputs( self.num_accepted_tokens.gpu.fill_(1) if self.mamba_prev_last_scheduled_idx is not None: - np_view = self.mamba_prev_last_scheduled_idx.np - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - np_view[i] = self.requests[req_id].mamba_last_scheduled_idx - np_view[num_reqs:].fill(-1) - self.mamba_prev_last_scheduled_idx.copy_to_gpu() + mamba_utils.preprocess_mamba_all_specdec( + scheduler_output, + self.input_batch, + self.mamba_state_idx, + num_reqs, + self.mamba_prev_last_scheduled_idx, + ) # Update num_computed_tokens on GPU. In async spec decode, # CPU values are optimistic (all drafts accepted). The kernel diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index c832389b1b0a..b33080cb094d 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -144,6 +144,24 @@ def do_mamba_copy_block(copy_bufs: MambaCopyBuffers): ) +def cleanup_mamba_state_idx( + scheduler_output: SchedulerOutput, + mamba_state_idx: dict[str, int], +) -> None: + """Pop stale `mamba_state_idx` entries for finished/preempted/resumed reqs. + + Force-preempted requests (e.g., during reset_prefix_cache / KV cache + flush) appear in resumed_req_ids without a corresponding entry in + preempted_req_ids, leaving stale entries that can point to block + indices beyond the new (smaller) block allocation. + """ + finished_req_ids = scheduler_output.finished_req_ids + preempted_req_ids = scheduler_output.preempted_req_ids or set() + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): + mamba_state_idx.pop(req_id, None) + + def preprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, @@ -165,16 +183,7 @@ def preprocess_mamba( # TODO(Chen): we need to optimize this function a lot assert cache_config.enable_prefix_caching block_size = mamba_spec.block_size - finished_req_ids = scheduler_output.finished_req_ids - preempted_req_ids = scheduler_output.preempted_req_ids or set() - # We need to clear mamba_state_idx for resumed requests. When requests are - # force-preempted (e.g., during reset_prefix_cache / KV cache flush), - # they appear in resumed_req_ids without a corresponding entry in - # preempted_req_ids, leaving stale mamba_state_idx entries that can - # point to block indices beyond the new (smaller) block allocation. - resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids - for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): - mamba_state_idx.pop(req_id, None) + cleanup_mamba_state_idx(scheduler_output, mamba_state_idx) copy_bufs.offset = 0 for i, req_id in enumerate(input_batch.req_ids): @@ -222,52 +231,100 @@ def preprocess_mamba( def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, + cache_config: CacheConfig, input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], mamba_state_idx: dict[str, int], - forward_context: dict[str, Any], - mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], - copy_bufs: MambaCopyBuffers, + num_spec_tokens: int, + num_reqs: int, + *, + forward_context: dict[str, Any] | None = None, + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...] | None = None, + copy_bufs: MambaCopyBuffers | None = None, ): """ - If a blocks is converted from partial block to full block in this step, copy the - state from the block for running state to the new full block. + Post-model-execute mamba prefix-caching bookkeeping. Dispatched by + cache_config.mamba_cache_mode: + - "align": if a block is converted from partial to full this step, + copy the running state into the new full block. + - "all" + num_spec_tokens > 0: record per-request the block index of + the last token scheduled this step, so the next step can anchor + its in-place writes when accepted drafts leave the sequence at a + non-block-aligned position. """ - num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens - scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens - num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu - mamba_group_ids = copy_bufs.mamba_group_ids - mamba_spec = copy_bufs.mamba_spec - copy_bufs.offset = 0 - for i, req_id in enumerate(input_batch.req_ids): - req_state = requests[req_id] - num_computed_tokens = req_state.num_computed_tokens - num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) - num_scheduled_tokens = num_scheduled_tokens_dict[req_id] - num_accepted_tokens = num_accepted_tokens_cpu[i] - num_tokens_running_state = ( - num_computed_tokens + num_scheduled_tokens - num_draft_tokens + if cache_config.mamba_cache_mode == "align": + assert forward_context is not None + assert mamba_state_copy_funcs is not None + assert copy_bufs is not None + num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens + scheduled_spec_decode_tokens_dict = ( + scheduler_output.scheduled_spec_decode_tokens ) - new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 - aligned_new_computed_tokens = ( - new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size - ) - # TODO: how to ensure all blocks that cache_blocks called are cached here? - if aligned_new_computed_tokens >= num_tokens_running_state: - accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state - src_block_idx = mamba_state_idx[req_id] - dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 - collect_mamba_copy_meta( - copy_bufs, - kv_cache_config, - mamba_state_copy_funcs, - mamba_group_ids, - src_block_idx, - dest_block_idx, - accept_token_bias, - req_state, - forward_context, + num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu + mamba_group_ids = copy_bufs.mamba_group_ids + mamba_spec = copy_bufs.mamba_spec + copy_bufs.offset = 0 + for i, req_id in enumerate(input_batch.req_ids): + req_state = requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) + num_scheduled_tokens = num_scheduled_tokens_dict[req_id] + num_accepted_tokens = num_accepted_tokens_cpu[i] + num_tokens_running_state = ( + num_computed_tokens + num_scheduled_tokens - num_draft_tokens ) - if src_block_idx == dest_block_idx: - num_accepted_tokens_cpu[i] = 1 - do_mamba_copy_block(copy_bufs) + new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 + aligned_new_computed_tokens = ( + new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size + ) + # TODO: how to ensure all blocks that cache_blocks called are cached here? + if aligned_new_computed_tokens >= num_tokens_running_state: + accept_token_bias = ( + aligned_new_computed_tokens - num_tokens_running_state + ) + src_block_idx = mamba_state_idx[req_id] + dest_block_idx = ( + aligned_new_computed_tokens // mamba_spec.block_size - 1 + ) + collect_mamba_copy_meta( + copy_bufs, + kv_cache_config, + mamba_state_copy_funcs, + mamba_group_ids, + src_block_idx, + dest_block_idx, + accept_token_bias, + req_state, + forward_context, + ) + if src_block_idx == dest_block_idx: + num_accepted_tokens_cpu[i] = 1 + do_mamba_copy_block(copy_bufs) + elif cache_config.mamba_cache_mode == "all" and num_spec_tokens > 0: + _, mamba_spec = get_mamba_groups(kv_cache_config) + block_size = mamba_spec.block_size + full_decode_len = 1 + num_spec_tokens + scheduled = scheduler_output.num_scheduled_tokens + for req_id in input_batch.req_ids[:num_reqs]: + num_query = scheduled.get(req_id, 0) + if num_query == full_decode_len: + req = requests[req_id] + seq_len = req.num_computed_tokens + num_query + mamba_state_idx[req_id] = max(0, (seq_len - 1) // block_size) + else: + mamba_state_idx.pop(req_id, None) + + +def preprocess_mamba_all_specdec( + scheduler_output: SchedulerOutput, + input_batch: GPUInputBatch, + mamba_state_idx: dict[str, int], + num_reqs: int, + prev_last_scheduled_idx_buf: CpuGpuBuffer, +) -> None: + cleanup_mamba_state_idx(scheduler_output, mamba_state_idx) + np_view = prev_last_scheduled_idx_buf.np + for i, req_id in enumerate(input_batch.req_ids[:num_reqs]): + np_view[i] = mamba_state_idx.get(req_id, -1) + np_view[num_reqs:].fill(-1) + prev_last_scheduled_idx_buf.copy_to_gpu() From ecc7a3a8042893ca6805d74c82e048e1b6af0ef2 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Wed, 13 May 2026 13:59:00 +0300 Subject: [PATCH 09/12] Inline function and delete tests Signed-off-by: Roi Koren --- .../test_mamba_update_block_table.py | 62 ------------------- vllm/v1/attention/backends/mamba_attn.py | 26 +++----- 2 files changed, 7 insertions(+), 81 deletions(-) diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index a601eafd3b72..619abac952de 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -20,7 +20,6 @@ from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadata, BaseMambaAttentionMetadataBuilder, - _compute_block_idx_last_scheduled_prev_step, ) from vllm.v1.kv_cache_interface import MambaSpec @@ -417,64 +416,3 @@ def test_block_idx_prev_step_cudagraph_capture_uses_persistent_buffer(): # Tail values past num_decodes: zero-filled padding for cudagraph capture. assert torch.all(out.block_idx_last_scheduled_token_prev_step[num_decodes:] == 0) - - -def test_prev_step_anchor_first_decode_after_prefill(): - """First decode step after prefill: the previous step (prefill) stored - its terminal mamba state at block (num_computed - 1) // block_size. - The worker-side tracker has no entry for these requests, so we pass - -1 to indicate the fallback path.""" - block_size = 16 - num_computed = torch.tensor([100, 16, 64, 1], dtype=torch.int64) - prev_last_scheduled = torch.full((4,), -1, dtype=torch.int64) - - result = _compute_block_idx_last_scheduled_prev_step( - num_computed, prev_last_scheduled, block_size - ) - - expected = torch.tensor([6, 0, 3, 0], dtype=torch.int64) - torch.testing.assert_close(result, expected) - - -def test_prev_step_anchor_subsequent_decode_uses_tracker(): - """For decode steps with a prior decode step, the worker-side tracker - holds the previous step's last_scheduled block index. The function must - use that value verbatim — even when last_computed (= (num_computed - 1) - // block_size) would give a different answer, which is exactly what - happens after partial draft acceptance straddles a block boundary.""" - block_size = 16 - # num_computed[N] = 16 means the last committed token is at position 15 - # (block 0). Naive last_computed would be 0. But the tracker says the - # previous step ended with last_scheduled = 1 (because step N-1 had - # scheduled tokens up to position 17, in block 1; some drafts were - # rejected so only 2 tokens committed, leaving the committed tail in - # block 0). The tracker's value (1) is what we need. - num_computed = torch.tensor([16], dtype=torch.int64) - prev_last_scheduled = torch.tensor([1], dtype=torch.int64) - - result = _compute_block_idx_last_scheduled_prev_step( - num_computed, prev_last_scheduled, block_size - ) - - expected = torch.tensor([1], dtype=torch.int64) - torch.testing.assert_close(result, expected) - - -def test_prev_step_anchor_mixed_batch(): - """Mixed batch: some requests are first-decode (tracker -1), others are - subsequent decode (tracker has a value). Each path resolves - independently.""" - block_size = 16 - num_computed = torch.tensor([16, 100, 16, 64], dtype=torch.int64) - # req 0: subsequent decode, tracker says block 1 (boundary-crossing case) - # req 1: first decode after prefill, fallback to last_computed = 6 - # req 2: subsequent decode, tracker says block 0 (no boundary involvement) - # req 3: first decode, fallback to last_computed = 3 - prev_last_scheduled = torch.tensor([1, -1, 0, -1], dtype=torch.int64) - - result = _compute_block_idx_last_scheduled_prev_step( - num_computed, prev_last_scheduled, block_size - ) - - expected = torch.tensor([1, 6, 0, 3], dtype=torch.int64) - torch.testing.assert_close(result, expected) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index ab28fdf49668..c95febb7ded7 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -26,19 +26,6 @@ M = TypeVar("M", bound="BaseMambaAttentionMetadata") -def _compute_block_idx_last_scheduled_prev_step( - num_computed_tokens: torch.Tensor, - prev_last_scheduled_idx: torch.Tensor, - mamba_block_size: int, -) -> torch.Tensor: - fallback = torch.clamp((num_computed_tokens - 1) // mamba_block_size, min=0) - return torch.where( - prev_last_scheduled_idx >= 0, - prev_last_scheduled_idx, - fallback, - ) - - @dataclass class BaseMambaAttentionMetadata: num_prefills: int @@ -438,12 +425,13 @@ def _compute_common_metadata( common_attn_metadata, mamba_block_size ) if self.use_spec_decode and prev_last_scheduled_idx is not None: - block_idx_last_scheduled_token_prev_step = ( - _compute_block_idx_last_scheduled_prev_step( - num_computed_tokens, - prev_last_scheduled_idx, - mamba_block_size, - ) + fallback = torch.clamp( + (num_computed_tokens - 1) // mamba_block_size, min=0 + ) + block_idx_last_scheduled_token_prev_step = torch.where( + prev_last_scheduled_idx >= 0, + prev_last_scheduled_idx, + fallback, ) else: state_indices_tensor = mamba_get_block_table_tensor( From 760d15e97882e8a982093a77872a0d19af0677c5 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Wed, 13 May 2026 14:05:17 +0300 Subject: [PATCH 10/12] Final cleanup Signed-off-by: Roi Koren --- vllm/v1/worker/gpu_model_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 996778aa05ee..b7aa8779fdaa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1315,9 +1315,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None # Update the cached states. req_state.num_computed_tokens = num_computed_tokens - if resumed_from_preemption: - self.mamba_state_idx.pop(req_id, None) - if not is_last_rank: if not req_data.new_token_ids: # Async scheduled PP: Sampled tokens propagated via GPU broadcast. From b2d328beb3f4f9f5b10c0185ba3d1ecf3009f3fa Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Mon, 18 May 2026 11:57:08 +0300 Subject: [PATCH 11/12] Last real cleanup maybe Signed-off-by: Roi Koren --- .../layers/mamba/mamba_mixer2.py | 20 ++++++++++--------- vllm/v1/attention/backends/mamba_attn.py | 2 +- vllm/v1/attention/backends/utils.py | 6 ++++-- vllm/v1/kv_cache_interface.py | 4 +++- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 0848c576f38a..a6524961ea92 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -502,6 +502,12 @@ def __init__( self.prefix = prefix self.num_spec = vllm_config.num_speculative_tokens + if self.num_spec > 0: + self.register_buffer( + "_decode_state_offsets", + torch.arange(1 + self.num_spec, dtype=torch.int32).unsqueeze(0), + persistent=False, + ) # Pre-compute sizes for forward pass self.tped_intermediate_size = self.intermediate_size // self.tp_size @@ -955,18 +961,14 @@ def conv_ssm_forward( if is_mamba_cache_all: if self.num_spec > 0: assert block_idx_last_scheduled_token_prev_step_d is not None - offsets = torch.arange( - 1 + self.num_spec, - device=block_idx_last_scheduled_token_d.device, - dtype=block_idx_last_scheduled_token_d.dtype, - ) input_indices = ( block_idx_last_scheduled_token_prev_step_d.unsqueeze(1) - + offsets.unsqueeze(0) + + self._decode_state_offsets + ) + output_indices = ( + block_idx_last_scheduled_token_d.unsqueeze(1) + + self._decode_state_offsets ) - output_indices = block_idx_last_scheduled_token_d.unsqueeze( - 1 - ) + offsets.unsqueeze(0) state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, input_indices ) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index c95febb7ded7..2af9c74ea765 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -112,7 +112,7 @@ def __init__( max_num_blocks = ( cdiv( self.vllm_config.model_config.max_model_len, - self.kv_cache_spec.block_size, + kv_cache_spec.block_size, ) + kv_cache_spec.num_speculative_blocks ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 43cbcfec1844..d09c01eb9059 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -878,8 +878,10 @@ def mamba_get_block_table_tensor( Get the block table tensor for mamba kernels from the input common_attn_metadata.block_table_tensor given different mamba cache modes. - - "all": input (#requests, cdiv(max_model_len, block_size)); - output (#requests, cdiv(max_model_len, block_size)). + - "all": input (#requests, cdiv(max_model_len, block_size) + + num_speculative_blocks); + output (#requests, cdiv(max_model_len, block_size) + + num_speculative_blocks). - "none": input (#requests, 1 + num_speculative_blocks); output (#requests, 1 + num_speculative_blocks). diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a63408907b44..31ee89bc72aa 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -582,7 +582,9 @@ def page_size_bytes(self) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: if vllm_config.cache_config.mamba_cache_mode == "all": max_model_len = vllm_config.model_config.max_model_len - return cdiv(max_model_len, self.block_size) * self.page_size_bytes + return ( + cdiv(max_model_len, self.block_size) + self.num_speculative_blocks + ) * self.page_size_bytes elif vllm_config.cache_config.mamba_cache_mode == "align": return self.page_size_bytes * (2 + self.num_speculative_blocks) else: From 680ee00b27a21de460ab5e0dc7f381ae36c61f23 Mon Sep 17 00:00:00 2001 From: Roi Koren Date: Mon, 18 May 2026 13:06:27 +0300 Subject: [PATCH 12/12] Fix fake fn Signed-off-by: Roi Koren --- .../v1/e2e/general/test_mamba_prefix_cache.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/v1/e2e/general/test_mamba_prefix_cache.py b/tests/v1/e2e/general/test_mamba_prefix_cache.py index 6ec9e7656e31..636eb13de886 100644 --- a/tests/v1/e2e/general/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/general/test_mamba_prefix_cache.py @@ -364,26 +364,34 @@ def fake_preprocess_mamba_fn( def fake_post_process_mamba_fn( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, + cache_config: CacheConfig, input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], mamba_state_idx: dict[str, int], - forward_context: dict[str, Any], - mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], - copy_bufs: mamba_utils.MambaCopyBuffers, + num_spec_tokens: int, + num_reqs: int, + *, + forward_context: dict[str, Any] | None = None, + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...] | None = None, + copy_bufs: mamba_utils.MambaCopyBuffers | None = None, ): nonlocal copy_info copy_info = None ret = original_post_process_mamba_fn( scheduler_output, kv_cache_config, + cache_config, input_batch, requests, mamba_state_idx, - forward_context, - mamba_state_copy_funcs, - copy_bufs, + num_spec_tokens, + num_reqs, + forward_context=forward_context, + mamba_state_copy_funcs=mamba_state_copy_funcs, + copy_bufs=copy_bufs, ) if cur_step_action is not None: + assert forward_context is not None check_copy_info( cur_step_action.postprocess_copy_idx, kv_cache_config,