Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions tests/v1/attention/test_mamba_update_block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@

import torch

from tests.v1.attention.utils import MockMambaBuilder
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata
from vllm.v1.kv_cache_interface import MambaSpec


class _ConcreteMambaBuilder(
BaseMambaAttentionMetadataBuilder[BaseMambaAttentionMetadata]
def _make_vllm_config(
max_model_len: int,
max_num_seqs: int,
num_speculative_tokens: int = 0,
block_size: int | None = None,
):
"""Minimal concrete subclass for testing (base class is ABC)."""

metadata_cls = BaseMambaAttentionMetadata


def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0):
"""Create a minimal mock VllmConfig with only the fields the builder
accesses, avoiding any model download / HF config inspection."""
speculative_config = (
Expand All @@ -44,7 +39,10 @@ def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0):
else None
)
return SimpleNamespace(
cache_config=SimpleNamespace(mamba_cache_mode="all"),
cache_config=SimpleNamespace(
block_size=block_size,
mamba_cache_mode="all",
),
compilation_config=SimpleNamespace(
cudagraph_mode=CUDAGraphMode.FULL,
max_cudagraph_capture_size=None,
Expand All @@ -57,6 +55,21 @@ def _make_vllm_config(max_model_len, max_num_seqs, num_speculative_tokens=0):
)


def test_mamba_single_token_prompt_runs_as_prefill():
seq_lens = [8, 9, 1]
config = _make_vllm_config(256, len(seq_lens), block_size=16)
metadata = MockMambaBuilder.build_mamba_metadata(
config,
seq_lens=seq_lens,
query_lens=[1] * len(seq_lens),
is_prefilling=[False, False, True],
)

assert metadata.num_decodes == 2
assert metadata.num_prefills == 1
assert metadata.has_initial_states_p.tolist() == [False]


def test_update_block_table_copies_block_idx_to_persistent_buffers():
"""update_block_table() must write block_idx tensors to the current
builder's persistent buffers, not leave them pointing to a different
Expand All @@ -77,8 +90,8 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers():
)

# Two builders simulating two KV cache groups with the same MambaSpec.
builder_a = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device)
builder_b = _ConcreteMambaBuilder(spec, ["layer1"], vllm_config, device)
builder_a = MockMambaBuilder(spec, ["layer0"], vllm_config, device)
builder_b = MockMambaBuilder(spec, ["layer1"], vllm_config, device)

# Sanity: each builder has its own persistent buffer.
assert (
Expand Down Expand Up @@ -188,7 +201,7 @@ def test_state_indices_tensor_d_includes_num_speculative_blocks():
num_speculative_blocks=num_speculative_blocks,
)

builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device)
builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device)

expected_cols = (max_model_len // block_size) + num_speculative_blocks
assert builder.state_indices_tensor_d.shape == (max_num_seqs, expected_cols)
Expand Down Expand Up @@ -221,7 +234,7 @@ def test_block_idx_cudagraph_capture_padded_by_num_reqs():
num_speculative_blocks=2,
)

builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device)
builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device)

builder.block_idx_last_scheduled_token.fill_(-1)
builder.block_idx_last_computed_token.fill_(-1)
Expand Down Expand Up @@ -299,7 +312,7 @@ def test_block_idx_prev_step_persistent_buffer_allocated():
mamba_cache_mode="all",
num_speculative_blocks=2,
)
builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device)
builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device)

assert hasattr(builder, "block_idx_last_scheduled_token_prev_step")
assert builder.block_idx_last_scheduled_token_prev_step.shape == (max_num_seqs,)
Expand All @@ -323,7 +336,7 @@ def test_block_idx_prev_step_persistent_buffer_skipped_without_spec_decode():
dtypes=(torch.float32,),
mamba_cache_mode="all",
)
builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device)
builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device)

assert not hasattr(builder, "block_idx_last_scheduled_token_prev_step")

Expand Down Expand Up @@ -351,7 +364,7 @@ def test_block_idx_prev_step_cudagraph_capture_uses_persistent_buffer():
mamba_cache_mode="all",
num_speculative_blocks=2,
)
builder = _ConcreteMambaBuilder(spec, ["layer0"], vllm_config, device)
builder = MockMambaBuilder(spec, ["layer0"], vllm_config, device)
builder.block_idx_last_scheduled_token.fill_(-1)
builder.block_idx_last_computed_token.fill_(-1)
builder.block_idx_last_scheduled_token_prev_step.fill_(-1)
Expand Down
41 changes: 40 additions & 1 deletion tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, FullAttentionSpec
from vllm.v1.kv_cache_interface import (
EncoderOnlyAttentionSpec,
FullAttentionSpec,
MambaSpec,
)


@dataclass
Expand Down Expand Up @@ -376,3 +384,34 @@ class BackendConfig:
},
),
}


class MockMambaBuilder(BaseMambaAttentionMetadataBuilder[BaseMambaAttentionMetadata]):
"""Minimal concrete subclass for testing (base class is ABC)."""

metadata_cls = BaseMambaAttentionMetadata

@classmethod
def build_mamba_metadata(
cls,
vllm_config: VllmConfig,
seq_lens: list[int],
query_lens: list[int],
is_prefilling: list[bool],
*,
device: torch.device | None = None,
) -> BaseMambaAttentionMetadata:
block_size = vllm_config.cache_config.block_size
device = device or torch.device("cpu")
mamba_spec = MambaSpec(
block_size=block_size, shapes=((1,), (1,)), dtypes=(torch.float32,)
)
builder = cls(mamba_spec, ["layer0"], vllm_config, device)
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
common_metadata = create_common_attn_metadata(
batch_spec, block_size=block_size, device=device, arange_block_indices=True
)
common_metadata = common_metadata.replace(
is_prefilling=torch.tensor(is_prefilling, dtype=torch.bool)
)
return builder.build(0, common_metadata)
25 changes: 25 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch

from tests.v1.attention.utils import MockMambaBuilder
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.v1.core.single_type_kv_cache_manager import (
Expand Down Expand Up @@ -636,6 +637,30 @@ def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count):
assert is_async is True


@pytest.mark.cpu_test
def test_mamba_n1_d_side_builds_decode_metadata():
req = create_request(num_tokens=10, do_remote_prefill=True)
sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True)

num_computed_tokens, is_async = sched.get_num_new_matched_tokens(
req, num_computed_tokens=0
)

assert num_computed_tokens == req.num_prompt_tokens - 1
assert is_async is True

vllm_config = create_vllm_config()
metadata = MockMambaBuilder.build_mamba_metadata(
vllm_config,
seq_lens=[req.num_prompt_tokens],
query_lens=[1],
is_prefilling=[True],
)

assert metadata.num_decodes == 1
assert metadata.num_prefills == 0


@pytest.mark.cpu_test
def test_mamba_n1_p_side_truncation():
"""P-side: Mamba truncates prompt to N-1, sets max_tokens=1.
Expand Down
22 changes: 22 additions & 0 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,28 @@ def _compute_common_metadata(
self.reorder_batch_threshold if num_accepted_tokens is not None else 1
)

# FULL-CG dispatch is shape-based, so one-token prefills with
# prior Mamba state can replay a decode graph while `is_prefilling`
# is still true. Treat them as decode/update rows. This is required
# for NIXL disagg's h(N-1)->N recompute path and for sporadic
# final single-token prefill chunks that land in a `uniform` FULL-CG
# batch. Relies on `reorder` putting short extends before pure prefills.
is_prefilling = common_attn_metadata.is_prefilling
assert is_prefilling is not None
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
assert seq_lens_cpu is not None
query_lens_cpu = torch.diff(common_attn_metadata.query_start_loc_cpu)
single_token_prefill_rows = is_prefilling & (query_lens_cpu == 1)
# First-token prefills have no prior Mamba state and must stay prefills.
has_prior_state = seq_lens_cpu > 1
prefill_to_decode = single_token_prefill_rows & has_prior_state
if torch.any(prefill_to_decode).item():

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling item here may cause cpu synchronization

is_prefilling = is_prefilling.clone()
is_prefilling[prefill_to_decode] = False
common_attn_metadata = common_attn_metadata.replace(
is_prefilling=is_prefilling
)

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
Expand Down
Loading