From ef012fe9d386a8e71c8b104fd4e495615d1f3ba1 Mon Sep 17 00:00:00 2001 From: khluu Date: Tue, 19 May 2026 19:27:00 -0700 Subject: [PATCH] Revert "[Bugfix] mamba: run single-token extends as decodes (#42430)" This reverts commit 47829b1159335a010521ea3e5361d51744a36b0a. --- .../test_mamba_update_block_table.py | 53 +++++++------------ tests/v1/attention/utils.py | 41 +------------- .../unit/test_nixl_connector_hma.py | 25 --------- vllm/v1/attention/backends/mamba_attn.py | 22 -------- 4 files changed, 21 insertions(+), 120 deletions(-) diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index 99dcb09ab154..619abac952de 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -16,18 +16,23 @@ import torch -from tests.v1.attention.utils import MockMambaBuilder from vllm.config.compilation import CUDAGraphMode -from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadata, + BaseMambaAttentionMetadataBuilder, +) from vllm.v1.kv_cache_interface import MambaSpec -def _make_vllm_config( - max_model_len: int, - max_num_seqs: int, - num_speculative_tokens: int = 0, - block_size: int | None = None, +class _ConcreteMambaBuilder( + BaseMambaAttentionMetadataBuilder[BaseMambaAttentionMetadata] ): + """Minimal concrete subclass for testing (base class is ABC).""" + + metadata_cls = BaseMambaAttentionMetadata + + +def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0): """Create a minimal mock VllmConfig with only the fields the builder accesses, avoiding any model download / HF config inspection.""" speculative_config = ( @@ -39,10 +44,7 @@ def _make_vllm_config( else None ) return SimpleNamespace( - cache_config=SimpleNamespace( - block_size=block_size, - mamba_cache_mode="all", - ), + cache_config=SimpleNamespace(mamba_cache_mode="all"), compilation_config=SimpleNamespace( cudagraph_mode=CUDAGraphMode.FULL, max_cudagraph_capture_size=None, @@ -55,21 +57,6 @@ def _make_vllm_config( ) -def test_mamba_single_token_prompt_runs_as_prefill(): - seq_lens = [8, 9, 1] - config = _make_vllm_config(256, len(seq_lens), block_size=16) - metadata = MockMambaBuilder.build_mamba_metadata( - config, - seq_lens=seq_lens, - query_lens=[1] * len(seq_lens), - is_prefilling=[False, False, True], - ) - - assert metadata.num_decodes == 2 - assert metadata.num_prefills == 1 - assert metadata.has_initial_states_p.tolist() == [False] - - def test_update_block_table_copies_block_idx_to_persistent_buffers(): """update_block_table() must write block_idx tensors to the current builder's persistent buffers, not leave them pointing to a different @@ -90,8 +77,8 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): ) # Two builders simulating two KV cache groups with the same MambaSpec. - builder_a = MockMambaBuilder(spec, ["layer0"], vllm_config, device) - builder_b = MockMambaBuilder(spec, ["layer1"], vllm_config, device) + builder_a = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + builder_b = _ConcreteMambaBuilder(spec, ["layer1"], vllm_config, device) # Sanity: each builder has its own persistent buffer. assert ( @@ -201,7 +188,7 @@ def test_state_indices_tensor_d_includes_num_speculative_blocks(): num_speculative_blocks=num_speculative_blocks, ) - builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) expected_cols = (max_model_len // block_size) + num_speculative_blocks assert builder.state_indices_tensor_d.shape == (max_num_seqs, expected_cols) @@ -234,7 +221,7 @@ def test_block_idx_cudagraph_capture_padded_by_num_reqs(): num_speculative_blocks=2, ) - builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) builder.block_idx_last_scheduled_token.fill_(-1) builder.block_idx_last_computed_token.fill_(-1) @@ -312,7 +299,7 @@ def test_block_idx_prev_step_persistent_buffer_allocated(): mamba_cache_mode="all", num_speculative_blocks=2, ) - builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) assert hasattr(builder, "block_idx_last_scheduled_token_prev_step") assert builder.block_idx_last_scheduled_token_prev_step.shape == (max_num_seqs,) @@ -336,7 +323,7 @@ def test_block_idx_prev_step_persistent_buffer_skipped_without_spec_decode(): dtypes=(torch.float32,), mamba_cache_mode="all", ) - builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) assert not hasattr(builder, "block_idx_last_scheduled_token_prev_step") @@ -364,7 +351,7 @@ def test_block_idx_prev_step_cudagraph_capture_uses_persistent_buffer(): mamba_cache_mode="all", num_speculative_blocks=2, ) - builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) builder.block_idx_last_scheduled_token.fill_(-1) builder.block_idx_last_computed_token.fill_(-1) builder.block_idx_last_scheduled_token_prev_step.fill_(-1) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 3daafcdc6ff8..1d5eba74693a 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -24,16 +24,8 @@ AttentionType, CommonAttentionMetadata, ) -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadata, - BaseMambaAttentionMetadataBuilder, -) from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.kv_cache_interface import ( - EncoderOnlyAttentionSpec, - FullAttentionSpec, - MambaSpec, -) +from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, FullAttentionSpec @dataclass @@ -384,34 +376,3 @@ class BackendConfig: }, ), } - - -class MockMambaBuilder(BaseMambaAttentionMetadataBuilder[BaseMambaAttentionMetadata]): - """Minimal concrete subclass for testing (base class is ABC).""" - - metadata_cls = BaseMambaAttentionMetadata - - @classmethod - def build_mamba_metadata( - cls, - vllm_config: VllmConfig, - seq_lens: list[int], - query_lens: list[int], - is_prefilling: list[bool], - *, - device: torch.device | None = None, - ) -> BaseMambaAttentionMetadata: - block_size = vllm_config.cache_config.block_size - device = device or torch.device("cpu") - mamba_spec = MambaSpec( - block_size=block_size, shapes=((1,), (1,)), dtypes=(torch.float32,) - ) - builder = cls(mamba_spec, ["layer0"], vllm_config, device) - batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens) - common_metadata = create_common_attn_metadata( - batch_spec, block_size=block_size, device=device, arange_block_indices=True - ) - common_metadata = common_metadata.replace( - is_prefilling=torch.tensor(is_prefilling, dtype=torch.bool) - ) - return builder.build(0, common_metadata) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 9c163fdf327f..c80dbfc62842 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -8,7 +8,6 @@ import pytest import torch -from tests.v1.attention.utils import MockMambaBuilder from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig from vllm.v1.core.single_type_kv_cache_manager import ( @@ -637,30 +636,6 @@ def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count): assert is_async is True -@pytest.mark.cpu_test -def test_mamba_n1_d_side_builds_decode_metadata(): - req = create_request(num_tokens=10, do_remote_prefill=True) - sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True) - - num_computed_tokens, is_async = sched.get_num_new_matched_tokens( - req, num_computed_tokens=0 - ) - - assert num_computed_tokens == req.num_prompt_tokens - 1 - assert is_async is True - - vllm_config = create_vllm_config() - metadata = MockMambaBuilder.build_mamba_metadata( - vllm_config, - seq_lens=[req.num_prompt_tokens], - query_lens=[1], - is_prefilling=[True], - ) - - assert metadata.num_decodes == 1 - assert metadata.num_prefills == 0 - - @pytest.mark.cpu_test def test_mamba_n1_p_side_truncation(): """P-side: Mamba truncates prompt to N-1, sets max_tokens=1. diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 16e292e21d2f..2af9c74ea765 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -385,28 +385,6 @@ def _compute_common_metadata( self.reorder_batch_threshold if num_accepted_tokens is not None else 1 ) - # FULL-CG dispatch is shape-based, so one-token prefills with - # prior Mamba state can replay a decode graph while `is_prefilling` - # is still true. Treat them as decode/update rows. This is required - # for NIXL disagg's h(N-1)->N recompute path and for sporadic - # final single-token prefill chunks that land in a `uniform` FULL-CG - # batch. Relies on `reorder` putting short extends before pure prefills. - is_prefilling = common_attn_metadata.is_prefilling - assert is_prefilling is not None - seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound - assert seq_lens_cpu is not None - query_lens_cpu = torch.diff(common_attn_metadata.query_start_loc_cpu) - single_token_prefill_rows = is_prefilling & (query_lens_cpu == 1) - # First-token prefills have no prior Mamba state and must stay prefills. - has_prior_state = seq_lens_cpu > 1 - prefill_to_decode = single_token_prefill_rows & has_prior_state - if torch.any(prefill_to_decode).item(): - is_prefilling = is_prefilling.clone() - is_prefilling[prefill_to_decode] = False - common_attn_metadata = common_attn_metadata.replace( - is_prefilling=is_prefilling - ) - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata,