From d27f74a364bbea9949cad5e94b5d70ac765ac58d Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 22:31:12 +0200 Subject: [PATCH 1/4] test Signed-off-by: Josephasafg --- .../layers/mamba/mamba_mixer.py | 8 ++++ vllm/v1/worker/gpu_model_runner.py | 6 --- vllm/v1/worker/mamba_utils.py | 38 ------------------- 3 files changed, 8 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6a33fc7d6b1b..f4537e743cdc 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -271,6 +271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p @@ -394,6 +395,13 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): else: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + + # Clear stale state for new requests classified as decodes + if has_initial_states_d is not None: + new_indices = state_indices_tensor_d_input[~has_initial_states_d] + conv_state[:, new_indices] = 0 + ssm_state[new_indices] = 0 + # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index afcc13e4de4a..08dbd614fdcf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3683,12 +3683,6 @@ def execute_model( ) ) - mamba_utils.clear_stale_mamba_states( - attn_metadata, - self.attn_groups, - self.compilation_config.static_forward_context, - ) - ( input_ids, inputs_embeds, diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index b4f808b04f6e..2bd5d2b3fea8 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -13,13 +13,11 @@ ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch -from vllm.v1.worker.utils import AttentionGroup @triton.jit @@ -216,42 +214,6 @@ def preprocess_mamba( do_mamba_copy_block(copy_bufs) -def clear_stale_mamba_states( - attn_metadata: list[dict[str, Any]] | dict[str, Any], - attn_groups: list[list["AttentionGroup"]], - forward_context: dict[str, Any], -) -> None: - """Clear Mamba states for new requests in the decode batch. - - New requests (has_initial_states=False) would otherwise read stale - state from a recycled cache slot. - """ - if not isinstance(attn_metadata, dict): - return - - for kv_cache_groups in attn_groups: - for attn_group in kv_cache_groups: - if not isinstance(attn_group.kv_cache_spec, MambaSpec): - continue - - metadata = attn_metadata.get(attn_group.layer_names[0]) - if not isinstance(metadata, BaseMambaAttentionMetadata): - continue - - has_initial_states_d = metadata.has_initial_states_d - if has_initial_states_d is None: - continue - - assert metadata.state_indices_tensor_d is not None - indices = metadata.state_indices_tensor_d[: metadata.num_decodes] - new_indices = indices[~has_initial_states_d[: metadata.num_decodes]] - - for layer_name in attn_group.layer_names: - layer = forward_context[layer_name] - for state in layer.kv_cache[0]: - state[new_indices] = 0 - - def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, From 1ad6f2acbe46c051a256072d5aad3fcaa8c3f2c5 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 23:30:33 +0200 Subject: [PATCH 2/4] mod change Signed-off-by: Josephasafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 15 +++++++++++---- .../model_executor/layers/mamba/mamba_mixer2.py | 17 +++++++++++++++++ vllm/v1/attention/backends/mamba_attn.py | 15 +++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index f4537e743cdc..36cb884ed69e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -396,11 +396,18 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - # Clear stale state for new requests classified as decodes if has_initial_states_d is not None: - new_indices = state_indices_tensor_d_input[~has_initial_states_d] - conv_state[:, new_indices] = 0 - ssm_state[new_indices] = 0 + indices = state_indices_tensor_d_input + + ssm_gathered = ssm_state[indices] + keep_ssm = has_initial_states_d.to(ssm_gathered.dtype) + keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1))) + ssm_state[indices] = ssm_gathered * keep_ssm + + conv_gathered = conv_state[indices] + keep_conv = has_initial_states_d.to(conv_gathered.dtype) + keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1))) + conv_state[indices] = conv_gathered * keep_conv # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 971581d89c27..f752dc472a6c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -580,6 +580,7 @@ def conv_ssm_forward( conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p @@ -831,6 +832,22 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + # Clear stale state for new requests classified as decodes. + # Uses gather-multiply-scatter (fixed-shape ops) instead of + # boolean indexing to stay compatible with CUDA graph capture. + if has_initial_states_d is not None: + indices = state_indices_tensor_d_input + + ssm_gathered = ssm_state[indices] + keep_ssm = has_initial_states_d.to(ssm_gathered.dtype) + keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1))) + ssm_state[indices] = ssm_gathered * keep_ssm + + conv_gathered = conv_state[indices] + keep_conv = has_initial_states_d.to(conv_gathered.dtype) + keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1))) + conv_state[indices] = conv_gathered * keep_conv + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index f86e691524c0..23adf6298a4b 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -149,6 +149,12 @@ def __init__( device=device, ) + self.has_initial_states_d: torch.Tensor = torch.ones( + (self.decode_cudagraph_max_bs,), + dtype=torch.bool, + device=device, + ) + self._init_reorder_batch_threshold(1, self.use_spec_decode) if self.use_spec_decode: self.supports_update_block_table = False @@ -501,6 +507,7 @@ def _update_metadata_for_cudagraph_capture( Currently, only decode is supported for full cudagraphs with Mamba. """ state_indices_tensor_d = metadata.state_indices_tensor_d + has_initial_states_d = metadata.has_initial_states_d query_start_loc_d = metadata.query_start_loc_d num_accepted_tokens = metadata.num_accepted_tokens block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token @@ -517,6 +524,13 @@ def _update_metadata_for_cudagraph_capture( state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID + if has_initial_states_d is not None: + self.has_initial_states_d[: metadata.num_decodes].copy_( + has_initial_states_d, non_blocking=True + ) + self.has_initial_states_d[metadata.num_decodes :] = True + has_initial_states_d = self.has_initial_states_d[:padded_bs] + if self.use_spec_decode: assert query_start_loc_d is not None assert num_accepted_tokens is not None @@ -551,6 +565,7 @@ def _update_metadata_for_cudagraph_capture( return replace( metadata, state_indices_tensor_d=state_indices_tensor_d, + has_initial_states_d=has_initial_states_d, query_start_loc_d=query_start_loc_d, num_accepted_tokens=num_accepted_tokens, block_idx_last_scheduled_token=block_idx_last_scheduled_token, From 5506afecebe007585070df1bbc7a8ec5d17aff44 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 23:31:47 +0200 Subject: [PATCH 3/4] mamba2 Signed-off-by: Josephasafg --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index f752dc472a6c..180f386a5f94 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -832,9 +832,6 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - # Clear stale state for new requests classified as decodes. - # Uses gather-multiply-scatter (fixed-shape ops) instead of - # boolean indexing to stay compatible with CUDA graph capture. if has_initial_states_d is not None: indices = state_indices_tensor_d_input From c780a501d06d194001cc3495c2c90b49d92b9127 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Fri, 13 Mar 2026 22:38:07 +0200 Subject: [PATCH 4/4] Added clear_stale_state Signed-off-by: Josephasafg --- vllm/model_executor/layers/mamba/abstract.py | 16 ++++++++++++++++ .../model_executor/layers/mamba/mamba_mixer.py | 18 ++++++------------ .../layers/mamba/mamba_mixer2.py | 18 ++++++------------ 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 3c6b0139424d..91ba9e42ac58 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -57,6 +57,22 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: ), ) + @staticmethod + def clear_stale_decode_states( + has_initial_states_d: torch.Tensor | None, + indices: torch.Tensor, + ssm_state: torch.Tensor, + conv_state: torch.Tensor, + ) -> None: + if has_initial_states_d is None: + return + + for state in (ssm_state, conv_state): + gathered = state[indices] + keep_state = has_initial_states_d.to(gathered.dtype) + keep_state = keep_state.view(-1, *([1] * (gathered.dim() - 1))) + state[indices] = gathered * keep_state + def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this Mamba layer.""" return get_mamba_attn_backend(self.mamba_type) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 36cb884ed69e..6302156958f3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -396,18 +396,12 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - if has_initial_states_d is not None: - indices = state_indices_tensor_d_input - - ssm_gathered = ssm_state[indices] - keep_ssm = has_initial_states_d.to(ssm_gathered.dtype) - keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1))) - ssm_state[indices] = ssm_gathered * keep_ssm - - conv_gathered = conv_state[indices] - keep_conv = has_initial_states_d.to(conv_gathered.dtype) - keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1))) - conv_state[indices] = conv_gathered * keep_conv + self.clear_stale_decode_states( + has_initial_states_d, + state_indices_tensor_d_input, + ssm_state, + conv_state, + ) # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 180f386a5f94..465f80240fad 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -832,18 +832,12 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - if has_initial_states_d is not None: - indices = state_indices_tensor_d_input - - ssm_gathered = ssm_state[indices] - keep_ssm = has_initial_states_d.to(ssm_gathered.dtype) - keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1))) - ssm_state[indices] = ssm_gathered * keep_ssm - - conv_gathered = conv_state[indices] - keep_conv = has_initial_states_d.to(conv_gathered.dtype) - keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1))) - conv_state[indices] = conv_gathered * keep_conv + self.clear_stale_decode_states( + has_initial_states_d, + state_indices_tensor_d_input, + ssm_state, + conv_state, + ) # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update(