From d965e6c0944c7b50f8ad6899aafcf6cb844128a2 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Sun, 17 May 2026 11:08:25 +0300 Subject: [PATCH 1/4] fix comment to make clear that this is a bugfix for any final single-token prefill chunk + FULL-CG Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- vllm/v1/attention/backends/mamba_attn.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 716dfcde592f..1ffa19f1de23 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -355,6 +355,28 @@ def _compute_common_metadata( self.reorder_batch_threshold if num_accepted_tokens is not None else 1 ) + # FULL-CG dispatch is shape-based, so one-token prefills with + # prior Mamba state can replay a decode graph while `is_prefilling` + # is still true. Treat them as decode/update rows. This is required + # for NIXL disagg's h(N-1)->N recompute path and for sporadic + # final single-token prefill chunks that land in a `uniform` FULL-CG + # batch. Relies on `reorder` putting short extends before pure prefills. + is_prefilling = common_attn_metadata.is_prefilling + assert is_prefilling is not None + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + assert seq_lens_cpu is not None + query_lens_cpu = torch.diff(common_attn_metadata.query_start_loc_cpu) + single_token_prefill_rows = is_prefilling & (query_lens_cpu == 1) + # First-token prefills have no prior Mamba state and must stay prefills. + has_prior_state = seq_lens_cpu > 1 + prefill_to_decode = single_token_prefill_rows & has_prior_state + if torch.any(prefill_to_decode).item(): + is_prefilling = is_prefilling.clone() + is_prefilling[prefill_to_decode] = False + common_attn_metadata = common_attn_metadata.replace( + is_prefilling=is_prefilling + ) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, From dfafc20a02ca2b445929a07b163078ea607431d4 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Sun, 17 May 2026 14:40:56 +0300 Subject: [PATCH 2/4] add unittests Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- .../test_mamba_update_block_table.py | 50 +++++++++++++------ tests/v1/attention/utils.py | 41 ++++++++++++++- .../unit/test_nixl_connector_hma.py | 25 ++++++++++ 3 files changed, 99 insertions(+), 17 deletions(-) diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index 923939053ece..cf6e103d32a4 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -16,27 +16,25 @@ import torch +from tests.v1.attention.utils import MockMambaBuilder from vllm.config.compilation import CUDAGraphMode -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadata, - BaseMambaAttentionMetadataBuilder, -) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata from vllm.v1.kv_cache_interface import MambaSpec -class _ConcreteMambaBuilder( - BaseMambaAttentionMetadataBuilder[BaseMambaAttentionMetadata] +def _make_vllm_config( + block_size: int, + max_model_len: int, + max_num_seqs: int, + mamba_cache_mode: str = "none", ): - """Minimal concrete subclass for testing (base class is ABC).""" - - metadata_cls = BaseMambaAttentionMetadata - - -def _make_vllm_config(block_size, max_model_len, max_num_seqs): """Create a minimal mock VllmConfig with only the fields the builder accesses, avoiding any model download / HF config inspection.""" return SimpleNamespace( - cache_config=SimpleNamespace(mamba_cache_mode="all"), + cache_config=SimpleNamespace( + block_size=block_size, + mamba_cache_mode=mamba_cache_mode, + ), compilation_config=SimpleNamespace( cudagraph_mode=CUDAGraphMode.FULL, max_cudagraph_capture_size=None, @@ -49,6 +47,21 @@ def _make_vllm_config(block_size, max_model_len, max_num_seqs): ) +def test_mamba_single_token_prompt_runs_as_prefill(): + seq_lens = [8, 9, 1] + config = _make_vllm_config(16, 256, len(seq_lens)) + metadata = MockMambaBuilder.build_mamba_metadata( + config, + seq_lens=seq_lens, + query_lens=[1] * len(seq_lens), + is_prefilling=[False, False, True], + ) + + assert metadata.num_decodes == 2 + assert metadata.num_prefills == 1 + assert metadata.has_initial_states_p.tolist() == [False] + + def test_update_block_table_copies_block_idx_to_persistent_buffers(): """update_block_table() must write block_idx tensors to the current builder's persistent buffers, not leave them pointing to a different @@ -59,7 +72,12 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): num_reqs = 4 device = torch.device("cpu") - vllm_config = _make_vllm_config(block_size, max_model_len, num_reqs) + vllm_config = _make_vllm_config( + block_size, + max_model_len, + num_reqs, + mamba_cache_mode="all", + ) spec = MambaSpec( block_size=block_size, @@ -69,8 +87,8 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): ) # Two builders simulating two KV cache groups with the same MambaSpec. - builder_a = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) - builder_b = _ConcreteMambaBuilder(spec, ["layer1"], vllm_config, device) + builder_a = MockMambaBuilder(spec, ["layer0"], vllm_config, device) + builder_b = MockMambaBuilder(spec, ["layer1"], vllm_config, device) # Sanity: each builder has its own persistent buffer. assert ( diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 1d5eba74693a..3daafcdc6ff8 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -24,8 +24,16 @@ AttentionType, CommonAttentionMetadata, ) +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadata, + BaseMambaAttentionMetadataBuilder, +) from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, FullAttentionSpec +from vllm.v1.kv_cache_interface import ( + EncoderOnlyAttentionSpec, + FullAttentionSpec, + MambaSpec, +) @dataclass @@ -376,3 +384,34 @@ class BackendConfig: }, ), } + + +class MockMambaBuilder(BaseMambaAttentionMetadataBuilder[BaseMambaAttentionMetadata]): + """Minimal concrete subclass for testing (base class is ABC).""" + + metadata_cls = BaseMambaAttentionMetadata + + @classmethod + def build_mamba_metadata( + cls, + vllm_config: VllmConfig, + seq_lens: list[int], + query_lens: list[int], + is_prefilling: list[bool], + *, + device: torch.device | None = None, + ) -> BaseMambaAttentionMetadata: + block_size = vllm_config.cache_config.block_size + device = device or torch.device("cpu") + mamba_spec = MambaSpec( + block_size=block_size, shapes=((1,), (1,)), dtypes=(torch.float32,) + ) + builder = cls(mamba_spec, ["layer0"], vllm_config, device) + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens) + common_metadata = create_common_attn_metadata( + batch_spec, block_size=block_size, device=device, arange_block_indices=True + ) + common_metadata = common_metadata.replace( + is_prefilling=torch.tensor(is_prefilling, dtype=torch.bool) + ) + return builder.build(0, common_metadata) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index c80dbfc62842..9c163fdf327f 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -8,6 +8,7 @@ import pytest import torch +from tests.v1.attention.utils import MockMambaBuilder from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig from vllm.v1.core.single_type_kv_cache_manager import ( @@ -636,6 +637,30 @@ def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count): assert is_async is True +@pytest.mark.cpu_test +def test_mamba_n1_d_side_builds_decode_metadata(): + req = create_request(num_tokens=10, do_remote_prefill=True) + sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True) + + num_computed_tokens, is_async = sched.get_num_new_matched_tokens( + req, num_computed_tokens=0 + ) + + assert num_computed_tokens == req.num_prompt_tokens - 1 + assert is_async is True + + vllm_config = create_vllm_config() + metadata = MockMambaBuilder.build_mamba_metadata( + vllm_config, + seq_lens=[req.num_prompt_tokens], + query_lens=[1], + is_prefilling=[True], + ) + + assert metadata.num_decodes == 1 + assert metadata.num_prefills == 0 + + @pytest.mark.cpu_test def test_mamba_n1_p_side_truncation(): """P-side: Mamba truncates prompt to N-1, sets max_tokens=1. From 423b92be5516197d5ce9663fde7e94c0dfaf00f3 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Mon, 18 May 2026 16:35:24 +0300 Subject: [PATCH 3/4] minimize diff Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- tests/v1/attention/test_mamba_update_block_table.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index 2ee6897d9186..931bd39a24f7 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -23,10 +23,10 @@ def _make_vllm_config( - block_size: int, max_model_len: int, max_num_seqs: int, num_speculative_tokens: int = 0, + block_size: int | None = None, ): """Create a minimal mock VllmConfig with only the fields the builder accesses, avoiding any model download / HF config inspection.""" @@ -57,7 +57,7 @@ def _make_vllm_config( def test_mamba_single_token_prompt_runs_as_prefill(): seq_lens = [8, 9, 1] - config = _make_vllm_config(16, 256, len(seq_lens)) + config = _make_vllm_config(256, len(seq_lens), block_size=16) metadata = MockMambaBuilder.build_mamba_metadata( config, seq_lens=seq_lens, @@ -81,7 +81,6 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): device = torch.device("cpu") vllm_config = _make_vllm_config( - block_size, max_model_len, num_reqs, ) @@ -192,7 +191,6 @@ def test_state_indices_tensor_d_includes_num_speculative_blocks(): device = torch.device("cpu") vllm_config = _make_vllm_config( - block_size, max_model_len, max_num_seqs, num_speculative_tokens=num_speculative_tokens, @@ -226,7 +224,6 @@ def test_block_idx_cudagraph_capture_padded_by_num_reqs(): device = torch.device("cpu") vllm_config = _make_vllm_config( - block_size, max_model_len, max_num_seqs, num_speculative_tokens=num_speculative_tokens, @@ -307,7 +304,6 @@ def test_block_idx_prev_step_persistent_buffer_allocated(): device = torch.device("cpu") vllm_config = _make_vllm_config( - block_size, max_model_len, max_num_seqs, num_speculative_tokens=num_speculative_tokens, @@ -335,7 +331,6 @@ def test_block_idx_prev_step_persistent_buffer_skipped_without_spec_decode(): device = torch.device("cpu") vllm_config = _make_vllm_config( - block_size, max_model_len, max_num_seqs, num_speculative_tokens=0, @@ -363,7 +358,6 @@ def test_block_idx_prev_step_cudagraph_capture_uses_persistent_buffer(): device = torch.device("cpu") vllm_config = _make_vllm_config( - block_size, max_model_len, max_num_seqs, num_speculative_tokens=num_speculative_tokens, From 7e6eadcf30fe8405db6eea184cd8444743316e7f Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Mon, 18 May 2026 16:38:06 +0300 Subject: [PATCH 4/4] minimize diff Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- tests/v1/attention/test_mamba_update_block_table.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index 931bd39a24f7..99dcb09ab154 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -80,10 +80,7 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): num_reqs = 4 device = torch.device("cpu") - vllm_config = _make_vllm_config( - max_model_len, - num_reqs, - ) + vllm_config = _make_vllm_config(max_model_len, num_reqs) spec = MambaSpec( block_size=block_size, @@ -331,9 +328,7 @@ def test_block_idx_prev_step_persistent_buffer_skipped_without_spec_decode(): device = torch.device("cpu") vllm_config = _make_vllm_config( - max_model_len, - max_num_seqs, - num_speculative_tokens=0, + max_model_len, max_num_seqs, num_speculative_tokens=0 ) spec = MambaSpec( block_size=block_size,