diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 3c6b0139424d..91ba9e42ac58 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -57,6 +57,22 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: ), ) + @staticmethod + def clear_stale_decode_states( + has_initial_states_d: torch.Tensor | None, + indices: torch.Tensor, + ssm_state: torch.Tensor, + conv_state: torch.Tensor, + ) -> None: + if has_initial_states_d is None: + return + + for state in (ssm_state, conv_state): + gathered = state[indices] + keep_state = has_initial_states_d.to(gathered.dtype) + keep_state = keep_state.view(-1, *([1] * (gathered.dim() - 1))) + state[indices] = gathered * keep_state + def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this Mamba layer.""" return get_mamba_attn_backend(self.mamba_type) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6a33fc7d6b1b..6302156958f3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -271,6 +271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p @@ -394,6 +395,14 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): else: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + + self.clear_stale_decode_states( + has_initial_states_d, + state_indices_tensor_d_input, + ssm_state, + conv_state, + ) + # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 971581d89c27..465f80240fad 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -580,6 +580,7 @@ def conv_ssm_forward( conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p @@ -831,6 +832,13 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + self.clear_stale_decode_states( + has_initial_states_d, + state_indices_tensor_d_input, + ssm_state, + conv_state, + ) + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0364d6aee5c7..23adf6298a4b 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -43,6 +43,7 @@ class BaseMambaAttentionMetadata: # The following tensors are used for decode requests and # speculative decoding compatibility, and will be None if the batch # has no decode requests. + has_initial_states_d: torch.Tensor | None state_indices_tensor_d: torch.Tensor | None query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,] @@ -148,6 +149,12 @@ def __init__( device=device, ) + self.has_initial_states_d: torch.Tensor = torch.ones( + (self.decode_cudagraph_max_bs,), + dtype=torch.bool, + device=device, + ) + self._init_reorder_batch_threshold(1, self.use_spec_decode) if self.use_spec_decode: self.supports_update_block_table = False @@ -364,6 +371,7 @@ def _compute_common_metadata( # Need flags to indicate if there are initial states has_initial_states_p = None + has_initial_states_d = None query_start_loc_p = None query_start_loc_d = None num_computed_tokens = None @@ -414,6 +422,19 @@ def _compute_common_metadata( ] state_indices_tensor_p = state_indices_tensor_p[:, 0] + # Only set when there are genuinely new decode requests + # (num_computed_tokens == 0, seq_lens > 0). Padded CG slots + # (seq_lens == 0) are excluded so their PAD_SLOT_ID doesn't + # cause zeroing of the last real cache slot. + if num_decodes > 0: + if num_computed_tokens is None: + num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() + has_initial_states_d = (num_computed_tokens[:num_decodes] > 0) | ( + common_attn_metadata.seq_lens[:num_decodes] == 0 + ) + if has_initial_states_d.all(): + has_initial_states_d = None + if num_decodes > 0 and self.use_spec_decode: assert num_accepted_tokens is not None query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1] @@ -459,6 +480,7 @@ def _compute_common_metadata( num_decode_tokens=num_decode_tokens, query_start_loc_p=query_start_loc_p, has_initial_states_p=has_initial_states_p, + has_initial_states_d=has_initial_states_d, state_indices_tensor_p=state_indices_tensor_p, state_indices_tensor_d=state_indices_tensor_d, num_accepted_tokens=num_accepted_tokens, @@ -485,6 +507,7 @@ def _update_metadata_for_cudagraph_capture( Currently, only decode is supported for full cudagraphs with Mamba. """ state_indices_tensor_d = metadata.state_indices_tensor_d + has_initial_states_d = metadata.has_initial_states_d query_start_loc_d = metadata.query_start_loc_d num_accepted_tokens = metadata.num_accepted_tokens block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token @@ -501,6 +524,13 @@ def _update_metadata_for_cudagraph_capture( state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID + if has_initial_states_d is not None: + self.has_initial_states_d[: metadata.num_decodes].copy_( + has_initial_states_d, non_blocking=True + ) + self.has_initial_states_d[metadata.num_decodes :] = True + has_initial_states_d = self.has_initial_states_d[:padded_bs] + if self.use_spec_decode: assert query_start_loc_d is not None assert num_accepted_tokens is not None @@ -535,6 +565,7 @@ def _update_metadata_for_cudagraph_capture( return replace( metadata, state_indices_tensor_d=state_indices_tensor_d, + has_initial_states_d=has_initial_states_d, query_start_loc_d=query_start_loc_d, num_accepted_tokens=num_accepted_tokens, block_idx_last_scheduled_token=block_idx_last_scheduled_token,