diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 3c6b0139424d..91ba9e42ac58 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -57,6 +57,22 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: ), ) + @staticmethod + def clear_stale_decode_states( + has_initial_states_d: torch.Tensor | None, + indices: torch.Tensor, + ssm_state: torch.Tensor, + conv_state: torch.Tensor, + ) -> None: + if has_initial_states_d is None: + return + + for state in (ssm_state, conv_state): + gathered = state[indices] + keep_state = has_initial_states_d.to(gathered.dtype) + keep_state = keep_state.view(-1, *([1] * (gathered.dim() - 1))) + state[indices] = gathered * keep_state + def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this Mamba layer.""" return get_mamba_attn_backend(self.mamba_type) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6a33fc7d6b1b..6302156958f3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -271,6 +271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p @@ -394,6 +395,14 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): else: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + + self.clear_stale_decode_states( + has_initial_states_d, + state_indices_tensor_d_input, + ssm_state, + conv_state, + ) + # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 971581d89c27..465f80240fad 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -580,6 +580,7 @@ def conv_ssm_forward( conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p @@ -831,6 +832,13 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + self.clear_stale_decode_states( + has_initial_states_d, + state_indices_tensor_d_input, + ssm_state, + conv_state, + ) + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index f86e691524c0..23adf6298a4b 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -149,6 +149,12 @@ def __init__( device=device, ) + self.has_initial_states_d: torch.Tensor = torch.ones( + (self.decode_cudagraph_max_bs,), + dtype=torch.bool, + device=device, + ) + self._init_reorder_batch_threshold(1, self.use_spec_decode) if self.use_spec_decode: self.supports_update_block_table = False @@ -501,6 +507,7 @@ def _update_metadata_for_cudagraph_capture( Currently, only decode is supported for full cudagraphs with Mamba. """ state_indices_tensor_d = metadata.state_indices_tensor_d + has_initial_states_d = metadata.has_initial_states_d query_start_loc_d = metadata.query_start_loc_d num_accepted_tokens = metadata.num_accepted_tokens block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token @@ -517,6 +524,13 @@ def _update_metadata_for_cudagraph_capture( state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID + if has_initial_states_d is not None: + self.has_initial_states_d[: metadata.num_decodes].copy_( + has_initial_states_d, non_blocking=True + ) + self.has_initial_states_d[metadata.num_decodes :] = True + has_initial_states_d = self.has_initial_states_d[:padded_bs] + if self.use_spec_decode: assert query_start_loc_d is not None assert num_accepted_tokens is not None @@ -551,6 +565,7 @@ def _update_metadata_for_cudagraph_capture( return replace( metadata, state_indices_tensor_d=state_indices_tensor_d, + has_initial_states_d=has_initial_states_d, query_start_loc_d=query_start_loc_d, num_accepted_tokens=num_accepted_tokens, block_idx_last_scheduled_token=block_idx_last_scheduled_token, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index afcc13e4de4a..08dbd614fdcf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3683,12 +3683,6 @@ def execute_model( ) ) - mamba_utils.clear_stale_mamba_states( - attn_metadata, - self.attn_groups, - self.compilation_config.static_forward_context, - ) - ( input_ids, inputs_embeds, diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index b4f808b04f6e..2bd5d2b3fea8 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -13,13 +13,11 @@ ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch -from vllm.v1.worker.utils import AttentionGroup @triton.jit @@ -216,42 +214,6 @@ def preprocess_mamba( do_mamba_copy_block(copy_bufs) -def clear_stale_mamba_states( - attn_metadata: list[dict[str, Any]] | dict[str, Any], - attn_groups: list[list["AttentionGroup"]], - forward_context: dict[str, Any], -) -> None: - """Clear Mamba states for new requests in the decode batch. - - New requests (has_initial_states=False) would otherwise read stale - state from a recycled cache slot. - """ - if not isinstance(attn_metadata, dict): - return - - for kv_cache_groups in attn_groups: - for attn_group in kv_cache_groups: - if not isinstance(attn_group.kv_cache_spec, MambaSpec): - continue - - metadata = attn_metadata.get(attn_group.layer_names[0]) - if not isinstance(metadata, BaseMambaAttentionMetadata): - continue - - has_initial_states_d = metadata.has_initial_states_d - if has_initial_states_d is None: - continue - - assert metadata.state_indices_tensor_d is not None - indices = metadata.state_indices_tensor_d[: metadata.num_decodes] - new_indices = indices[~has_initial_states_d[: metadata.num_decodes]] - - for layer_name in attn_group.layer_names: - layer = forward_context[layer_name] - for state in layer.kv_cache[0]: - state[new_indices] = 0 - - def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig,