Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 68 additions & 27 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
)
from vllm import _custom_ops as ops
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention.mla_attention import QueryLenSupport
from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport,
_DecodeConcatQuantFP8,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import CommonAttentionMetadata
Expand Down Expand Up @@ -50,6 +55,7 @@
if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)


SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
builder_cls, _ = try_get_attention_backend(backend)
Expand Down Expand Up @@ -144,9 +150,8 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: Common attention metadata
randomize_blocks: Whether to randomly permute blocks
or use sequential order
kv_cache_dtype: Optional kv cache dtype string. When set to
"fp8_ds_mla" the cache is populated using the
fp8 DeepSeek MLA layout via concat_and_cache_mla.
kv_cache_dtype: Optional kv cache dtype string. For fp8 cache dtype,
the cache is populated via concat_and_cache_mla.
scale: Scaling factor forwarded to concat_and_cache_mla when the
fp8 cache layout is requested.

Expand All @@ -163,18 +168,21 @@ def create_and_prepopulate_kv_cache(
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping

fp8_attention = kv_cache_dtype and kv_cache_dtype.startswith("fp8")
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"

if use_fp8_ds_mla:
if not kv_c_contexts:
raise ValueError(
"kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype"
)
kv_lora_rank = kv_c_contexts[0].shape[-1]
rope_dim = k_pe_contexts[0].shape[-1]
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
if fp8_attention:
if use_fp8_ds_mla:
kv_lora_rank = kv_c_contexts[0].shape[-1]
rope_dim = k_pe_contexts[0].shape[-1]
# 4 * 4: 4 float32 scale values for 128-element tiles
# 2 * rope_dim: 16-bit RoPE values
kv_entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
else:
kv_entry_size = head_size

kv_cache = torch.zeros(
num_blocks, block_size, entry_size, dtype=torch.uint8, device=device
num_blocks, block_size, kv_entry_size, dtype=torch.uint8, device=device
)
scale_tensor = (
scale
Expand All @@ -201,14 +209,14 @@ def create_and_prepopulate_kv_cache(

start = start_block_idx * block_size

if use_fp8_ds_mla:
if fp8_attention:
slots = torch.arange(context_len, device=device, dtype=torch.long) + start
ops.concat_and_cache_mla(
kv_c_context,
k_pe_context.squeeze(1),
kv_cache,
slots,
kv_cache_dtype="fp8_ds_mla",
kv_cache_dtype=kv_cache_dtype,
scale=scale_tensor,
)
else:
Expand Down Expand Up @@ -329,8 +337,9 @@ def forward_impl(
output: torch.Tensor,
) -> torch.Tensor:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
# Write to KV cache
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")

# Write to KV cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
kv_c,
Expand Down Expand Up @@ -426,6 +435,12 @@ def __init__(
self._k_scale_float = 1.0
self._v_scale_float = 1.0

self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)

def get_attn_backend(self):
raise NotImplementedError

Expand All @@ -443,16 +458,21 @@ def forward_impl(
) -> torch.Tensor:
"""Replicates MLAAttention.forward_impl logic for testing."""
# Write to KV cache
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
fp8_attention = kv_cache_dtype.startswith("fp8")
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
kv_c,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype="auto",
kv_cache_dtype=kv_cache_dtype,
scale=self._k_scale,
)

if fp8_attention and kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype())

# Determine decode vs prefill split
num_decode_tokens = attn_metadata.num_decode_tokens or 0
has_decode = (attn_metadata.num_decodes or 0) > 0
Expand Down Expand Up @@ -491,8 +511,14 @@ def forward_impl(
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

# Pass as tuple to forward_mqa
mqa_q = (mqa_ql_nope, mqa_q_pe)
if fp8_attention and self.impl.supports_quant_query_input:
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op(
mqa_ql_nope, mqa_q_pe, self._q_scale
)
else:
mqa_q = (mqa_ql_nope, mqa_q_pe)

attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

Expand Down Expand Up @@ -526,6 +552,7 @@ def run_attention_backend(
qk_rope_head_dim: int,
v_head_dim: int,
mock_kv_b_proj,
kv_cache_dtype: str = "auto",
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""

Expand All @@ -550,7 +577,7 @@ def run_attention_backend(
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
kv_cache_dtype=kv_cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
Expand Down Expand Up @@ -630,12 +657,14 @@ def run_attention_backend(
)
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
def test_backend_correctness(
default_vllm_config,
dist_init,
batch_spec_name: str,
model: str,
tensor_parallel_size: int,
kv_cache_dtype: str,
):
"""
Test that all backends produce similar outputs to a reference implementation
Expand All @@ -658,9 +687,18 @@ def test_backend_correctness(
head counts.
"""

# Filter backends to those that support the requested kv_cache_dtype
backends_to_test = [
b
for b in BACKENDS_TO_TEST
if kv_cache_dtype in b.get_class().supported_kv_cache_dtypes
]
if not backends_to_test:
pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}")

batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES[b] for b in backends_to_test))
default_block_size = unique_block_sizes[0]
required_blocks = sum(
(seq_len + default_block_size - 1) // default_block_size
Expand Down Expand Up @@ -694,6 +732,7 @@ def test_backend_correctness(
block_size=default_block_size,
hf_config_override=hf_config_override,
)
vllm_config.cache_config.cache_dtype = kv_cache_dtype

# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
if is_spec_decode_test:
Expand Down Expand Up @@ -751,7 +790,7 @@ def test_backend_correctness(

kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)

for i, backend in enumerate(BACKENDS_TO_TEST):
for i, backend in enumerate(backends_to_test):
all_sdpa_outputs.append([])

for i in range(batch_size):
Expand Down Expand Up @@ -785,7 +824,7 @@ def test_backend_correctness(
# pipeline (MHA-style). This ensures the reference implementation
# matches each backend's actual decode/prefill pipeline path.
is_decode = []
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend in enumerate(backends_to_test):
builder_cls, _ = try_get_attention_backend(backend)
if is_spec_decode_test:
query_len_support = getattr(
Expand Down Expand Up @@ -885,7 +924,7 @@ def test_backend_correctness(
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)

for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend in enumerate(backends_to_test):
if is_decode[backend_idx]:
all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
else:
Expand All @@ -905,7 +944,7 @@ def test_backend_correctness(
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_outputs = {}
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend in enumerate(backends_to_test):
sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)

# Create mock kv_b_proj using the same weights as reference implementation
Expand Down Expand Up @@ -973,12 +1012,13 @@ def test_backend_correctness(
num_blocks=num_blocks_for_size,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True,
kv_cache_dtype=kv_cache_dtype,
)
kv_cache_per_block_size[block_size] = kv_cache

# 4. Run vLLM backends and compare
failures = []
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend_name in enumerate(backends_to_test):
# Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
continue
Expand All @@ -997,7 +1037,7 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
cache_dtype_str=kv_cache_dtype,
)

backend_output = run_attention_backend(
Expand All @@ -1016,6 +1056,7 @@ def test_backend_correctness(
qk_rope_head_dim,
v_head_dim,
mock_kv_b_proj,
kv_cache_dtype=kv_cache_dtype,
)

# Use backend_idx to get the correct SDPA output for this backend
Expand Down