diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index 923939053ece..619abac952de 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -32,17 +32,25 @@ class _ConcreteMambaBuilder( metadata_cls = BaseMambaAttentionMetadata -def _make_vllm_config(block_size, max_model_len, max_num_seqs): +def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0): """Create a minimal mock VllmConfig with only the fields the builder accesses, avoiding any model download / HF config inspection.""" + speculative_config = ( + SimpleNamespace( + num_speculative_tokens=num_speculative_tokens, + parallel_drafting=False, + ) + if num_speculative_tokens > 0 + else None + ) return SimpleNamespace( cache_config=SimpleNamespace(mamba_cache_mode="all"), compilation_config=SimpleNamespace( cudagraph_mode=CUDAGraphMode.FULL, max_cudagraph_capture_size=None, ), - speculative_config=None, - num_speculative_tokens=0, + speculative_config=speculative_config, + num_speculative_tokens=num_speculative_tokens, parallel_config=SimpleNamespace(decode_context_parallel_size=1), scheduler_config=SimpleNamespace(max_num_seqs=max_num_seqs), model_config=SimpleNamespace(max_model_len=max_model_len), @@ -59,7 +67,7 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): num_reqs = 4 device = torch.device("cpu") - vllm_config = _make_vllm_config(block_size, max_model_len, num_reqs) + vllm_config = _make_vllm_config(max_model_len, num_reqs) spec = MambaSpec( block_size=block_size, @@ -106,6 +114,7 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): block_idx_last_computed_token=( builder_a.block_idx_last_computed_token[:num_reqs] ), + block_idx_last_scheduled_token_prev_step=None, seq_lens=seq_lens, ) @@ -149,3 +158,261 @@ def shares_storage(tensor, buffer): metadata_b.block_idx_last_computed_token, block_idx_vals, ) + + +def test_state_indices_tensor_d_includes_num_speculative_blocks(): + """Regression test for https://github.com/vllm-project/vllm/issues/39809 + bug 1: with mamba_cache_mode='all' and speculative decoding enabled, + the cudagraph buffer for state_indices_tensor_d must allocate the same + per-request column count as the runtime block table, which includes + num_speculative_blocks trailing scratch columns.""" + + block_size = 16 + max_model_len = 256 + max_num_seqs = 4 + num_speculative_tokens = 1 + num_speculative_blocks = 2 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, + max_num_seqs, + num_speculative_tokens=num_speculative_tokens, + ) + + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + num_speculative_blocks=num_speculative_blocks, + ) + + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + + expected_cols = (max_model_len // block_size) + num_speculative_blocks + assert builder.state_indices_tensor_d.shape == (max_num_seqs, expected_cols) + + +def test_block_idx_cudagraph_capture_padded_by_num_reqs(): + """Regression test for https://github.com/vllm-project/vllm/issues/39809 + bug 2: with mamba_cache_mode='all' and spec decode, _update_metadata_for + _cudagraph_capture must slice block_idx_last_{scheduled,computed}_token + by the request count (padded_bs == num_reqs), not by num_decode_tokens. + Past num_decodes, the slice must be zero-filled.""" + + block_size = 16 + max_model_len = 256 + max_num_seqs = 8 + num_speculative_tokens = 1 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, + max_num_seqs, + num_speculative_tokens=num_speculative_tokens, + ) + + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + num_speculative_blocks=2, + ) + + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + + builder.block_idx_last_scheduled_token.fill_(-1) + builder.block_idx_last_computed_token.fill_(-1) + + num_decodes = 2 + num_reqs = 3 + num_decode_tokens = num_decodes * (1 + num_speculative_tokens) + seq_lens = torch.full((num_reqs,), 64, dtype=torch.int32, device=device) + block_idx_vals = torch.tensor([3, 5], dtype=torch.int32, device=device) + state_indices_d = torch.zeros( + (num_decodes, builder.state_indices_tensor_d.shape[1]), + dtype=torch.int32, + device=device, + ) + query_start_loc_d = torch.arange( + num_decodes + 1, dtype=torch.int32, device=device + ) * (1 + num_speculative_tokens) + num_accepted_tokens = torch.ones(num_decodes, dtype=torch.int32, device=device) + + metadata = BaseMambaAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_reqs=num_reqs, + has_initial_states_p=None, + query_start_loc_p=None, + num_computed_tokens_p=None, + state_indices_tensor_p=None, + state_indices_tensor_d=state_indices_d, + query_start_loc_d=query_start_loc_d, + num_accepted_tokens=num_accepted_tokens, + block_idx_last_scheduled_token=block_idx_vals, + block_idx_first_scheduled_token_p=None, + block_idx_last_computed_token=block_idx_vals, + block_idx_last_scheduled_token_prev_step=None, + seq_lens=seq_lens, + ) + + out = builder._update_metadata_for_cudagraph_capture(metadata) + + assert out.block_idx_last_scheduled_token.shape == (num_reqs,) + assert out.block_idx_last_computed_token.shape == (num_reqs,) + torch.testing.assert_close( + out.block_idx_last_scheduled_token[:num_decodes], block_idx_vals + ) + torch.testing.assert_close( + out.block_idx_last_computed_token[:num_decodes], block_idx_vals + ) + assert torch.all(out.block_idx_last_scheduled_token[num_decodes:] == 0) + assert torch.all(out.block_idx_last_computed_token[num_decodes:] == 0) + + +def test_block_idx_prev_step_persistent_buffer_allocated(): + """With mamba_cache_mode='all' + spec decode, the builder must allocate + block_idx_last_scheduled_token_prev_step as a persistent buffer with the + same shape as the existing block_idx_last_{scheduled,computed}_token + buffers, so cudagraph capture records a stable pointer for the prev-step + input anchor consumed by mamba_mixer2's input gather.""" + block_size = 16 + max_model_len = 256 + max_num_seqs = 8 + num_speculative_tokens = 1 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, + max_num_seqs, + num_speculative_tokens=num_speculative_tokens, + ) + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + num_speculative_blocks=2, + ) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + + assert hasattr(builder, "block_idx_last_scheduled_token_prev_step") + assert builder.block_idx_last_scheduled_token_prev_step.shape == (max_num_seqs,) + assert builder.block_idx_last_scheduled_token_prev_step.dtype == torch.int32 + + +def test_block_idx_prev_step_persistent_buffer_skipped_without_spec_decode(): + """Without spec decode, the prev-step buffer is unused and must not be + allocated — the input anchor reduces to last_computed_token.""" + block_size = 16 + max_model_len = 256 + max_num_seqs = 8 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, max_num_seqs, num_speculative_tokens=0 + ) + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + ) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + + assert not hasattr(builder, "block_idx_last_scheduled_token_prev_step") + + +def test_block_idx_prev_step_cudagraph_capture_uses_persistent_buffer(): + """_update_metadata_for_cudagraph_capture must copy the prev-step anchor + into the builder's persistent buffer (so cudagraph replay reads from the + same underlying memory), pad past num_decodes with zero, and return a + slice of the persistent buffer in the metadata.""" + block_size = 16 + max_model_len = 256 + max_num_seqs = 8 + num_speculative_tokens = 1 + device = torch.device("cpu") + + vllm_config = _make_vllm_config( + max_model_len, + max_num_seqs, + num_speculative_tokens=num_speculative_tokens, + ) + spec = MambaSpec( + block_size=block_size, + shapes=((1,), (1,)), + dtypes=(torch.float32,), + mamba_cache_mode="all", + num_speculative_blocks=2, + ) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + builder.block_idx_last_scheduled_token.fill_(-1) + builder.block_idx_last_computed_token.fill_(-1) + builder.block_idx_last_scheduled_token_prev_step.fill_(-1) + + num_decodes = 2 + num_reqs = 3 + num_decode_tokens = num_decodes * (1 + num_speculative_tokens) + seq_lens = torch.full((num_reqs,), 64, dtype=torch.int32, device=device) + block_idx_vals = torch.tensor([3, 5], dtype=torch.int32, device=device) + prev_step_vals = torch.tensor([2, 4], dtype=torch.int32, device=device) + state_indices_d = torch.zeros( + (num_decodes, builder.state_indices_tensor_d.shape[1]), + dtype=torch.int32, + device=device, + ) + query_start_loc_d = torch.arange( + num_decodes + 1, dtype=torch.int32, device=device + ) * (1 + num_speculative_tokens) + num_accepted_tokens = torch.ones(num_decodes, dtype=torch.int32, device=device) + + metadata = BaseMambaAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_reqs=num_reqs, + has_initial_states_p=None, + query_start_loc_p=None, + num_computed_tokens_p=None, + state_indices_tensor_p=None, + state_indices_tensor_d=state_indices_d, + query_start_loc_d=query_start_loc_d, + num_accepted_tokens=num_accepted_tokens, + block_idx_last_scheduled_token=block_idx_vals, + block_idx_first_scheduled_token_p=None, + block_idx_last_computed_token=block_idx_vals, + block_idx_last_scheduled_token_prev_step=prev_step_vals, + seq_lens=seq_lens, + ) + + out = builder._update_metadata_for_cudagraph_capture(metadata) + + # Output field exists and is identity-shared with the persistent buffer. + assert out.block_idx_last_scheduled_token_prev_step is not None + assert ( + out.block_idx_last_scheduled_token_prev_step.untyped_storage().data_ptr() + == builder.block_idx_last_scheduled_token_prev_step.untyped_storage().data_ptr() + ), ( + "prev-step buffer must live in the builder's persistent buffer, not " + "in the caller-provided tensor" + ) + + # Padded by num_reqs (not num_decode_tokens) — same fix as bug 2 for the + # other block_idx_* fields. + assert out.block_idx_last_scheduled_token_prev_step.shape == (num_reqs,) + + # First num_decodes values: input values copied through. + torch.testing.assert_close( + out.block_idx_last_scheduled_token_prev_step[:num_decodes], + prev_step_vals, + ) + + # Tail values past num_decodes: zero-filled padding for cudagraph capture. + assert torch.all(out.block_idx_last_scheduled_token_prev_step[num_decodes:] == 0) diff --git a/tests/v1/e2e/general/test_mamba_prefix_cache.py b/tests/v1/e2e/general/test_mamba_prefix_cache.py index 6ec9e7656e31..636eb13de886 100644 --- a/tests/v1/e2e/general/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/general/test_mamba_prefix_cache.py @@ -364,26 +364,34 @@ def fake_preprocess_mamba_fn( def fake_post_process_mamba_fn( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, + cache_config: CacheConfig, input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], mamba_state_idx: dict[str, int], - forward_context: dict[str, Any], - mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], - copy_bufs: mamba_utils.MambaCopyBuffers, + num_spec_tokens: int, + num_reqs: int, + *, + forward_context: dict[str, Any] | None = None, + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...] | None = None, + copy_bufs: mamba_utils.MambaCopyBuffers | None = None, ): nonlocal copy_info copy_info = None ret = original_post_process_mamba_fn( scheduler_output, kv_cache_config, + cache_config, input_batch, requests, mamba_state_idx, - forward_context, - mamba_state_copy_funcs, - copy_bufs, + num_spec_tokens, + num_reqs, + forward_context=forward_context, + mamba_state_copy_funcs=mamba_state_copy_funcs, + copy_bufs=copy_bufs, ) if cur_step_action is not None: + assert forward_context is not None check_copy_info( cur_step_action.postprocess_copy_idx, kv_cache_config, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 674134f373ea..a6524961ea92 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -502,6 +502,12 @@ def __init__( self.prefix = prefix self.num_spec = vllm_config.num_speculative_tokens + if self.num_spec > 0: + self.register_buffer( + "_decode_state_offsets", + torch.arange(1 + self.num_spec, dtype=torch.int32).unsqueeze(0), + persistent=False, + ) # Pre-compute sizes for forward pass self.tped_intermediate_size = self.intermediate_size // self.tp_size @@ -755,6 +761,14 @@ def conv_ssm_forward( dim=0, ) ) + if attn_metadata.block_idx_last_scheduled_token_prev_step is not None: + block_idx_last_scheduled_token_prev_step_d, _ = torch.split( + attn_metadata.block_idx_last_scheduled_token_prev_step, + [num_decodes, num_prefills], + dim=0, + ) + else: + block_idx_last_scheduled_token_prev_step_d = None # Prefill-only variables: block_idx_first_scheduled_token_p = ( attn_metadata.block_idx_first_scheduled_token_p @@ -766,6 +780,7 @@ def conv_ssm_forward( block_idx_first_scheduled_token_p = None block_idx_last_scheduled_token_d = None block_idx_last_computed_token_d = None + block_idx_last_scheduled_token_prev_step_d = None num_computed_tokens_p = None preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( @@ -944,18 +959,29 @@ def conv_ssm_forward( if has_decode: assert state_indices_tensor_d is not None if is_mamba_cache_all: - state_indices_tensor_d_input = state_indices_tensor_d.gather( - 1, block_idx_last_computed_token_d.unsqueeze(1) - ).squeeze(1) - state_indices_tensor_d_output = state_indices_tensor_d.gather( - 1, block_idx_last_scheduled_token_d.unsqueeze(1) - ).squeeze(1) - # for decode: - # block_idx_first_scheduled_token_d == - # block_idx_last_scheduled_token_d - # at block boundaries: - # block_idx_first_scheduled_token_d > - # block_idx_last_computed_token_d + if self.num_spec > 0: + assert block_idx_last_scheduled_token_prev_step_d is not None + input_indices = ( + block_idx_last_scheduled_token_prev_step_d.unsqueeze(1) + + self._decode_state_offsets + ) + output_indices = ( + block_idx_last_scheduled_token_d.unsqueeze(1) + + self._decode_state_offsets + ) + state_indices_tensor_d_input = state_indices_tensor_d.gather( + 1, input_indices + ) + state_indices_tensor_d_output = state_indices_tensor_d.gather( + 1, output_indices + ) + else: + state_indices_tensor_d_input = state_indices_tensor_d.gather( + 1, block_idx_last_computed_token_d.unsqueeze(1) + ).squeeze(1) + state_indices_tensor_d_output = state_indices_tensor_d.gather( + 1, block_idx_last_scheduled_token_d.unsqueeze(1) + ).squeeze(1) else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 459c16f8ec97..422acca642a9 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -350,26 +350,15 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config.enable_prefix_caching: if cache_config.mamba_cache_mode == "none": - if ( - model_config.supports_mamba_prefix_caching - and vllm_config.speculative_config is not None - ): - cache_config.mamba_cache_mode = "align" - logger.warning( - "Mamba cache mode is set to 'align' for %s by default " - "when prefix caching and speculative decoding are enabled", - model_config.architecture, - ) - else: - cache_config.mamba_cache_mode = ( - "all" if model_config.supports_mamba_prefix_caching else "align" - ) - logger.warning( - "Mamba cache mode is set to '%s' for %s by default " - "when prefix caching is enabled", - cache_config.mamba_cache_mode, - model_config.architecture, - ) + cache_config.mamba_cache_mode = ( + "all" if model_config.supports_mamba_prefix_caching else "align" + ) + logger.warning( + "Mamba cache mode is set to '%s' for %s by default " + "when prefix caching is enabled", + cache_config.mamba_cache_mode, + model_config.architecture, + ) if ( cache_config.mamba_cache_mode == "all" and not model_config.supports_mamba_prefix_caching diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index fa7d4bd2ec51..5f25c4a79520 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -137,7 +137,9 @@ def build( **kwargs: Any, ) -> Mamba2AttentionMetadata: common = self._compute_common_metadata( - common_attn_metadata, num_accepted_tokens=kwargs.get("num_accepted_tokens") + common_attn_metadata, + num_accepted_tokens=kwargs.get("num_accepted_tokens"), + prev_last_scheduled_idx=kwargs.get("prev_last_scheduled_idx"), ) seq_idx_p = None diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 716dfcde592f..2af9c74ea765 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -56,6 +56,7 @@ class BaseMambaAttentionMetadata: block_idx_last_scheduled_token: torch.Tensor | None block_idx_first_scheduled_token_p: torch.Tensor | None block_idx_last_computed_token: torch.Tensor | None + block_idx_last_scheduled_token_prev_step: torch.Tensor | None # The following tensor is only used for prefix caching in align mode seq_lens: torch.Tensor @@ -108,12 +109,13 @@ def __init__( ) if self.vllm_config.cache_config.mamba_cache_mode == "all": - max_num_blocks = cdiv( - self.vllm_config.model_config.max_model_len, - self.kv_cache_spec.block_size, + max_num_blocks = ( + cdiv( + self.vllm_config.model_config.max_model_len, + kv_cache_spec.block_size, + ) + + kv_cache_spec.num_speculative_blocks ) - # Speculative decoding not supported with prefix caching, - # so keep shape consistent with prefill buffer # TODO: reduce this size as needed for decode-only cudagraph capture self.state_indices_tensor_d: torch.Tensor = torch.empty( ( @@ -133,6 +135,14 @@ def __init__( dtype=torch.int32, device=device, ) + if self.use_spec_decode: + self.block_idx_last_scheduled_token_prev_step: torch.Tensor = ( + torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + ) else: self.state_indices_tensor_d = torch.empty( (self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens), @@ -176,7 +186,23 @@ def build_for_cudagraph_capture( if self.num_spec_tokens > 0: num_accepted_tokens = torch.diff(m.query_start_loc) - return self.build(0, m, num_accepted_tokens=num_accepted_tokens) + prev_last_scheduled_idx = None + if ( + self.use_spec_decode + and self.vllm_config.cache_config.mamba_cache_mode == "all" + ): + prev_last_scheduled_idx = torch.zeros( + (m.num_reqs,), + dtype=torch.int32, + device=m.query_start_loc.device, + ) + + return self.build( + 0, + m, + num_accepted_tokens=num_accepted_tokens, + prev_last_scheduled_idx=prev_last_scheduled_idx, + ) def build( self, @@ -185,6 +211,7 @@ def build( fast_build: bool = False, *, num_accepted_tokens: torch.Tensor | None = None, + prev_last_scheduled_idx: torch.Tensor | None = None, **kwargs: Any, ) -> M: """ @@ -192,7 +219,9 @@ def build( Subclasses (e.g., Mamba2) can override to add additional metadata. """ return self._compute_common_metadata( - common_attn_metadata, num_accepted_tokens=num_accepted_tokens + common_attn_metadata, + num_accepted_tokens=num_accepted_tokens, + prev_last_scheduled_idx=prev_last_scheduled_idx, ) def _compute_chunk_metadata( @@ -341,6 +370,7 @@ def _compute_common_metadata( common_attn_metadata: CommonAttentionMetadata, *, num_accepted_tokens: torch.Tensor | None = None, + prev_last_scheduled_idx: torch.Tensor | None = None, ) -> M: """ Compute metadata common to both Mamba1 and Mamba2. @@ -375,6 +405,7 @@ def _compute_common_metadata( block_idx_first_scheduled_token_p = None block_idx_last_computed_token = None block_idx_last_scheduled_token = None + block_idx_last_scheduled_token_prev_step = None # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None @@ -393,6 +424,15 @@ def _compute_common_metadata( ) = self._compute_prefix_caching_block_indices( common_attn_metadata, mamba_block_size ) + if self.use_spec_decode and prev_last_scheduled_idx is not None: + fallback = torch.clamp( + (num_computed_tokens - 1) // mamba_block_size, min=0 + ) + block_idx_last_scheduled_token_prev_step = torch.where( + prev_last_scheduled_idx >= 0, + prev_last_scheduled_idx, + fallback, + ) else: state_indices_tensor = mamba_get_block_table_tensor( common_attn_metadata.block_table_tensor, @@ -470,6 +510,9 @@ def _compute_common_metadata( block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, block_idx_last_computed_token=block_idx_last_computed_token, + block_idx_last_scheduled_token_prev_step=( + block_idx_last_scheduled_token_prev_step + ), num_computed_tokens_p=num_computed_tokens_p, num_reqs=num_reqs, seq_lens=common_attn_metadata.seq_lens, @@ -493,6 +536,9 @@ def _update_metadata_for_cudagraph_capture( num_accepted_tokens = metadata.num_accepted_tokens block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token block_idx_last_computed_token = metadata.block_idx_last_computed_token + block_idx_last_scheduled_token_prev_step = ( + metadata.block_idx_last_scheduled_token_prev_step + ) if ( metadata.num_prefills == 0 and metadata.num_decodes <= self.decode_cudagraph_max_bs @@ -524,16 +570,35 @@ def _update_metadata_for_cudagraph_capture( non_blocking=True, ) block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - : metadata.num_decode_tokens + :padded_bs ] + block_idx_last_scheduled_token[metadata.num_decodes :] = 0 self.block_idx_last_computed_token[: metadata.num_decodes].copy_( block_idx_last_computed_token[: metadata.num_decodes], non_blocking=True, ) block_idx_last_computed_token = self.block_idx_last_computed_token[ - : metadata.num_decode_tokens + :padded_bs ] + block_idx_last_computed_token[metadata.num_decodes :] = 0 + + if ( + self.use_spec_decode + and block_idx_last_scheduled_token_prev_step is not None + ): + self.block_idx_last_scheduled_token_prev_step[ + : metadata.num_decodes + ].copy_( + block_idx_last_scheduled_token_prev_step[ + : metadata.num_decodes + ], + non_blocking=True, + ) + block_idx_last_scheduled_token_prev_step = ( + self.block_idx_last_scheduled_token_prev_step[:padded_bs] + ) + block_idx_last_scheduled_token_prev_step[metadata.num_decodes :] = 0 return replace( metadata, @@ -542,6 +607,9 @@ def _update_metadata_for_cudagraph_capture( num_accepted_tokens=num_accepted_tokens, block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_last_computed_token=block_idx_last_computed_token, + block_idx_last_scheduled_token_prev_step=( + block_idx_last_scheduled_token_prev_step + ), ) def update_block_table( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 43cbcfec1844..d09c01eb9059 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -878,8 +878,10 @@ def mamba_get_block_table_tensor( Get the block table tensor for mamba kernels from the input common_attn_metadata.block_table_tensor given different mamba cache modes. - - "all": input (#requests, cdiv(max_model_len, block_size)); - output (#requests, cdiv(max_model_len, block_size)). + - "all": input (#requests, cdiv(max_model_len, block_size) + + num_speculative_blocks); + output (#requests, cdiv(max_model_len, block_size) + + num_speculative_blocks). - "none": input (#requests, 1 + num_speculative_blocks); output (#requests, 1 + num_speculative_blocks). diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a63408907b44..31ee89bc72aa 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -582,7 +582,9 @@ def page_size_bytes(self) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: if vllm_config.cache_config.mamba_cache_mode == "all": max_model_len = vllm_config.model_config.max_model_len - return cdiv(max_model_len, self.block_size) * self.page_size_bytes + return ( + cdiv(max_model_len, self.block_size) + self.num_speculative_blocks + ) * self.page_size_bytes elif vllm_config.cache_config.mamba_cache_mode == "align": return self.page_size_bytes * (2 + self.num_speculative_blocks) else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2c010040bc21..b7aa8779fdaa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -888,6 +888,11 @@ def __init__( self.kv_connector_output: KVConnectorOutput | None = None self.mamba_state_idx: dict[str, int] = {} self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None + self.mamba_prev_last_scheduled_idx: CpuGpuBuffer | None = None + if self.cache_config.mamba_cache_mode == "all" and self.num_spec_tokens > 0: + self.mamba_prev_last_scheduled_idx = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) self.layerwise_nvtx_hooks_registered = False def update_max_model_len(self, max_model_len: int) -> None: @@ -1493,21 +1498,12 @@ def _update_states_after_model_execute( num_reqs = output_token_ids.size(0) self.num_accepted_tokens.gpu[:num_reqs] = (output_token_ids != -1).sum(dim=1) - if self.cache_config.mamba_cache_mode == "align": + is_align = self.cache_config.mamba_cache_mode == "align" + if is_align: for i, num_tokens in enumerate( self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() ): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens - mamba_utils.postprocess_mamba( - scheduler_output, - self.kv_cache_config, - self.input_batch, - self.requests, - self.mamba_state_idx, - self.compilation_config.static_forward_context, - self.model.get_mamba_state_copy_func(), - self._get_mamba_copy_bufs(), - ) else: self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_( self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True @@ -1515,6 +1511,24 @@ def _update_states_after_model_execute( assert self.num_accepted_tokens_event is not None self.num_accepted_tokens_event.record() + mamba_utils.postprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.input_batch, + self.requests, + self.mamba_state_idx, + self.num_spec_tokens, + num_reqs, + forward_context=( + self.compilation_config.static_forward_context if is_align else None + ), + mamba_state_copy_funcs=( + self.model.get_mamba_state_copy_func() if is_align else None + ), + copy_bufs=self._get_mamba_copy_bufs() if is_align else None, + ) + def _update_streaming_request( self, req_id: str, new_req_data: NewRequestData ) -> CachedRequestState: @@ -2012,6 +2026,15 @@ def _prepare_inputs( self.num_accepted_tokens.np.fill(1) self.num_accepted_tokens.gpu.fill_(1) + if self.mamba_prev_last_scheduled_idx is not None: + mamba_utils.preprocess_mamba_all_specdec( + scheduler_output, + self.input_batch, + self.mamba_state_idx, + num_reqs, + self.mamba_prev_last_scheduled_idx, + ) + # Update num_computed_tokens on GPU. In async spec decode, # CPU values are optimistic (all drafts accepted). The kernel # corrects on GPU using the previous step's @@ -2319,6 +2342,13 @@ def _build_attn_group_metadata( :num_reqs_padded ], ) + if ( + isinstance(builder, Mamba2AttentionMetadataBuilder) + and self.mamba_prev_last_scheduled_idx is not None + ): + extra_attn_metadata_args["prev_last_scheduled_idx"] = ( + self.mamba_prev_last_scheduled_idx.gpu[:num_reqs_padded] + ) if for_cudagraph_capture: attn_metadata_i = builder.build_for_cudagraph_capture( diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index c832389b1b0a..b33080cb094d 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -144,6 +144,24 @@ def do_mamba_copy_block(copy_bufs: MambaCopyBuffers): ) +def cleanup_mamba_state_idx( + scheduler_output: SchedulerOutput, + mamba_state_idx: dict[str, int], +) -> None: + """Pop stale `mamba_state_idx` entries for finished/preempted/resumed reqs. + + Force-preempted requests (e.g., during reset_prefix_cache / KV cache + flush) appear in resumed_req_ids without a corresponding entry in + preempted_req_ids, leaving stale entries that can point to block + indices beyond the new (smaller) block allocation. + """ + finished_req_ids = scheduler_output.finished_req_ids + preempted_req_ids = scheduler_output.preempted_req_ids or set() + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): + mamba_state_idx.pop(req_id, None) + + def preprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, @@ -165,16 +183,7 @@ def preprocess_mamba( # TODO(Chen): we need to optimize this function a lot assert cache_config.enable_prefix_caching block_size = mamba_spec.block_size - finished_req_ids = scheduler_output.finished_req_ids - preempted_req_ids = scheduler_output.preempted_req_ids or set() - # We need to clear mamba_state_idx for resumed requests. When requests are - # force-preempted (e.g., during reset_prefix_cache / KV cache flush), - # they appear in resumed_req_ids without a corresponding entry in - # preempted_req_ids, leaving stale mamba_state_idx entries that can - # point to block indices beyond the new (smaller) block allocation. - resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids - for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): - mamba_state_idx.pop(req_id, None) + cleanup_mamba_state_idx(scheduler_output, mamba_state_idx) copy_bufs.offset = 0 for i, req_id in enumerate(input_batch.req_ids): @@ -222,52 +231,100 @@ def preprocess_mamba( def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, + cache_config: CacheConfig, input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], mamba_state_idx: dict[str, int], - forward_context: dict[str, Any], - mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], - copy_bufs: MambaCopyBuffers, + num_spec_tokens: int, + num_reqs: int, + *, + forward_context: dict[str, Any] | None = None, + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...] | None = None, + copy_bufs: MambaCopyBuffers | None = None, ): """ - If a blocks is converted from partial block to full block in this step, copy the - state from the block for running state to the new full block. + Post-model-execute mamba prefix-caching bookkeeping. Dispatched by + cache_config.mamba_cache_mode: + - "align": if a block is converted from partial to full this step, + copy the running state into the new full block. + - "all" + num_spec_tokens > 0: record per-request the block index of + the last token scheduled this step, so the next step can anchor + its in-place writes when accepted drafts leave the sequence at a + non-block-aligned position. """ - num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens - scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens - num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu - mamba_group_ids = copy_bufs.mamba_group_ids - mamba_spec = copy_bufs.mamba_spec - copy_bufs.offset = 0 - for i, req_id in enumerate(input_batch.req_ids): - req_state = requests[req_id] - num_computed_tokens = req_state.num_computed_tokens - num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) - num_scheduled_tokens = num_scheduled_tokens_dict[req_id] - num_accepted_tokens = num_accepted_tokens_cpu[i] - num_tokens_running_state = ( - num_computed_tokens + num_scheduled_tokens - num_draft_tokens + if cache_config.mamba_cache_mode == "align": + assert forward_context is not None + assert mamba_state_copy_funcs is not None + assert copy_bufs is not None + num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens + scheduled_spec_decode_tokens_dict = ( + scheduler_output.scheduled_spec_decode_tokens ) - new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 - aligned_new_computed_tokens = ( - new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size - ) - # TODO: how to ensure all blocks that cache_blocks called are cached here? - if aligned_new_computed_tokens >= num_tokens_running_state: - accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state - src_block_idx = mamba_state_idx[req_id] - dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 - collect_mamba_copy_meta( - copy_bufs, - kv_cache_config, - mamba_state_copy_funcs, - mamba_group_ids, - src_block_idx, - dest_block_idx, - accept_token_bias, - req_state, - forward_context, + num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu + mamba_group_ids = copy_bufs.mamba_group_ids + mamba_spec = copy_bufs.mamba_spec + copy_bufs.offset = 0 + for i, req_id in enumerate(input_batch.req_ids): + req_state = requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) + num_scheduled_tokens = num_scheduled_tokens_dict[req_id] + num_accepted_tokens = num_accepted_tokens_cpu[i] + num_tokens_running_state = ( + num_computed_tokens + num_scheduled_tokens - num_draft_tokens ) - if src_block_idx == dest_block_idx: - num_accepted_tokens_cpu[i] = 1 - do_mamba_copy_block(copy_bufs) + new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 + aligned_new_computed_tokens = ( + new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size + ) + # TODO: how to ensure all blocks that cache_blocks called are cached here? + if aligned_new_computed_tokens >= num_tokens_running_state: + accept_token_bias = ( + aligned_new_computed_tokens - num_tokens_running_state + ) + src_block_idx = mamba_state_idx[req_id] + dest_block_idx = ( + aligned_new_computed_tokens // mamba_spec.block_size - 1 + ) + collect_mamba_copy_meta( + copy_bufs, + kv_cache_config, + mamba_state_copy_funcs, + mamba_group_ids, + src_block_idx, + dest_block_idx, + accept_token_bias, + req_state, + forward_context, + ) + if src_block_idx == dest_block_idx: + num_accepted_tokens_cpu[i] = 1 + do_mamba_copy_block(copy_bufs) + elif cache_config.mamba_cache_mode == "all" and num_spec_tokens > 0: + _, mamba_spec = get_mamba_groups(kv_cache_config) + block_size = mamba_spec.block_size + full_decode_len = 1 + num_spec_tokens + scheduled = scheduler_output.num_scheduled_tokens + for req_id in input_batch.req_ids[:num_reqs]: + num_query = scheduled.get(req_id, 0) + if num_query == full_decode_len: + req = requests[req_id] + seq_len = req.num_computed_tokens + num_query + mamba_state_idx[req_id] = max(0, (seq_len - 1) // block_size) + else: + mamba_state_idx.pop(req_id, None) + + +def preprocess_mamba_all_specdec( + scheduler_output: SchedulerOutput, + input_batch: GPUInputBatch, + mamba_state_idx: dict[str, int], + num_reqs: int, + prev_last_scheduled_idx_buf: CpuGpuBuffer, +) -> None: + cleanup_mamba_state_idx(scheduler_output, mamba_state_idx) + np_view = prev_last_scheduled_idx_buf.np + for i, req_id in enumerate(input_batch.req_ids[:num_reqs]): + np_view[i] = mamba_state_idx.get(req_id, -1) + np_view[num_reqs:].fill(-1) + prev_last_scheduled_idx_buf.copy_to_gpu()