From 7746bd95ace367acbf35b0bcfb9eba477fb8932a Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 20 Jan 2026 22:04:36 +0200 Subject: [PATCH 01/15] Supporting MTP edge case Signed-off-by: Josephasafg --- vllm/v1/attention/backends/utils.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1b030eaf140a..7f70aa027fbd 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -512,14 +512,29 @@ def split_decodes_and_prefills( num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens_cpu + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + # A new request has no prior context (num_computed_tokens == 0), + # which means seq_lens == query_lens. New requests need prefill + # treatment even if query_lens <= decode_threshold (e.g., for Mamba + # state initialization). Exclude padding (query_lens == 0). + is_new_request = (seq_lens == query_lens) & (query_lens > 0) + + # If ALL non-padding requests appear to be "new" (seq_lens == query_lens) + # AND all queries are small, treat as a decode-only batch. This handles + # CUDA graph capture where synthetic batches have seq_lens == query_lens == 1. + # Real prefill batches have large query_lens, so they won't match this. + all_new_requests = torch.all(is_new_request | (query_lens == 0)) + if all_new_requests and max_query_len <= decode_threshold: + is_new_request = torch.zeros_like(is_new_request) if max_query_len <= decode_threshold and ( not require_uniform or decode_threshold <= 1 - ): + ) and not torch.any(is_new_request): return num_reqs, 0, num_tokens, 0 - - query_lens = query_start_loc[1:] - query_start_loc[:-1] - if query_lens[0].item() > decode_threshold: + if query_lens[0].item() > decode_threshold or is_new_request[0].item(): # first request is not decode, so no decode requests return 0, num_reqs, 0, num_tokens @@ -530,9 +545,9 @@ def split_decodes_and_prefills( if torch.all((query_lens == query_lens[0]) | (query_lens == 0)): assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly" return num_reqs, 0, num_tokens, 0 # all decodes - is_prefill = query_lens != query_lens[0] + is_prefill = (query_lens != query_lens[0]) | is_new_request else: - is_prefill = query_lens > decode_threshold + is_prefill = (query_lens > decode_threshold) | is_new_request if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 From 59b96ff5dae7b1ef66fe765d5726ece1ca988fbc Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 20 Jan 2026 22:14:42 +0200 Subject: [PATCH 02/15] Remove Full CG support for now Signed-off-by: Josephasafg --- vllm/v1/attention/backends/utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7f70aa027fbd..c517730e943e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -522,14 +522,6 @@ def split_decodes_and_prefills( # state initialization). Exclude padding (query_lens == 0). is_new_request = (seq_lens == query_lens) & (query_lens > 0) - # If ALL non-padding requests appear to be "new" (seq_lens == query_lens) - # AND all queries are small, treat as a decode-only batch. This handles - # CUDA graph capture where synthetic batches have seq_lens == query_lens == 1. - # Real prefill batches have large query_lens, so they won't match this. - all_new_requests = torch.all(is_new_request | (query_lens == 0)) - if all_new_requests and max_query_len <= decode_threshold: - is_new_request = torch.zeros_like(is_new_request) - if max_query_len <= decode_threshold and ( not require_uniform or decode_threshold <= 1 ) and not torch.any(is_new_request): From fb9929945d1c74641ceaf50b6a20b63221e4c756 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 2 Mar 2026 13:20:02 +0200 Subject: [PATCH 03/15] Handle cuda graph capture Signed-off-by: Josephasafg --- vllm/v1/attention/backends/utils.py | 13 ++++++++----- vllm/v1/worker/gpu_model_runner.py | 4 ++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index c517730e943e..af6c6109f33c 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -516,11 +516,14 @@ def split_decodes_and_prefills( query_lens = query_start_loc[1:] - query_start_loc[:-1] - # A new request has no prior context (num_computed_tokens == 0), - # which means seq_lens == query_lens. New requests need prefill - # treatment even if query_lens <= decode_threshold (e.g., for Mamba - # state initialization). Exclude padding (query_lens == 0). - is_new_request = (seq_lens == query_lens) & (query_lens > 0) + # A new request has no prior context (num_computed_tokens == 0). + # New requests need prefill treatment even if + # query_lens <= decode_threshold (e.g., for Mamba state init). + num_computed = common_attn_metadata._num_computed_tokens_cpu + if num_computed is not None: + is_new_request = (num_computed[:num_reqs] == 0) & (query_lens > 0) + else: + is_new_request = (seq_lens == query_lens) & (query_lens > 0) if max_query_len <= decode_threshold and ( not require_uniform or decode_threshold <= 1 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 59a82d4ce6ae..2be96e03af3f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4842,6 +4842,10 @@ def _dummy_run( else: seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens + # Mark all dummy requests as having prior context so + # split_decodes_and_prefills won't misclassify them + # as new prefill requests. + self.input_batch.num_computed_tokens_cpu[:num_reqs] = 1 self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() From 7ebba3c255c0ac1274b2700eac54869cb4ad2ba9 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 2 Mar 2026 19:32:31 +0200 Subject: [PATCH 04/15] Added tests Signed-off-by: Josephasafg --- tests/v1/attention/test_batch_reordering.py | 89 ++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/tests/v1/attention/test_batch_reordering.py b/tests/v1/attention/test_batch_reordering.py index 6265e12f9a7d..45c2e6cbc82a 100644 --- a/tests/v1/attention/test_batch_reordering.py +++ b/tests/v1/attention/test_batch_reordering.py @@ -5,8 +5,46 @@ import numpy as np import pytest +import torch + +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import ( + reorder_batch_to_split_decodes_and_prefills, + split_decodes_and_prefills, +) -from vllm.v1.attention.backends.utils import reorder_batch_to_split_decodes_and_prefills + +def _make_common_attn_metadata( + query_lens: list[int], + seq_lens: list[int], + num_computed_tokens: list[int] | None = None, +): + num_reqs = len(query_lens) + num_tokens = sum(query_lens) + max_query_len = max(query_lens) if query_lens else 0 + + query_start_loc = torch.zeros(num_reqs + 1, dtype=torch.int32) + for i, ql in enumerate(query_lens): + query_start_loc[i + 1] = query_start_loc[i] + ql + + seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32) + + nct = None + if num_computed_tokens is not None: + nct = torch.tensor(num_computed_tokens, dtype=torch.int32) + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc, + seq_lens=seq_lens_t, + _num_computed_tokens_cpu=nct, + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=max(seq_lens) if seq_lens else 0, + block_table_tensor=torch.empty(0), + slot_mapping=torch.empty(0), + ) class MockInputBatch: @@ -145,3 +183,52 @@ def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase) assert input_batch.req_ids == expected_req_ids, ( f"Expected order {expected_req_ids}, got {input_batch.req_ids}" ) + + +@dataclass +class SplitTestCase: + query_lens: list[int] + seq_lens: list[int] + num_computed_tokens: list[int] + decode_threshold: int + expected: tuple[int, int, int, int] # (num_d, num_p, num_dt, num_pt) + + +SPLIT_TEST_CASES = { + "mtp_new_request_is_prefill": SplitTestCase( + query_lens=[3], + seq_lens=[3], + num_computed_tokens=[0], + decode_threshold=4, + expected=(0, 1, 0, 3), + ), + "mtp_cuda_graph_synthetic_decodes": SplitTestCase( + query_lens=[4, 4, 4], + seq_lens=[4, 4, 4], + num_computed_tokens=[1, 1, 1], + decode_threshold=4, + expected=(3, 0, 12, 0), + ), + "mtp_mixed_decodes_and_new_request": SplitTestCase( + query_lens=[4, 4, 3], + seq_lens=[100, 200, 3], + num_computed_tokens=[96, 196, 0], + decode_threshold=4, + expected=(2, 1, 8, 3), + ), +} + + +@pytest.mark.parametrize( + "test_case", SPLIT_TEST_CASES.values(), ids=SPLIT_TEST_CASES.keys() +) +def test_split_decodes_and_prefills(test_case: SplitTestCase): + meta = _make_common_attn_metadata( + query_lens=test_case.query_lens, + seq_lens=test_case.seq_lens, + num_computed_tokens=test_case.num_computed_tokens, + ) + result = split_decodes_and_prefills( + meta, decode_threshold=test_case.decode_threshold + ) + assert result == test_case.expected From 95220311857056e07a254b43037443e871b590f4 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Mon, 2 Mar 2026 19:32:42 +0200 Subject: [PATCH 05/15] formatting Signed-off-by: Josephasafg --- vllm/v1/attention/backends/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index af6c6109f33c..2c815a9e10ee 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -525,9 +525,11 @@ def split_decodes_and_prefills( else: is_new_request = (seq_lens == query_lens) & (query_lens > 0) - if max_query_len <= decode_threshold and ( - not require_uniform or decode_threshold <= 1 - ) and not torch.any(is_new_request): + if ( + max_query_len <= decode_threshold + and (not require_uniform or decode_threshold <= 1) + and not torch.any(is_new_request) + ): return num_reqs, 0, num_tokens, 0 if query_lens[0].item() > decode_threshold or is_new_request[0].item(): # first request is not decode, so no decode requests From 6a62bb8af0840cfb17a719ff956d01b7558a7509 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Wed, 4 Mar 2026 11:25:58 +0200 Subject: [PATCH 06/15] Added has_context to common attention metadata Signed-off-by: Josephasafg --- tests/v1/attention/test_batch_reordering.py | 9 ++++++--- vllm/v1/attention/backend.py | 7 +++++++ vllm/v1/attention/backends/utils.py | 11 +++++------ vllm/v1/worker/gpu_model_runner.py | 8 ++++++-- vllm/v1/worker/ubatch_utils.py | 4 ++++ 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/v1/attention/test_batch_reordering.py b/tests/v1/attention/test_batch_reordering.py index 45c2e6cbc82a..b787de2eab6d 100644 --- a/tests/v1/attention/test_batch_reordering.py +++ b/tests/v1/attention/test_batch_reordering.py @@ -29,15 +29,18 @@ def _make_common_attn_metadata( seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32) - nct = None + has_context = None if num_computed_tokens is not None: - nct = torch.tensor(num_computed_tokens, dtype=torch.int32) + has_context = torch.tensor( + [nct > 0 for nct in num_computed_tokens], dtype=torch.bool + ) return CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc, seq_lens=seq_lens_t, - _num_computed_tokens_cpu=nct, + _seq_lens_cpu=seq_lens_t, + has_context=has_context, num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 585ad1d793ff..db25610d5b5c 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -333,6 +333,10 @@ class CommonAttentionMetadata: dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" + has_context: torch.Tensor | None = None + """(batch_size,) bool CPU tensor. True if the request has prior computed + context (num_computed_tokens > 0), False for brand-new requests.""" + # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) _seq_lens_cpu: torch.Tensor | None = None _num_computed_tokens_cpu: torch.Tensor | None = None @@ -401,6 +405,9 @@ def unpadded( _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs] if self._num_computed_tokens_cpu is not None else None, + has_context=self.has_context[:num_actual_reqs] + if self.has_context is not None + else None, num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 2c815a9e10ee..b28059f3597b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -356,6 +356,7 @@ def make_local_attention_virtual_batches( causal=True, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), + has_context=torch.from_numpy(num_computed_tokens_local > 0), ), make_block_table @@ -412,6 +413,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, causal=True, + has_context=common_attn_metadata.has_context, _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, ) @@ -512,18 +514,15 @@ def split_decodes_and_prefills( num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - seq_lens = common_attn_metadata.seq_lens_cpu query_lens = query_start_loc[1:] - query_start_loc[:-1] # A new request has no prior context (num_computed_tokens == 0). # New requests need prefill treatment even if # query_lens <= decode_threshold (e.g., for Mamba state init). - num_computed = common_attn_metadata._num_computed_tokens_cpu - if num_computed is not None: - is_new_request = (num_computed[:num_reqs] == 0) & (query_lens > 0) - else: - is_new_request = (seq_lens == query_lens) & (query_lens > 0) + has_context = common_attn_metadata.has_context + assert has_context is not None + is_new_request = ~has_context[:num_reqs] & (query_lens > 0) if ( max_query_len <= decode_threshold diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 198ccb3c9da8..78d49086b7be 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1802,6 +1802,9 @@ def _get_block_table(kv_cache_gid: int): if self.model_config.enable_return_routed_experts: self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() + has_context = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu[:num_reqs_padded] > 0 + ) cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], @@ -1810,6 +1813,7 @@ def _get_block_table(kv_cache_gid: int): _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ :num_reqs_padded ], + has_context=has_context, num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len, @@ -4843,8 +4847,8 @@ def _dummy_run( seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens # Mark all dummy requests as having prior context so - # split_decodes_and_prefills won't misclassify them - # as new prefill requests. + # has_context is True and split_decodes_and_prefills + # won't misclassify them as new prefill requests. self.input_batch.num_computed_tokens_cpu[:num_reqs] = 1 self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 7c41726472d5..99e5a7c04684 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -211,6 +211,9 @@ def _make_metadata_with_slice( block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] + assert attn_metadata.has_context is not None + has_context = attn_metadata.has_context[request_slice] + return CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -221,6 +224,7 @@ def _make_metadata_with_slice( max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + has_context=has_context, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu, ) From f326fac013c0483da5b6b11a54450a3845d54bef Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 18:26:29 +0200 Subject: [PATCH 07/15] Changed to use clear_stale_mamba_states instead of has_context Signed-off-by: Josephasafg --- .../layers/mamba/mamba_mixer.py | 1 + vllm/v1/attention/backend.py | 7 --- vllm/v1/attention/backends/mamba_attn.py | 16 +++++++ vllm/v1/attention/backends/utils.py | 25 +++------- vllm/v1/worker/gpu_model_runner.py | 16 +++---- vllm/v1/worker/mamba_utils.py | 48 ++++++++++++++++++- vllm/v1/worker/ubatch_utils.py | 4 -- 7 files changed, 79 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6a33fc7d6b1b..3355640d4666 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -394,6 +394,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): else: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 2b1c4920cff8..3af817a2e08f 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -333,10 +333,6 @@ class CommonAttentionMetadata: dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" - has_context: torch.Tensor | None = None - """(batch_size,) bool CPU tensor. True if the request has prior computed - context (num_computed_tokens > 0), False for brand-new requests.""" - # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) _seq_lens_cpu: torch.Tensor | None = None _num_computed_tokens_cpu: torch.Tensor | None = None @@ -405,9 +401,6 @@ def unpadded( _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs] if self._num_computed_tokens_cpu is not None else None, - has_context=self.has_context[:num_actual_reqs] - if self.has_context is not None - else None, num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0364d6aee5c7..f86e691524c0 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -43,6 +43,7 @@ class BaseMambaAttentionMetadata: # The following tensors are used for decode requests and # speculative decoding compatibility, and will be None if the batch # has no decode requests. + has_initial_states_d: torch.Tensor | None state_indices_tensor_d: torch.Tensor | None query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,] @@ -364,6 +365,7 @@ def _compute_common_metadata( # Need flags to indicate if there are initial states has_initial_states_p = None + has_initial_states_d = None query_start_loc_p = None query_start_loc_d = None num_computed_tokens = None @@ -414,6 +416,19 @@ def _compute_common_metadata( ] state_indices_tensor_p = state_indices_tensor_p[:, 0] + # Only set when there are genuinely new decode requests + # (num_computed_tokens == 0, seq_lens > 0). Padded CG slots + # (seq_lens == 0) are excluded so their PAD_SLOT_ID doesn't + # cause zeroing of the last real cache slot. + if num_decodes > 0: + if num_computed_tokens is None: + num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() + has_initial_states_d = (num_computed_tokens[:num_decodes] > 0) | ( + common_attn_metadata.seq_lens[:num_decodes] == 0 + ) + if has_initial_states_d.all(): + has_initial_states_d = None + if num_decodes > 0 and self.use_spec_decode: assert num_accepted_tokens is not None query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1] @@ -459,6 +474,7 @@ def _compute_common_metadata( num_decode_tokens=num_decode_tokens, query_start_loc_p=query_start_loc_p, has_initial_states_p=has_initial_states_p, + has_initial_states_d=has_initial_states_d, state_indices_tensor_p=state_indices_tensor_p, state_indices_tensor_d=state_indices_tensor_d, num_accepted_tokens=num_accepted_tokens, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b28059f3597b..1b030eaf140a 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -356,7 +356,6 @@ def make_local_attention_virtual_batches( causal=True, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), - has_context=torch.from_numpy(num_computed_tokens_local > 0), ), make_block_table @@ -413,7 +412,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, causal=True, - has_context=common_attn_metadata.has_context, _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, ) @@ -515,22 +513,13 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - query_lens = query_start_loc[1:] - query_start_loc[:-1] - - # A new request has no prior context (num_computed_tokens == 0). - # New requests need prefill treatment even if - # query_lens <= decode_threshold (e.g., for Mamba state init). - has_context = common_attn_metadata.has_context - assert has_context is not None - is_new_request = ~has_context[:num_reqs] & (query_lens > 0) - - if ( - max_query_len <= decode_threshold - and (not require_uniform or decode_threshold <= 1) - and not torch.any(is_new_request) + if max_query_len <= decode_threshold and ( + not require_uniform or decode_threshold <= 1 ): return num_reqs, 0, num_tokens, 0 - if query_lens[0].item() > decode_threshold or is_new_request[0].item(): + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + if query_lens[0].item() > decode_threshold: # first request is not decode, so no decode requests return 0, num_reqs, 0, num_tokens @@ -541,9 +530,9 @@ def split_decodes_and_prefills( if torch.all((query_lens == query_lens[0]) | (query_lens == 0)): assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly" return num_reqs, 0, num_tokens, 0 # all decodes - is_prefill = (query_lens != query_lens[0]) | is_new_request + is_prefill = query_lens != query_lens[0] else: - is_prefill = (query_lens > decode_threshold) | is_new_request + is_prefill = query_lens > decode_threshold if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index edd6805ac43c..3b8ff7eec571 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1912,9 +1912,6 @@ def _get_block_table(kv_cache_gid: int): if self.model_config.enable_return_routed_experts: self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() - has_context = torch.from_numpy( - self.input_batch.num_computed_tokens_cpu[:num_reqs_padded] > 0 - ) cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], @@ -1923,7 +1920,6 @@ def _get_block_table(kv_cache_gid: int): _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ :num_reqs_padded ], - has_context=has_context, num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len, @@ -3687,6 +3683,14 @@ def execute_model( ) ) + # New mamba requests classified as decodes + # would read stale state from recycled cache slots. + mamba_utils.clear_stale_mamba_states( + attn_metadata, + self.attn_groups, + self.compilation_config.static_forward_context, + ) + ( input_ids, inputs_embeds, @@ -5089,10 +5093,6 @@ def _dummy_run( else: seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens - # Mark all dummy requests as having prior context so - # has_context is True and split_decodes_and_prefills - # won't misclassify them as new prefill requests. - self.input_batch.num_computed_tokens_cpu[:num_reqs] = 1 self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 2bd5d2b3fea8..f60fc51ff76e 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -3,7 +3,7 @@ import dataclasses import itertools from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import torch @@ -19,6 +19,9 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch +if TYPE_CHECKING: + from vllm.v1.worker.utils import AttentionGroup + @triton.jit def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): @@ -214,6 +217,49 @@ def preprocess_mamba( do_mamba_copy_block(copy_bufs) +def clear_stale_mamba_states( + attn_metadata: dict[str, Any], + attn_groups: list[list["AttentionGroup"]], + forward_context: dict[str, Any], +) -> None: + """Zero Mamba states for new requests in the decode batch. + + Runs outside the CUDA graph so zeroing is not recorded. + New requests (has_initial_states=False) would otherwise read stale + state from a recycled cache slot. + """ + from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata + + for kv_cache_groups in attn_groups: + for attn_group in kv_cache_groups: + if not isinstance(attn_group.kv_cache_spec, MambaSpec): + continue + + first_layer = attn_group.layer_names[0] + if first_layer not in attn_metadata: + continue + metadata = attn_metadata[first_layer] + if not isinstance(metadata, BaseMambaAttentionMetadata): + continue + + has_initial_states_d = metadata.has_initial_states_d + if has_initial_states_d is None: + continue + + num_decodes = metadata.num_decodes + indices = metadata.state_indices_tensor_d[:num_decodes] + + # Keep only new-request slots + new_indices = indices[~has_initial_states_d[:num_decodes]] + if new_indices.numel() == 0: + continue + + for layer_name in attn_group.layer_names: + layer = forward_context[layer_name] + for state in layer.kv_cache[0]: + state[new_indices] = 0 + + def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 99e5a7c04684..7c41726472d5 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -211,9 +211,6 @@ def _make_metadata_with_slice( block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] - assert attn_metadata.has_context is not None - has_context = attn_metadata.has_context[request_slice] - return CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -224,7 +221,6 @@ def _make_metadata_with_slice( max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, - has_context=has_context, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu, ) From a63da0f0b1c16b7aa5ebd4c3e6fa6d1b258b74a1 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 18:26:39 +0200 Subject: [PATCH 08/15] Reverted tests Signed-off-by: Josephasafg --- tests/v1/attention/test_batch_reordering.py | 88 --------------------- 1 file changed, 88 deletions(-) diff --git a/tests/v1/attention/test_batch_reordering.py b/tests/v1/attention/test_batch_reordering.py index b787de2eab6d..63fa53d6b44a 100644 --- a/tests/v1/attention/test_batch_reordering.py +++ b/tests/v1/attention/test_batch_reordering.py @@ -5,51 +5,12 @@ import numpy as np import pytest -import torch -from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.utils import ( reorder_batch_to_split_decodes_and_prefills, - split_decodes_and_prefills, ) -def _make_common_attn_metadata( - query_lens: list[int], - seq_lens: list[int], - num_computed_tokens: list[int] | None = None, -): - num_reqs = len(query_lens) - num_tokens = sum(query_lens) - max_query_len = max(query_lens) if query_lens else 0 - - query_start_loc = torch.zeros(num_reqs + 1, dtype=torch.int32) - for i, ql in enumerate(query_lens): - query_start_loc[i + 1] = query_start_loc[i] + ql - - seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32) - - has_context = None - if num_computed_tokens is not None: - has_context = torch.tensor( - [nct > 0 for nct in num_computed_tokens], dtype=torch.bool - ) - - return CommonAttentionMetadata( - query_start_loc=query_start_loc, - query_start_loc_cpu=query_start_loc, - seq_lens=seq_lens_t, - _seq_lens_cpu=seq_lens_t, - has_context=has_context, - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - max_seq_len=max(seq_lens) if seq_lens else 0, - block_table_tensor=torch.empty(0), - slot_mapping=torch.empty(0), - ) - - class MockInputBatch: def __init__(self, req_ids, num_computed_tokens_cpu): self.req_ids = req_ids @@ -186,52 +147,3 @@ def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase) assert input_batch.req_ids == expected_req_ids, ( f"Expected order {expected_req_ids}, got {input_batch.req_ids}" ) - - -@dataclass -class SplitTestCase: - query_lens: list[int] - seq_lens: list[int] - num_computed_tokens: list[int] - decode_threshold: int - expected: tuple[int, int, int, int] # (num_d, num_p, num_dt, num_pt) - - -SPLIT_TEST_CASES = { - "mtp_new_request_is_prefill": SplitTestCase( - query_lens=[3], - seq_lens=[3], - num_computed_tokens=[0], - decode_threshold=4, - expected=(0, 1, 0, 3), - ), - "mtp_cuda_graph_synthetic_decodes": SplitTestCase( - query_lens=[4, 4, 4], - seq_lens=[4, 4, 4], - num_computed_tokens=[1, 1, 1], - decode_threshold=4, - expected=(3, 0, 12, 0), - ), - "mtp_mixed_decodes_and_new_request": SplitTestCase( - query_lens=[4, 4, 3], - seq_lens=[100, 200, 3], - num_computed_tokens=[96, 196, 0], - decode_threshold=4, - expected=(2, 1, 8, 3), - ), -} - - -@pytest.mark.parametrize( - "test_case", SPLIT_TEST_CASES.values(), ids=SPLIT_TEST_CASES.keys() -) -def test_split_decodes_and_prefills(test_case: SplitTestCase): - meta = _make_common_attn_metadata( - query_lens=test_case.query_lens, - seq_lens=test_case.seq_lens, - num_computed_tokens=test_case.num_computed_tokens, - ) - result = split_decodes_and_prefills( - meta, decode_threshold=test_case.decode_threshold - ) - assert result == test_case.expected From fddaf4d4c8a1410afce09877b5e340ee5bc3825f Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 18:39:09 +0200 Subject: [PATCH 09/15] Simplified function Signed-off-by: Josephasafg --- vllm/v1/worker/mamba_utils.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index f60fc51ff76e..b4f808b04f6e 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -3,7 +3,7 @@ import dataclasses import itertools from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import Any import torch @@ -13,14 +13,13 @@ ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch - -if TYPE_CHECKING: - from vllm.v1.worker.utils import AttentionGroup +from vllm.v1.worker.utils import AttentionGroup @triton.jit @@ -218,27 +217,24 @@ def preprocess_mamba( def clear_stale_mamba_states( - attn_metadata: dict[str, Any], + attn_metadata: list[dict[str, Any]] | dict[str, Any], attn_groups: list[list["AttentionGroup"]], forward_context: dict[str, Any], ) -> None: - """Zero Mamba states for new requests in the decode batch. + """Clear Mamba states for new requests in the decode batch. - Runs outside the CUDA graph so zeroing is not recorded. New requests (has_initial_states=False) would otherwise read stale state from a recycled cache slot. """ - from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata + if not isinstance(attn_metadata, dict): + return for kv_cache_groups in attn_groups: for attn_group in kv_cache_groups: if not isinstance(attn_group.kv_cache_spec, MambaSpec): continue - first_layer = attn_group.layer_names[0] - if first_layer not in attn_metadata: - continue - metadata = attn_metadata[first_layer] + metadata = attn_metadata.get(attn_group.layer_names[0]) if not isinstance(metadata, BaseMambaAttentionMetadata): continue @@ -246,13 +242,9 @@ def clear_stale_mamba_states( if has_initial_states_d is None: continue - num_decodes = metadata.num_decodes - indices = metadata.state_indices_tensor_d[:num_decodes] - - # Keep only new-request slots - new_indices = indices[~has_initial_states_d[:num_decodes]] - if new_indices.numel() == 0: - continue + assert metadata.state_indices_tensor_d is not None + indices = metadata.state_indices_tensor_d[: metadata.num_decodes] + new_indices = indices[~has_initial_states_d[: metadata.num_decodes]] for layer_name in attn_group.layer_names: layer = forward_context[layer_name] From 9cf03a302752072f38d441eb2d415f37bbbac098 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 18:43:42 +0200 Subject: [PATCH 10/15] reverted unnecessary changes Signed-off-by: Josephasafg --- tests/v1/attention/test_batch_reordering.py | 4 +--- vllm/model_executor/layers/mamba/mamba_mixer.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/v1/attention/test_batch_reordering.py b/tests/v1/attention/test_batch_reordering.py index 63fa53d6b44a..6265e12f9a7d 100644 --- a/tests/v1/attention/test_batch_reordering.py +++ b/tests/v1/attention/test_batch_reordering.py @@ -6,9 +6,7 @@ import numpy as np import pytest -from vllm.v1.attention.backends.utils import ( - reorder_batch_to_split_decodes_and_prefills, -) +from vllm.v1.attention.backends.utils import reorder_batch_to_split_decodes_and_prefills class MockInputBatch: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 3355640d4666..6a33fc7d6b1b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -394,7 +394,6 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): else: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), From 9f2b598d84091fb3fdb189dbbc3a730a17691aae Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 20:57:15 +0200 Subject: [PATCH 11/15] Removed redudnant comment Signed-off-by: Josephasafg --- vllm/v1/worker/gpu_model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3b8ff7eec571..afcc13e4de4a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3683,8 +3683,6 @@ def execute_model( ) ) - # New mamba requests classified as decodes - # would read stale state from recycled cache slots. mamba_utils.clear_stale_mamba_states( attn_metadata, self.attn_groups, From d27f74a364bbea9949cad5e94b5d70ac765ac58d Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 22:31:12 +0200 Subject: [PATCH 12/15] test Signed-off-by: Josephasafg --- .../layers/mamba/mamba_mixer.py | 8 ++++ vllm/v1/worker/gpu_model_runner.py | 6 --- vllm/v1/worker/mamba_utils.py | 38 ------------------- 3 files changed, 8 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6a33fc7d6b1b..f4537e743cdc 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -271,6 +271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p @@ -394,6 +395,13 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): else: state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + + # Clear stale state for new requests classified as decodes + if has_initial_states_d is not None: + new_indices = state_indices_tensor_d_input[~has_initial_states_d] + conv_state[:, new_indices] = 0 + ssm_state[new_indices] = 0 + # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index afcc13e4de4a..08dbd614fdcf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3683,12 +3683,6 @@ def execute_model( ) ) - mamba_utils.clear_stale_mamba_states( - attn_metadata, - self.attn_groups, - self.compilation_config.static_forward_context, - ) - ( input_ids, inputs_embeds, diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index b4f808b04f6e..2bd5d2b3fea8 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -13,13 +13,11 @@ ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch -from vllm.v1.worker.utils import AttentionGroup @triton.jit @@ -216,42 +214,6 @@ def preprocess_mamba( do_mamba_copy_block(copy_bufs) -def clear_stale_mamba_states( - attn_metadata: list[dict[str, Any]] | dict[str, Any], - attn_groups: list[list["AttentionGroup"]], - forward_context: dict[str, Any], -) -> None: - """Clear Mamba states for new requests in the decode batch. - - New requests (has_initial_states=False) would otherwise read stale - state from a recycled cache slot. - """ - if not isinstance(attn_metadata, dict): - return - - for kv_cache_groups in attn_groups: - for attn_group in kv_cache_groups: - if not isinstance(attn_group.kv_cache_spec, MambaSpec): - continue - - metadata = attn_metadata.get(attn_group.layer_names[0]) - if not isinstance(metadata, BaseMambaAttentionMetadata): - continue - - has_initial_states_d = metadata.has_initial_states_d - if has_initial_states_d is None: - continue - - assert metadata.state_indices_tensor_d is not None - indices = metadata.state_indices_tensor_d[: metadata.num_decodes] - new_indices = indices[~has_initial_states_d[: metadata.num_decodes]] - - for layer_name in attn_group.layer_names: - layer = forward_context[layer_name] - for state in layer.kv_cache[0]: - state[new_indices] = 0 - - def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, From 1ad6f2acbe46c051a256072d5aad3fcaa8c3f2c5 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 23:30:33 +0200 Subject: [PATCH 13/15] mod change Signed-off-by: Josephasafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 15 +++++++++++---- .../model_executor/layers/mamba/mamba_mixer2.py | 17 +++++++++++++++++ vllm/v1/attention/backends/mamba_attn.py | 15 +++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index f4537e743cdc..36cb884ed69e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -396,11 +396,18 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - # Clear stale state for new requests classified as decodes if has_initial_states_d is not None: - new_indices = state_indices_tensor_d_input[~has_initial_states_d] - conv_state[:, new_indices] = 0 - ssm_state[new_indices] = 0 + indices = state_indices_tensor_d_input + + ssm_gathered = ssm_state[indices] + keep_ssm = has_initial_states_d.to(ssm_gathered.dtype) + keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1))) + ssm_state[indices] = ssm_gathered * keep_ssm + + conv_gathered = conv_state[indices] + keep_conv = has_initial_states_d.to(conv_gathered.dtype) + keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1))) + conv_state[indices] = conv_gathered * keep_conv # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 971581d89c27..f752dc472a6c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -580,6 +580,7 @@ def conv_ssm_forward( conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + has_initial_states_d = attn_metadata.has_initial_states_d prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p @@ -831,6 +832,22 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d + # Clear stale state for new requests classified as decodes. + # Uses gather-multiply-scatter (fixed-shape ops) instead of + # boolean indexing to stay compatible with CUDA graph capture. + if has_initial_states_d is not None: + indices = state_indices_tensor_d_input + + ssm_gathered = ssm_state[indices] + keep_ssm = has_initial_states_d.to(ssm_gathered.dtype) + keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1))) + ssm_state[indices] = ssm_gathered * keep_ssm + + conv_gathered = conv_state[indices] + keep_conv = has_initial_states_d.to(conv_gathered.dtype) + keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1))) + conv_state[indices] = conv_gathered * keep_conv + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index f86e691524c0..23adf6298a4b 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -149,6 +149,12 @@ def __init__( device=device, ) + self.has_initial_states_d: torch.Tensor = torch.ones( + (self.decode_cudagraph_max_bs,), + dtype=torch.bool, + device=device, + ) + self._init_reorder_batch_threshold(1, self.use_spec_decode) if self.use_spec_decode: self.supports_update_block_table = False @@ -501,6 +507,7 @@ def _update_metadata_for_cudagraph_capture( Currently, only decode is supported for full cudagraphs with Mamba. """ state_indices_tensor_d = metadata.state_indices_tensor_d + has_initial_states_d = metadata.has_initial_states_d query_start_loc_d = metadata.query_start_loc_d num_accepted_tokens = metadata.num_accepted_tokens block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token @@ -517,6 +524,13 @@ def _update_metadata_for_cudagraph_capture( state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID + if has_initial_states_d is not None: + self.has_initial_states_d[: metadata.num_decodes].copy_( + has_initial_states_d, non_blocking=True + ) + self.has_initial_states_d[metadata.num_decodes :] = True + has_initial_states_d = self.has_initial_states_d[:padded_bs] + if self.use_spec_decode: assert query_start_loc_d is not None assert num_accepted_tokens is not None @@ -551,6 +565,7 @@ def _update_metadata_for_cudagraph_capture( return replace( metadata, state_indices_tensor_d=state_indices_tensor_d, + has_initial_states_d=has_initial_states_d, query_start_loc_d=query_start_loc_d, num_accepted_tokens=num_accepted_tokens, block_idx_last_scheduled_token=block_idx_last_scheduled_token, From 5506afecebe007585070df1bbc7a8ec5d17aff44 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Tue, 10 Mar 2026 23:31:47 +0200 Subject: [PATCH 14/15] mamba2 Signed-off-by: Josephasafg --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index f752dc472a6c..180f386a5f94 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -832,9 +832,6 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - # Clear stale state for new requests classified as decodes. - # Uses gather-multiply-scatter (fixed-shape ops) instead of - # boolean indexing to stay compatible with CUDA graph capture. if has_initial_states_d is not None: indices = state_indices_tensor_d_input From c780a501d06d194001cc3495c2c90b49d92b9127 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Fri, 13 Mar 2026 22:38:07 +0200 Subject: [PATCH 15/15] Added clear_stale_state Signed-off-by: Josephasafg --- vllm/model_executor/layers/mamba/abstract.py | 16 ++++++++++++++++ .../model_executor/layers/mamba/mamba_mixer.py | 18 ++++++------------ .../layers/mamba/mamba_mixer2.py | 18 ++++++------------ 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 3c6b0139424d..91ba9e42ac58 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -57,6 +57,22 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: ), ) + @staticmethod + def clear_stale_decode_states( + has_initial_states_d: torch.Tensor | None, + indices: torch.Tensor, + ssm_state: torch.Tensor, + conv_state: torch.Tensor, + ) -> None: + if has_initial_states_d is None: + return + + for state in (ssm_state, conv_state): + gathered = state[indices] + keep_state = has_initial_states_d.to(gathered.dtype) + keep_state = keep_state.view(-1, *([1] * (gathered.dim() - 1))) + state[indices] = gathered * keep_state + def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this Mamba layer.""" return get_mamba_attn_backend(self.mamba_type) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 36cb884ed69e..6302156958f3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -396,18 +396,12 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - if has_initial_states_d is not None: - indices = state_indices_tensor_d_input - - ssm_gathered = ssm_state[indices] - keep_ssm = has_initial_states_d.to(ssm_gathered.dtype) - keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1))) - ssm_state[indices] = ssm_gathered * keep_ssm - - conv_gathered = conv_state[indices] - keep_conv = has_initial_states_d.to(conv_gathered.dtype) - keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1))) - conv_state[indices] = conv_gathered * keep_conv + self.clear_stale_decode_states( + has_initial_states_d, + state_indices_tensor_d_input, + ssm_state, + conv_state, + ) # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 180f386a5f94..465f80240fad 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -832,18 +832,12 @@ def conv_ssm_forward( state_indices_tensor_d_input = state_indices_tensor_d state_indices_tensor_d_output = state_indices_tensor_d - if has_initial_states_d is not None: - indices = state_indices_tensor_d_input - - ssm_gathered = ssm_state[indices] - keep_ssm = has_initial_states_d.to(ssm_gathered.dtype) - keep_ssm = keep_ssm.view(-1, *([1] * (ssm_gathered.dim() - 1))) - ssm_state[indices] = ssm_gathered * keep_ssm - - conv_gathered = conv_state[indices] - keep_conv = has_initial_states_d.to(conv_gathered.dtype) - keep_conv = keep_conv.view(-1, *([1] * (conv_gathered.dim() - 1))) - conv_state[indices] = conv_gathered * keep_conv + self.clear_stale_decode_states( + has_initial_states_d, + state_indices_tensor_d_input, + ssm_state, + conv_state, + ) # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update(