diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index 619abac952de..99dcb09ab154 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -16,23 +16,18 @@ import torch +from tests.v1.attention.utils import MockMambaBuilder from vllm.config.compilation import CUDAGraphMode -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadata, - BaseMambaAttentionMetadataBuilder, -) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata from vllm.v1.kv_cache_interface import MambaSpec -class _ConcreteMambaBuilder( - BaseMambaAttentionMetadataBuilder[BaseMambaAttentionMetadata] +def _make_vllm_config( + max_model_len: int, + max_num_seqs: int, + num_speculative_tokens: int = 0, + block_size: int | None = None, ): - """Minimal concrete subclass for testing (base class is ABC).""" - - metadata_cls = BaseMambaAttentionMetadata - - -def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0): """Create a minimal mock VllmConfig with only the fields the builder accesses, avoiding any model download / HF config inspection.""" speculative_config = ( @@ -44,7 +39,10 @@ def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0): else None ) return SimpleNamespace( - cache_config=SimpleNamespace(mamba_cache_mode="all"), + cache_config=SimpleNamespace( + block_size=block_size, + mamba_cache_mode="all", + ), compilation_config=SimpleNamespace( cudagraph_mode=CUDAGraphMode.FULL, max_cudagraph_capture_size=None, @@ -57,6 +55,21 @@ def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0): ) +def test_mamba_single_token_prompt_runs_as_prefill(): + seq_lens = [8, 9, 1] + config = _make_vllm_config(256, len(seq_lens), block_size=16) + metadata = MockMambaBuilder.build_mamba_metadata( + config, + seq_lens=seq_lens, + query_lens=[1] * len(seq_lens), + is_prefilling=[False, False, True], + ) + + assert metadata.num_decodes == 2 + assert metadata.num_prefills == 1 + assert metadata.has_initial_states_p.tolist() == [False] + + def test_update_block_table_copies_block_idx_to_persistent_buffers(): """update_block_table() must write block_idx tensors to the current builder's persistent buffers, not leave them pointing to a different @@ -77,8 +90,8 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): ) # Two builders simulating two KV cache groups with the same MambaSpec. - builder_a = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) - builder_b = _ConcreteMambaBuilder(spec, ["layer1"], vllm_config, device) + builder_a = MockMambaBuilder(spec, ["layer0"], vllm_config, device) + builder_b = MockMambaBuilder(spec, ["layer1"], vllm_config, device) # Sanity: each builder has its own persistent buffer. assert ( @@ -188,7 +201,7 @@ def test_state_indices_tensor_d_includes_num_speculative_blocks(): num_speculative_blocks=num_speculative_blocks, ) - builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) expected_cols = (max_model_len // block_size) + num_speculative_blocks assert builder.state_indices_tensor_d.shape == (max_num_seqs, expected_cols) @@ -221,7 +234,7 @@ def test_block_idx_cudagraph_capture_padded_by_num_reqs(): num_speculative_blocks=2, ) - builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) builder.block_idx_last_scheduled_token.fill_(-1) builder.block_idx_last_computed_token.fill_(-1) @@ -299,7 +312,7 @@ def test_block_idx_prev_step_persistent_buffer_allocated(): mamba_cache_mode="all", num_speculative_blocks=2, ) - builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) assert hasattr(builder, "block_idx_last_scheduled_token_prev_step") assert builder.block_idx_last_scheduled_token_prev_step.shape == (max_num_seqs,) @@ -323,7 +336,7 @@ def test_block_idx_prev_step_persistent_buffer_skipped_without_spec_decode(): dtypes=(torch.float32,), mamba_cache_mode="all", ) - builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) assert not hasattr(builder, "block_idx_last_scheduled_token_prev_step") @@ -351,7 +364,7 @@ def test_block_idx_prev_step_cudagraph_capture_uses_persistent_buffer(): mamba_cache_mode="all", num_speculative_blocks=2, ) - builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device) + builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device) builder.block_idx_last_scheduled_token.fill_(-1) builder.block_idx_last_computed_token.fill_(-1) builder.block_idx_last_scheduled_token_prev_step.fill_(-1) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 1d5eba74693a..3daafcdc6ff8 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -24,8 +24,16 @@ AttentionType, CommonAttentionMetadata, ) +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadata, + BaseMambaAttentionMetadataBuilder, +) from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, FullAttentionSpec +from vllm.v1.kv_cache_interface import ( + EncoderOnlyAttentionSpec, + FullAttentionSpec, + MambaSpec, +) @dataclass @@ -376,3 +384,34 @@ class BackendConfig: }, ), } + + +class MockMambaBuilder(BaseMambaAttentionMetadataBuilder[BaseMambaAttentionMetadata]): + """Minimal concrete subclass for testing (base class is ABC).""" + + metadata_cls = BaseMambaAttentionMetadata + + @classmethod + def build_mamba_metadata( + cls, + vllm_config: VllmConfig, + seq_lens: list[int], + query_lens: list[int], + is_prefilling: list[bool], + *, + device: torch.device | None = None, + ) -> BaseMambaAttentionMetadata: + block_size = vllm_config.cache_config.block_size + device = device or torch.device("cpu") + mamba_spec = MambaSpec( + block_size=block_size, shapes=((1,), (1,)), dtypes=(torch.float32,) + ) + builder = cls(mamba_spec, ["layer0"], vllm_config, device) + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens) + common_metadata = create_common_attn_metadata( + batch_spec, block_size=block_size, device=device, arange_block_indices=True + ) + common_metadata = common_metadata.replace( + is_prefilling=torch.tensor(is_prefilling, dtype=torch.bool) + ) + return builder.build(0, common_metadata) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index c80dbfc62842..9c163fdf327f 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -8,6 +8,7 @@ import pytest import torch +from tests.v1.attention.utils import MockMambaBuilder from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig from vllm.v1.core.single_type_kv_cache_manager import ( @@ -636,6 +637,30 @@ def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count): assert is_async is True +@pytest.mark.cpu_test +def test_mamba_n1_d_side_builds_decode_metadata(): + req = create_request(num_tokens=10, do_remote_prefill=True) + sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True) + + num_computed_tokens, is_async = sched.get_num_new_matched_tokens( + req, num_computed_tokens=0 + ) + + assert num_computed_tokens == req.num_prompt_tokens - 1 + assert is_async is True + + vllm_config = create_vllm_config() + metadata = MockMambaBuilder.build_mamba_metadata( + vllm_config, + seq_lens=[req.num_prompt_tokens], + query_lens=[1], + is_prefilling=[True], + ) + + assert metadata.num_decodes == 1 + assert metadata.num_prefills == 0 + + @pytest.mark.cpu_test def test_mamba_n1_p_side_truncation(): """P-side: Mamba truncates prompt to N-1, sets max_tokens=1. diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 2af9c74ea765..16e292e21d2f 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -385,6 +385,28 @@ def _compute_common_metadata( self.reorder_batch_threshold if num_accepted_tokens is not None else 1 ) + # FULL-CG dispatch is shape-based, so one-token prefills with + # prior Mamba state can replay a decode graph while `is_prefilling` + # is still true. Treat them as decode/update rows. This is required + # for NIXL disagg's h(N-1)->N recompute path and for sporadic + # final single-token prefill chunks that land in a `uniform` FULL-CG + # batch. Relies on `reorder` putting short extends before pure prefills. + is_prefilling = common_attn_metadata.is_prefilling + assert is_prefilling is not None + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + assert seq_lens_cpu is not None + query_lens_cpu = torch.diff(common_attn_metadata.query_start_loc_cpu) + single_token_prefill_rows = is_prefilling & (query_lens_cpu == 1) + # First-token prefills have no prior Mamba state and must stay prefills. + has_prior_state = seq_lens_cpu > 1 + prefill_to_decode = single_token_prefill_rows & has_prior_state + if torch.any(prefill_to_decode).item(): + is_prefilling = is_prefilling.clone() + is_prefilling[prefill_to_decode] = False + common_attn_metadata = common_attn_metadata.replace( + is_prefilling=is_prefilling + ) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata,