Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 247 additions & 10 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,18 +274,241 @@ def forward(self, *_args, **_kwargs):
raise NotImplementedError


class MockSparseMLAAttentionLayer:
"""A mock sparse MLA attention layer for testing.

Sparse MLA implementations only support forward_mqa (decode-style attention)
for all tokens, so this class only implements that path.

Unlike regular MLA impls, sparse MLA impls don't have W_UK_T and W_UV
attributes. These transformations are done by the layer (MLAAttention),
not the impl. This mock layer accepts these weight matrices directly.
"""

def __init__(
self,
impl,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
device: torch.device,
W_UK: torch.Tensor,
W_UV: torch.Tensor,
):
self.impl = impl
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.kv_lora_rank = kv_lora_rank

# Compute weight matrices in the format expected by forward_impl
# W_UK shape: (L, N, P) -> W_UK_T shape: (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
# W_UV shape: (L, N, V) -> (N, L, V)
self.W_UV = W_UV.transpose(0, 1)

# Scale attributes needed by attention backends
self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0

def forward_impl(
self,
q: torch.Tensor,
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata,
output: torch.Tensor,
) -> torch.Tensor:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
# Write to KV cache
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
kv_c,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=self._k_scale,
)

num_tokens = q.shape[0]

# Sparse MLA uses forward_mqa for all tokens
# Split q into nope and pe parts
mqa_q_nope, mqa_q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

# Convert from (B, N, P) to (N, B, P)
mqa_q_nope = mqa_q_nope.transpose(0, 1)

# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
mqa_ql_nope = torch.bmm(mqa_q_nope, self.W_UK_T)

# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

# Pass as tuple to forward_mqa
mqa_q = (mqa_ql_nope, mqa_q_pe)

attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

# v_up projection: multiply by W_UV
# attn_out shape: (B, N, L) where L = kv_lora_rank
# W_UV shape: (N, L, V)
# output shape: (B, N, V) -> flatten to (B, N*V)
decode_output = torch.bmm(attn_out.transpose(0, 1), self.W_UV).transpose(0, 1)
output[:num_tokens] = decode_output.reshape(
num_tokens, self.num_heads * self.v_head_dim
)

return output


class MockMLAAttentionLayer(AttentionLayerBase):
"""A mock MLA attention layer for populating static_forward_context."""
"""A mock MLA attention layer for testing.

This replicates the forward_impl logic from MLAAttention to allow
testing MLA backends without the full layer infrastructure.

The W_UK_T and W_UV weight matrices are created on the layer (like in
MLAAttention.process_weights_after_loading), not on the impl.
"""

def __init__(self, impl):
def __init__(
self,
impl,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
device: torch.device,
kv_b_proj,
):
self.impl = impl
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.kv_lora_rank = kv_lora_rank

# Compute weight matrices from kv_b_proj (like MLAAttention does)
# This replicates MLAAttention.process_weights_after_loading logic
kv_b_proj_weight = kv_b_proj.weight.T
kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank,
num_heads,
qk_nope_head_dim + v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split([qk_nope_head_dim, v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)

# Scale attributes needed by attention backends
self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0

def get_attn_backend(self):
raise NotImplementedError

def get_kv_cache_spec(self, vllm_config):
raise NotImplementedError

def forward_impl(
self,
q: torch.Tensor,
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata,
output: torch.Tensor,
) -> torch.Tensor:
"""Replicates MLAAttention.forward_impl logic for testing."""
# Write to KV cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
kv_c,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype="auto",
scale=self._k_scale,
)

# Determine decode vs prefill split
num_decode_tokens = attn_metadata.num_decode_tokens or 0
has_decode = (attn_metadata.num_decodes or 0) > 0
has_prefill = (attn_metadata.num_prefills or 0) > 0

# Run prefill with forward_mha
if has_prefill:
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c = kv_c[num_decode_tokens:]
self.impl.forward_mha(
prefill_q,
prefill_k_c,
prefill_k_pe,
kv_cache,
attn_metadata,
self._k_scale,
output=output[num_decode_tokens:],
)

# Run decode with forward_mqa
if has_decode:
decode_q = q[:num_decode_tokens]

# Split q into nope and pe parts
mqa_q_nope, mqa_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

# Convert from (B, N, P) to (N, B, P)
mqa_q_nope = mqa_q_nope.transpose(0, 1)

# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
mqa_ql_nope = torch.bmm(mqa_q_nope, self.W_UK_T)

# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

# Pass as tuple to forward_mqa
mqa_q = (mqa_ql_nope, mqa_q_pe)

attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

# v_up projection: multiply by W_UV
# attn_out shape: (B, N, L) where L = kv_lora_rank
# W_UV shape: (N, L, V)
# output shape: (B, N, V) -> flatten to (B, N*V)
decode_output = torch.bmm(attn_out.transpose(0, 1), self.W_UV).transpose(
0, 1
)
output[:num_decode_tokens] = decode_output.reshape(
num_decode_tokens, self.num_heads * self.v_head_dim
)

return output


def run_attention_backend(
backend: AttentionBackendEnum,
Expand Down Expand Up @@ -340,14 +563,31 @@ def run_attention_backend(
kv_b_proj=mock_kv_b_proj,
)

# Process weights to create W_UK_T and W_UV attributes needed by MLA
# Process weights on the impl
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype)

# Initialize DCP attributes (normally set by MLAAttention.forward
# before calling forward_mha, see mla_attention.py:511-512)
if impl.dcp_world_size == -1:
impl.dcp_world_size = 1

# Create mock MLA layer
mock_layer = MockMLAAttentionLayer(
impl=impl,
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
device=device,
kv_b_proj=mock_kv_b_proj,
)

# Populate static_forward_context with mock attention layers
for layer_name in layer_names:
vllm_config.compilation_config.static_forward_context[layer_name] = (
MockMLAAttentionLayer(impl)
mock_layer
)

# Build metadata
Expand All @@ -357,18 +597,15 @@ def run_attention_backend(
common_attn_metadata=common_attn_metadata,
)

# Create mock layer and output buffer
mock_layer = MockAttentionLayer(device)
# Create output buffer
num_tokens = query.shape[0]
output = torch.empty(
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
)

# Run forward pass
# NOTE: The query, key, and value are already shaped correctly
# in the calling test function.
output = impl.forward(
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
output = mock_layer.forward_impl(
query, kv_c, k_pe, kv_cache, attn_metadata, output
)

return output
Expand Down
21 changes: 16 additions & 5 deletions tests/v1/attention/test_sparse_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tests.v1.attention.test_mla_backends import (
BATCH_SPECS,
BatchSpec,
MockAttentionLayer,
MockSparseMLAAttentionLayer,
create_and_prepopulate_kv_cache,
)
from tests.v1.attention.utils import (
Expand Down Expand Up @@ -408,20 +408,31 @@ def test_sparse_backend_decode_correctness(

impl.process_weights_after_loading(dtype)

layer = MockAttentionLayer(device)
# Create mock sparse MLA layer with weight matrices
mock_layer = MockSparseMLAAttentionLayer(
impl=impl,
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
device=device,
W_UK=W_UK,
W_UV=W_UV,
)

out_buffer = torch.empty(
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
)

with torch.inference_mode():
backend_output = impl.forward(
layer,
backend_output = mock_layer.forward_impl(
query_vllm,
kv_c_vllm,
k_pe_vllm,
kv_cache,
metadata,
output=out_buffer,
out_buffer,
)

assert backend_output.shape == sdpa_reference.shape
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import torch
import torch.nn as nn
Expand Down Expand Up @@ -562,7 +562,7 @@ def maybe_calc_kv_scales_fake(

def get_attention_context(
layer_name: str,
) -> tuple[dict | object | None, "Attention | MLAAttention", torch.Tensor]:
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor]:
"""Extract attention context for a given layer.

This helper function extracts the attention metadata, attention layer
Expand Down
Loading