diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6a33fc7d6b1b..f4537e743cdc 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -271,6 +271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p @@ -394,6 +395,13 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): else: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + + # Clear stale state for new requests classified as decodes + if has_initial_states_d is not None: + new_indices = state_indices_tensor_d_input[~has_initial_states_d] + conv_state[:, new_indices] = 0 + ssm_state[new_indices] = 0 + # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 971581d89c27..378a862c2f18 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -580,6 +580,7 @@ def conv_ssm_forward( conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p @@ -831,6 +832,12 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + # Clear stale state for new requests classified as decodes + if has_initial_states_d is not None: + new_indices = state_indices_tensor_d_input[~has_initial_states_d] + conv_state[:, new_indices] = 0 + ssm_state[new_indices] = 0 + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 81ba858d6a7e..07f12e61eb71 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -269,6 +269,7 @@ def forward_impl( state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p @@ -393,6 +394,12 @@ def forward_impl( # Process decode requests if has_decode: + # Clear stale state for new requests classified as decodes + if has_initial_states_d is not None: + new_indices = state_indices_tensor_d[~has_initial_states_d] + conv_state[:, new_indices] = 0 + ssm_state[new_indices] = 0 + # 2. Convolution sequence transformation hidden_states_d = causal_conv1d_update( hidden_states_d, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index afcc13e4de4a..08dbd614fdcf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3683,12 +3683,6 @@ def execute_model( ) ) - mamba_utils.clear_stale_mamba_states( - attn_metadata, - self.attn_groups, - self.compilation_config.static_forward_context, - ) - ( input_ids, inputs_embeds, diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index b4f808b04f6e..2bd5d2b3fea8 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -13,13 +13,11 @@ ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch -from vllm.v1.worker.utils import AttentionGroup @triton.jit @@ -216,42 +214,6 @@ def preprocess_mamba( do_mamba_copy_block(copy_bufs) -def clear_stale_mamba_states( - attn_metadata: list[dict[str, Any]] | dict[str, Any], - attn_groups: list[list["AttentionGroup"]], - forward_context: dict[str, Any], -) -> None: - """Clear Mamba states for new requests in the decode batch. - - New requests (has_initial_states=False) would otherwise read stale - state from a recycled cache slot. - """ - if not isinstance(attn_metadata, dict): - return - - for kv_cache_groups in attn_groups: - for attn_group in kv_cache_groups: - if not isinstance(attn_group.kv_cache_spec, MambaSpec): - continue - - metadata = attn_metadata.get(attn_group.layer_names[0]) - if not isinstance(metadata, BaseMambaAttentionMetadata): - continue - - has_initial_states_d = metadata.has_initial_states_d - if has_initial_states_d is None: - continue - - assert metadata.state_indices_tensor_d is not None - indices = metadata.state_indices_tensor_d[: metadata.num_decodes] - new_indices = indices[~has_initial_states_d[: metadata.num_decodes]] - - for layer_name in attn_group.layer_names: - layer = forward_context[layer_name] - for state in layer.kv_cache[0]: - state[new_indices] = 0 - - def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig,