diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 0744db0b0d66..815274e1cca1 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -274,11 +274,157 @@ def forward(self, *_args, **_kwargs): raise NotImplementedError +class MockSparseMLAAttentionLayer: + """A mock sparse MLA attention layer for testing. + + Sparse MLA implementations only support forward_mqa (decode-style attention) + for all tokens, so this class only implements that path. + + Unlike regular MLA impls, sparse MLA impls don't have W_UK_T and W_UV + attributes. These transformations are done by the layer (MLAAttention), + not the impl. This mock layer accepts these weight matrices directly. + """ + + def __init__( + self, + impl, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + kv_lora_rank: int, + device: torch.device, + W_UK: torch.Tensor, + W_UV: torch.Tensor, + ): + self.impl = impl + self.num_heads = num_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.kv_lora_rank = kv_lora_rank + + # Compute weight matrices in the format expected by forward_impl + # W_UK shape: (L, N, P) -> W_UK_T shape: (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + # W_UV shape: (L, N, V) -> (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + + # Scale attributes needed by attention backends + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + self._prob_scale = torch.tensor(1.0, device=device) + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + def forward_impl( + self, + q: torch.Tensor, + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata, + output: torch.Tensor, + ) -> torch.Tensor: + """Forward for sparse MLA - uses forward_mqa for all tokens.""" + # Write to KV cache + kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto") + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + kv_c, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=kv_cache_dtype, + scale=self._k_scale, + ) + + num_tokens = q.shape[0] + + # Sparse MLA uses forward_mqa for all tokens + # Split q into nope and pe parts + mqa_q_nope, mqa_q_pe = q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Convert from (B, N, P) to (N, B, P) + mqa_q_nope = mqa_q_nope.transpose(0, 1) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + mqa_ql_nope = torch.bmm(mqa_q_nope, self.W_UK_T) + + # Convert from (N, B, L) to (B, N, L) + mqa_ql_nope = mqa_ql_nope.transpose(0, 1) + + # Pass as tuple to forward_mqa + mqa_q = (mqa_ql_nope, mqa_q_pe) + + attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) + + # v_up projection: multiply by W_UV + # attn_out shape: (B, N, L) where L = kv_lora_rank + # W_UV shape: (N, L, V) + # output shape: (B, N, V) -> flatten to (B, N*V) + decode_output = torch.bmm(attn_out.transpose(0, 1), self.W_UV).transpose(0, 1) + output[:num_tokens] = decode_output.reshape( + num_tokens, self.num_heads * self.v_head_dim + ) + + return output + + class MockMLAAttentionLayer(AttentionLayerBase): - """A mock MLA attention layer for populating static_forward_context.""" + """A mock MLA attention layer for testing. + + This replicates the forward_impl logic from MLAAttention to allow + testing MLA backends without the full layer infrastructure. + + The W_UK_T and W_UV weight matrices are created on the layer (like in + MLAAttention.process_weights_after_loading), not on the impl. + """ - def __init__(self, impl): + def __init__( + self, + impl, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + kv_lora_rank: int, + device: torch.device, + kv_b_proj, + ): self.impl = impl + self.num_heads = num_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.kv_lora_rank = kv_lora_rank + + # Compute weight matrices from kv_b_proj (like MLAAttention does) + # This replicates MLAAttention.process_weights_after_loading logic + kv_b_proj_weight = kv_b_proj.weight.T + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, + num_heads, + qk_nope_head_dim + v_head_dim, + ) + W_UK, W_UV = kv_b_proj_weight.split([qk_nope_head_dim, v_head_dim], dim=-1) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + # Scale attributes needed by attention backends + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + self._prob_scale = torch.tensor(1.0, device=device) + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 def get_attn_backend(self): raise NotImplementedError @@ -286,6 +432,83 @@ def get_attn_backend(self): def get_kv_cache_spec(self, vllm_config): raise NotImplementedError + def forward_impl( + self, + q: torch.Tensor, + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata, + output: torch.Tensor, + ) -> torch.Tensor: + """Replicates MLAAttention.forward_impl logic for testing.""" + # Write to KV cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + kv_c, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype="auto", + scale=self._k_scale, + ) + + # Determine decode vs prefill split + num_decode_tokens = attn_metadata.num_decode_tokens or 0 + has_decode = (attn_metadata.num_decodes or 0) > 0 + has_prefill = (attn_metadata.num_prefills or 0) > 0 + + # Run prefill with forward_mha + if has_prefill: + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c = kv_c[num_decode_tokens:] + self.impl.forward_mha( + prefill_q, + prefill_k_c, + prefill_k_pe, + kv_cache, + attn_metadata, + self._k_scale, + output=output[num_decode_tokens:], + ) + + # Run decode with forward_mqa + if has_decode: + decode_q = q[:num_decode_tokens] + + # Split q into nope and pe parts + mqa_q_nope, mqa_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Convert from (B, N, P) to (N, B, P) + mqa_q_nope = mqa_q_nope.transpose(0, 1) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + mqa_ql_nope = torch.bmm(mqa_q_nope, self.W_UK_T) + + # Convert from (N, B, L) to (B, N, L) + mqa_ql_nope = mqa_ql_nope.transpose(0, 1) + + # Pass as tuple to forward_mqa + mqa_q = (mqa_ql_nope, mqa_q_pe) + + attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) + + # v_up projection: multiply by W_UV + # attn_out shape: (B, N, L) where L = kv_lora_rank + # W_UV shape: (N, L, V) + # output shape: (B, N, V) -> flatten to (B, N*V) + decode_output = torch.bmm(attn_out.transpose(0, 1), self.W_UV).transpose( + 0, 1 + ) + output[:num_decode_tokens] = decode_output.reshape( + num_decode_tokens, self.num_heads * self.v_head_dim + ) + + return output + def run_attention_backend( backend: AttentionBackendEnum, @@ -340,14 +563,31 @@ def run_attention_backend( kv_b_proj=mock_kv_b_proj, ) - # Process weights to create W_UK_T and W_UV attributes needed by MLA + # Process weights on the impl act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) impl.process_weights_after_loading(act_dtype) + # Initialize DCP attributes (normally set by MLAAttention.forward + # before calling forward_mha, see mla_attention.py:511-512) + if impl.dcp_world_size == -1: + impl.dcp_world_size = 1 + + # Create mock MLA layer + mock_layer = MockMLAAttentionLayer( + impl=impl, + num_heads=num_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_lora_rank=kv_lora_rank, + device=device, + kv_b_proj=mock_kv_b_proj, + ) + # Populate static_forward_context with mock attention layers for layer_name in layer_names: vllm_config.compilation_config.static_forward_context[layer_name] = ( - MockMLAAttentionLayer(impl) + mock_layer ) # Build metadata @@ -357,18 +597,15 @@ def run_attention_backend( common_attn_metadata=common_attn_metadata, ) - # Create mock layer and output buffer - mock_layer = MockAttentionLayer(device) + # Create output buffer num_tokens = query.shape[0] output = torch.empty( num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device ) # Run forward pass - # NOTE: The query, key, and value are already shaped correctly - # in the calling test function. - output = impl.forward( - mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output + output = mock_layer.forward_impl( + query, kv_c, k_pe, kv_cache, attn_metadata, output ) return output diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 8d1f5cc46ba9..e4ffd12ca6ef 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -12,7 +12,7 @@ from tests.v1.attention.test_mla_backends import ( BATCH_SPECS, BatchSpec, - MockAttentionLayer, + MockSparseMLAAttentionLayer, create_and_prepopulate_kv_cache, ) from tests.v1.attention.utils import ( @@ -408,20 +408,31 @@ def test_sparse_backend_decode_correctness( impl.process_weights_after_loading(dtype) - layer = MockAttentionLayer(device) + # Create mock sparse MLA layer with weight matrices + mock_layer = MockSparseMLAAttentionLayer( + impl=impl, + num_heads=num_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_lora_rank=kv_lora_rank, + device=device, + W_UK=W_UK, + W_UV=W_UV, + ) + out_buffer = torch.empty( metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device ) with torch.inference_mode(): - backend_output = impl.forward( - layer, + backend_output = mock_layer.forward_impl( query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, - output=out_buffer, + out_buffer, ) assert backend_output.shape == sdpa_reference.shape diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 25917294ab7e..170de6a87034 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -562,7 +562,7 @@ def maybe_calc_kv_scales_fake( def get_attention_context( layer_name: str, -) -> tuple[dict | object | None, "Attention | MLAAttention", torch.Tensor]: +) -> tuple[Any, "Attention | MLAAttention", torch.Tensor]: """Extract attention context for a given layer. This helper function extracts the attention metadata, attention layer diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 8b5edc0d3b75..501b939c11b1 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -63,7 +63,7 @@ W_O project v to h_t shape [N * V, H] -## Compute Friendly Approach (i.e. "_forward_prefill"): +## Compute Friendly Approach (i.e. "forward_mha"): q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(Sq, N, P) @@ -91,7 +91,7 @@ `out_proj` is W_O -## Data-Movement Friendly Approach (i.e. "_forward_decode"): +## Data-Movement Friendly Approach (i.e. "forward_mqa"): Runtime q_c = h_t @ W_DQ @@ -243,6 +243,7 @@ AttentionType, CommonAttentionMetadata, MLAAttentionImpl, + SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.utils import ( @@ -266,6 +267,9 @@ class MLAAttention(nn.Module, AttentionLayerBase): """Multi-Head Latent Attention layer. + NOTE: Please read the comment at the top of the file before trying to + understand this class + This class takes query, and compressed key/value tensors as input. The class does the following: @@ -289,6 +293,7 @@ def __init__( prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + q_pad_num_heads: int | None = None, **extra_impl_args, ): super().__init__() @@ -299,8 +304,14 @@ def __init__( self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank + self.kv_b_proj = kv_b_proj self.head_size = kv_lora_rank + qk_rope_head_dim self.layer_name = prefix + self.indexer = indexer + self.q_pad_num_heads = q_pad_num_heads + + self.num_kv_heads = 1 + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype @@ -364,6 +375,7 @@ def __init__( v_head_dim=self.v_head_dim, kv_b_proj=kv_b_proj, indexer=indexer, + q_pad_num_heads=q_pad_num_heads, **extra_impl_args, ) @@ -388,6 +400,26 @@ def __init__( self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() + + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + self.is_aiter_triton_fp4_bmm_enabled = ( + rocm_aiter_ops.is_fp4bmm_enabled() + and self.kv_b_proj.weight.dtype == torch.bfloat16 + ) + + # Attributes for forward_impl method + self.chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config() + ) + ) + self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + compile_native=True, + ) + def forward( self, q: torch.Tensor, @@ -407,8 +439,7 @@ def forward( if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - self.impl.forward( - self, + self.forward_impl( q, kv_c_normed, k_pe, @@ -418,8 +449,8 @@ def forward( ) return output else: - return self.impl.forward( - self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata + return self.forward_impl( + q, kv_c_normed, k_pe, self_kv_cache, attn_metadata ) else: if self.attn_backend.accept_output_buffer: @@ -440,9 +471,282 @@ def forward( self.layer_name, ) + def forward_impl( + self, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: "MLACommonMetadata", + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for MLA" + ) + + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + if self.impl.dcp_world_size == -1: + self.impl.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_q = q[:num_decode_tokens] + + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + # Sparse MLA impls only support forward_mqa (decode-style attention) + is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl) + + if has_prefill and not is_sparse_impl: + self.impl.forward_mha( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + self._k_scale, + output=output[num_decode_tokens:], + ) + + if has_decode or (has_prefill and is_sparse_impl): + # For sparse impl, we always use forward_mqa for all tokens + # For non-sparse impl, we only use forward_mqa for decode tokens + if is_sparse_impl: + mqa_q = q + mqa_output_slice = output + else: + assert attn_metadata.decode is not None + mqa_q = decode_q + mqa_output_slice = output[:num_decode_tokens] + + mqa_q_nope, mqa_q_pe = mqa_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Convert from (B, N, P) to (N, B, P) + mqa_q_nope = mqa_q_nope.transpose(0, 1) + + if self.q_pad_num_heads is not None: + B, N, L = mqa_q_pe.shape + mqa_pe_padded = mqa_q_pe.new_empty((B, self.q_pad_num_heads, L)) + mqa_pe_padded.resize_((B, N, L)) + mqa_pe_padded.copy_(mqa_q_pe) + mqa_q_pe = mqa_pe_padded + + if self.is_aiter_triton_fp4_bmm_enabled: + from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + + mqa_ql_nope = batched_gemm_a16wfp4( + mqa_q_nope, + self.W_K, + self.W_K_scale, + transpose_bm=True, + prequant=True, + y_scale=self._q_scale if fp8_attention else None, + ) + elif self.is_aiter_triton_fp8_bmm_enabled: + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + mqa_ql_nope = rocm_aiter_ops.triton_fp8_bmm( + mqa_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + ) + else: + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = mqa_q_nope.shape + _, _, L = self.W_UK_T.shape + + if self.q_pad_num_heads is not None: + mqa_ql_nope = mqa_q_nope.new_empty((self.q_pad_num_heads, B, L)) + mqa_ql_nope.resize_((N, B, L)) + else: + mqa_ql_nope = mqa_q_nope.new_empty((N, B, L)) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(mqa_q_nope, self.W_UK_T, out=mqa_ql_nope) + + # Convert from (N, B, L) to (B, N, L) + mqa_ql_nope = mqa_ql_nope.transpose(0, 1) + + if fp8_attention: + assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0] + assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1] + mqa_q = self._decode_concat_quant_fp8_op( + mqa_ql_nope, mqa_q_pe, self._q_scale + ) + else: + mqa_q = (mqa_ql_nope, mqa_q_pe) + if self.impl.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate mqa_ql_nope and mqa_q_pe -> (B, N, L + P) + mqa_q = torch.cat(mqa_q, dim=-1) + # mqa_q do allgather in head dim. + mqa_q = get_dcp_group().all_gather(mqa_q, dim=1) + + # call decode attn + attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) + + # correct dcp attn_out with lse. + if self.impl.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) + + # v_up projection + self._v_up_proj(attn_out, out=mqa_output_slice) + return output_padded + def process_weights_after_loading(self, act_dtype: torch.dtype): - if hasattr(self.impl, "process_weights_after_loading"): - self.impl.process_weights_after_loading(act_dtype) + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights( + self.kv_b_proj, out_dtype=act_dtype + ).T + + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import ( + quark_quantize_weight_to_mxfp4, + ) + + self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK) + # Convert from (L, N, P) to (N, L, P) + self.W_K = self.W_K.transpose(0, 1) + self.W_K_scale = self.W_K_scale.transpose(0, 1) + + self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4( + W_UV.permute(1, 2, 0) + ) + elif self.is_aiter_triton_fp8_bmm_enabled: + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype() + ) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype() + ) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + rocm_aiter_ops.triton_fp8_bmm( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + rocm_aiter_ops.triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) # If we should not load quant weights, we initialize the scales to 1.0 # as the default value. See [Note: Register q/k/v/prob scales in state dict] @@ -492,6 +796,41 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: cache_dtype_str=vllm_config.cache_config.cache_dtype, ) + def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + out = out.view(-1, self.num_heads, self.v_head_dim) + if self.is_aiter_triton_fp4_bmm_enabled: + out = rocm_aiter_ops.batched_gemm_a16wfp4( + x, + self.W_V, + self.W_V_scale, + out, + transpose_bm=True, + prequant=True, + y_scale=None, + ) + x = out.view(-1, self.num_heads * self.v_head_dim) + elif self.is_aiter_triton_fp8_bmm_enabled: + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = rocm_aiter_ops.triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out + ) + else: + # Convert from (B, N * V) to (N, B, V) + out = out.transpose(0, 1) + + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + + # Convert from (N, B, V) to (B, N * V) + out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result + @maybe_transfer_kv_layer def unified_mla_attention( @@ -500,8 +839,8 @@ def unified_mla_attention( k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: - attn_metadata, self, kv_cache = get_attention_context(layer_name) - output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) + attn_metadata, layer, kv_cache = get_attention_context(layer_name) + output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) return output @@ -534,9 +873,8 @@ def unified_mla_attention_with_output( output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: - attn_metadata, self, kv_cache = get_attention_context(layer_name) - self.impl.forward( - self, + attn_metadata, layer, kv_cache = get_attention_context(layer_name) + layer.forward_impl( q, kv_c_normed, k_pe, @@ -1460,247 +1798,104 @@ def reorg_kvcache( padded_local_chunk_seq_lens_lst: local chunk context lengths under current CP rank. local_context_lens_allranks: local context lengths on each CP rank. - sum_seq_len: the sum of cp_chunk_seq_lens_lst. - max_seq_len: the max value of cp_chunk_seq_lens_lst. - chunk_size: the local padded max context chunk from - chunked_context_metadata building. - chunk_idx: chunk idx of chunked_prefill. - toks: the number of tokens for local gather cache. - """ - kv_c_segments = [] - k_pe_segments = [] - src_token_idx = 0 - max_seq_len_check = 0 - for padded_local_chunk_seq_len, local_context_lens in zip( - padded_local_chunk_seq_lens_lst, local_context_lens_allranks - ): - cur_seq_len = 0 - for rank, local_context_len in enumerate(local_context_lens): - # Note(qcs): We split the context into multiple chunks, - # depending on the size of the workspace. - # local_context in dcp0: |-----------------| - # local_context in dcp1: |--------------| - # n*padded_local_chunk: |-----|-----|-----| - # local_chunk_len in dcp1: |-----|-----|--| - # so we need update the last chunk length in dcp1. - local_chunk_len = min( - max(0, local_context_len - chunk_idx * chunk_size), - padded_local_chunk_seq_len, - ) - if local_chunk_len != 0: - kv_c_segment = allgatered_kv_c_normed[ - rank * toks + src_token_idx : rank * toks - + src_token_idx - + local_chunk_len - ] - k_pe_segment = allgatered_k_pe[ - rank * toks + src_token_idx : rank * toks - + src_token_idx - + local_chunk_len - ] - kv_c_segments.append(kv_c_segment) - k_pe_segments.append(k_pe_segment) - cur_seq_len += local_chunk_len - max_seq_len_check = max(max_seq_len_check, cur_seq_len) - src_token_idx += padded_local_chunk_seq_len - reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) - reorganized_k_pe = torch.cat(k_pe_segments, dim=0) - assert reorganized_kv_c_normed.shape[0] == sum_seq_len - assert reorganized_k_pe.shape[0] == sum_seq_len - assert max_seq_len_check == max_seq_len - return reorganized_kv_c_normed, reorganized_k_pe - - -# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl, -# and MLACommonImpl -> MLACommonDenseImpl or somthing like that -class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: list[float] | None, - sliding_window: int | None, - kv_cache_dtype: str, - logits_soft_cap: float | None, - attn_type: str, - kv_sharing_target_layer_name: str | None, - # MLA Specific Arguments - q_lora_rank: int | None, - kv_lora_rank: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - qk_head_dim: int, - v_head_dim: int, - kv_b_proj: ColumnParallelLinear, - indexer=None, - q_pad_num_heads: int | None = None, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported for MLA") - - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_head_dim - self.v_head_dim = v_head_dim - self.kv_b_proj = kv_b_proj - self.indexer = indexer - self.q_pad_num_heads = q_pad_num_heads - self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() - - # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported - self.is_aiter_triton_fp4_bmm_enabled = ( - rocm_aiter_ops.is_fp4bmm_enabled() - and self.kv_b_proj.weight.dtype == torch.bfloat16 - ) - - def process_weights_after_loading(self, act_dtype: torch.dtype): - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights( - self.kv_b_proj, out_dtype=act_dtype - ).T - - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - ), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}" - ) - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - - # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported - if self.is_aiter_triton_fp4_bmm_enabled: - from vllm.model_executor.layers.quantization.quark.utils import ( - quark_quantize_weight_to_mxfp4, - ) - - self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK) - # Convert from (L, N, P) to (N, L, P) - self.W_K = self.W_K.transpose(0, 1) - self.W_K_scale = self.W_K_scale.transpose(0, 1) - - self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4( - W_UV.permute(1, 2, 0) - ) - elif self.is_aiter_triton_fp8_bmm_enabled: - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype() - ) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype() - ) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty( - (self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - ) - - x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - - def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - out = out.view(-1, self.num_heads, self.v_head_dim) - if self.is_aiter_triton_fp4_bmm_enabled: - out = rocm_aiter_ops.batched_gemm_a16wfp4( - x, - self.W_V, - self.W_V_scale, - out, - transpose_bm=True, - prequant=True, - y_scale=None, - ) - x = out.view(-1, self.num_heads * self.v_head_dim) - elif self.is_aiter_triton_fp8_bmm_enabled: - # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: the local padded max context chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for padded_local_chunk_seq_len, local_context_lens in zip( + padded_local_chunk_seq_lens_lst, local_context_lens_allranks + ): + cur_seq_len = 0 + for rank, local_context_len in enumerate(local_context_lens): + # Note(qcs): We split the context into multiple chunks, + # depending on the size of the workspace. + # local_context in dcp0: |-----------------| + # local_context in dcp1: |--------------| + # n*padded_local_chunk: |-----|-----|-----| + # local_chunk_len in dcp1: |-----|-----|--| + # so we need update the last chunk length in dcp1. + local_chunk_len = min( + max(0, local_context_len - chunk_idx * chunk_size), + padded_local_chunk_seq_len, ) - else: - # Convert from (B, N * V) to (N, B, V) - out = out.transpose(0, 1) - - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" - - # Convert from (N, B, V) to (B, N * V) - out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - - # Adjust output buffer shape back to the original (B, N * V) - N, B, V = out.shape - out.resize_((B, N * V)) - out.copy_(out_new) # Copy result + if local_chunk_len != 0: + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + local_chunk_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + local_chunk_len + ] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += local_chunk_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += padded_local_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe -class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): +class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + q_lora_rank: int | None, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: ColumnParallelLinear, + indexer: object | None = None, + q_pad_num_heads: int | None = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported for MLA") + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.kv_b_proj = kv_b_proj + self.indexer = indexer + self.q_pad_num_heads = q_pad_num_heads if use_trtllm_ragged_deepseek_prefill(): logger.info_once( @@ -1750,19 +1945,9 @@ def __init__(self, *args, **kwargs) -> None: self.dcp_world_size: int = -1 - self.chunked_prefill_workspace_size = ( - MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( - get_current_vllm_config() - ) - ) self.cp_kv_cache_interleave_size: int = ( get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size ) - self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( - static=True, - group_shape=GroupShape.PER_TENSOR, - compile_native=True, - ) def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs @@ -2193,7 +2378,7 @@ def _context_parallel_compute_prefill_context( return output, output_lse - def _forward_prefill( + def forward_mha( self, q: torch.Tensor, kv_c_normed: torch.Tensor, @@ -2258,7 +2443,7 @@ def _forward_prefill( output.copy_(output_prefill) @abstractmethod - def _forward_decode( + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, @@ -2266,185 +2451,3 @@ def _forward_decode( layer: AttentionLayer, ) -> tuple[torch.Tensor, torch.Tensor | None]: raise NotImplementedError - - def forward( - self, - layer: AttentionLayer, - q: torch.Tensor, - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, - attn_metadata: M, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported for MLACommonImpl" - ) - - if attn_metadata is None: - # During the profile run try to simulate to worse case output size - # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` - # since this can be large - _ = torch.empty( - ( - self.chunked_prefill_workspace_size, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ), - device=k_c_normed.device, - dtype=k_c_normed.dtype, - ) - - # The zero fill is required when used with DP + EP - # to ensure all ranks within a DP group compute the - # same expert outputs. - return output.fill_(0) - - if self.dcp_world_size == -1: - self.dcp_world_size = get_dcp_group().world_size - - fp8_attention = self.kv_cache_dtype.startswith("fp8") - - num_actual_toks = attn_metadata.num_actual_tokens - - # Inputs and outputs may be padded for CUDA graphs - output_padded = output - output = output[:num_actual_toks, ...] - q = q[:num_actual_toks, ...] - k_c_normed = k_c_normed[:num_actual_toks, ...] - k_pe = k_pe[:num_actual_toks, ...] - - assert ( - attn_metadata.num_decodes is not None - and attn_metadata.num_prefills is not None - and attn_metadata.num_decode_tokens is not None - ) - - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - - decode_q = q[:num_decode_tokens] - - prefill_q = q[num_decode_tokens:] - prefill_k_pe = k_pe[num_decode_tokens:] - prefill_k_c_normed = k_c_normed[num_decode_tokens:] - - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - - if fp8_attention: - kv_cache = kv_cache.view(current_platform.fp8_dtype()) - - if has_prefill: - self._forward_prefill( - prefill_q, - prefill_k_c_normed, - prefill_k_pe, - kv_cache, - attn_metadata, - layer._k_scale, - output=output[num_decode_tokens:], - ) - - if has_decode: - assert attn_metadata.decode is not None - - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) - - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - - if self.q_pad_num_heads is not None: - B, N, L = decode_q_pe.shape - decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) - decode_pe_padded.resize_((B, N, L)) - decode_pe_padded.copy_(decode_q_pe) - decode_q_pe = decode_pe_padded - - if self.is_aiter_triton_fp4_bmm_enabled: - from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 - - decode_ql_nope = batched_gemm_a16wfp4( - decode_q_nope, - self.W_K, - self.W_K_scale, - transpose_bm=True, - prequant=True, - y_scale=layer._q_scale if fp8_attention else None, - ) - elif self.is_aiter_triton_fp8_bmm_enabled: - # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( - decode_q_nope, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True, - ) - else: - # Pads the head_dim if necessary (for the underlying kernel) - N, B, P = decode_q_nope.shape - _, _, L = self.W_UK_T.shape - - if self.q_pad_num_heads is not None: - decode_ql_nope = decode_q_nope.new_empty( - (self.q_pad_num_heads, B, L) - ) - decode_ql_nope.resize_((N, B, L)) - else: - decode_ql_nope = decode_q_nope.new_empty((N, B, L)) - - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) - - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - - if fp8_attention: - assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] - assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] - decode_q = self._decode_concat_quant_fp8_op( - decode_ql_nope, decode_q_pe, layer._q_scale - ) - else: - decode_q = (decode_ql_nope, decode_q_pe) - if self.dcp_world_size > 1: - assert not fp8_attention, "DCP not support fp8 kvcache now." - # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) - decode_q = torch.cat(decode_q, dim=-1) - # decode_q do allgather in head dim. - decode_q = get_dcp_group().all_gather(decode_q, dim=1) - - # call decode attn - attn_out, lse = self._forward_decode( - decode_q, kv_cache, attn_metadata, layer - ) - - # correct dcp attn_out with lse. - if self.dcp_world_size > 1: - attn_out = cp_lse_ag_out_rs( - attn_out, - lse, - get_dcp_group(), - is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), - ) - - # v_up projection - self._v_up_proj(attn_out, out=output[:num_decode_tokens]) - return output_padded diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 32a143f8ee67..13082608c47c 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -67,7 +67,7 @@ def get_name() -> str: @staticmethod @abstractmethod - def get_impl_cls() -> type["AttentionImpl"]: + def get_impl_cls() -> type["AttentionImplBase"]: raise NotImplementedError @staticmethod @@ -594,7 +594,14 @@ def forward( ) -> torch.Tensor: ... -class AttentionImpl(ABC, Generic[T]): +class AttentionImplBase(ABC, Generic[T]): + """Base class for attention implementations. + + Contains common attributes and initialization logic shared by both + standard AttentionImpl and MLAAttentionImpl. Does not define a forward + method - subclasses define their own forward interfaces. + """ + # Required attributes that all impls should have num_heads: int head_size: int @@ -662,6 +669,13 @@ def __new__(cls, *args, **kwargs): ) return self + def process_weights_after_loading(self, act_dtype: torch.dtype): + pass + + +class AttentionImpl(AttentionImplBase[T], Generic[T]): + """Standard attention implementation with forward method.""" + @abstractmethod def __init__( self, @@ -704,11 +718,10 @@ def fused_output_quant_supported(self, quant_key: "QuantKey"): """ return False - def process_weights_after_loading(self, act_dtype: torch.dtype): - pass +class MLAAttentionImpl(AttentionImplBase[T], Generic[T]): + """MLA attention implementation with forward_mqa and forward_mha methods.""" -class MLAAttentionImpl(AttentionImpl[T], Generic[T]): @abstractmethod def __init__( self, @@ -731,22 +744,78 @@ def __init__( v_head_dim: int, kv_b_proj: "ColumnParallelLinear", indexer: object | None = None, + q_pad_num_heads: int | None = None, ) -> None: raise NotImplementedError @abstractmethod - def forward( + def forward_mha( self, - layer: AttentionLayer, - hidden_states_or_cq: torch.Tensor, + q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, - kv_cache: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: T, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: + k_scale: torch.Tensor, + output: torch.Tensor, + ) -> None: + """MHA-style prefill forward pass.""" + raise NotImplementedError + + @abstractmethod + def forward_mqa( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: T, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """MQA-style decode forward pass.""" + raise NotImplementedError + + +class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): + """Sparse MLA attention implementation with only forward_mqa method. + + Sparse MLA implementations only support decode (MQA-style) attention. + They do not support prefill (MHA-style) attention. + """ + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + q_lora_rank: int | None, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: "ColumnParallelLinear", + indexer: object | None = None, + q_pad_num_heads: int | None = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward_mqa( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: T, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """MQA-style decode forward pass.""" raise NotImplementedError diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index a8ba10080d54..6d10a9d66e20 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -244,7 +244,7 @@ def _sm100_cutlass_mla_decode( return out, lse - def _forward_decode( + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index b912a3d3d785..e160d3255688 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -293,7 +293,7 @@ def __init__( "FlashAttnMLA V1 with FP8 KV cache not yet supported" ) - def _forward_decode( + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index d1314ccf2955..58d4bec7c92e 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -150,7 +150,7 @@ def __init__( self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None - def _forward_decode( + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2dd8f4a51006..37ab148095f7 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -234,7 +234,7 @@ def __init__( "FlashMLAImpl" ) - def _forward_decode( + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 2f77e3c031b7..8ef957a93c20 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -11,7 +11,6 @@ from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.attention.mla_attention import ( - MLACommonBaseImpl, get_mla_dims, ) from vllm.platforms import current_platform @@ -25,6 +24,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, MultipleOf, + SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.utils import ( reshape_attn_output_for_spec_decode, @@ -686,7 +686,7 @@ def build( return metadata -class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): +class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]): @staticmethod def _compute_fp8_decode_padded_heads(num_heads: int) -> int: # FP8 decode kernel only supports h_q = 64 or 128 @@ -710,19 +710,12 @@ def __init__( indexer: Optional["Indexer"] = None, **mla_args, ) -> None: - super().__init__( - num_heads, - head_size, - scale, - num_kv_heads, - alibi_slopes, - sliding_window, - kv_cache_dtype, - logits_soft_cap, - attn_type, - kv_sharing_target_layer_name, - **mla_args, - ) + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.kv_lora_rank: int = mla_args["kv_lora_rank"] self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer @@ -974,78 +967,39 @@ def _bf16_flash_mla_kernel( output = output[:, : self.num_heads, :] return output - def forward( + def forward_mqa( self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, layer: AttentionLayer, - q: torch.Tensor, - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata | None, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None]: # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # MQA 576/512 approach for both prefill and decode - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported for MLACommonImpl" - ) - - if attn_metadata is None: - # Dummy run - no need to allocate buffers - # The zero fill is required when used with DP + EP - # to ensure all ranks within a DP group compute the - # same expert outputs. - return output.fill_(0) - - num_actual_toks = attn_metadata.num_actual_tokens + # Concatenate q if it's a tuple (ql_nope, q_pe) + if isinstance(q, tuple): + q = torch.cat(q, dim=-1) - # Inputs and outputs may be padded for CUDA graphs + num_actual_toks = q.shape[0] - q = q[:num_actual_toks, ...] - k_c_normed = k_c_normed[:num_actual_toks, ...] - k_pe = k_pe[:num_actual_toks, ...] + # Get topk indices assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[:num_actual_toks] - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - ql_nope = ql_nope.transpose(0, 1) - use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" - q = torch.cat([ql_nope, q_pe], dim=-1) - - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - if not use_fp8_cache: - attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) + attn_out = self._forward_bf16_kv( + q, kv_c_and_k_pe_cache, topk_indices, attn_metadata + ) elif attn_metadata.fp8_use_mixed_batch: attn_out = self._forward_fp8_kv_mixed_batch( - q, kv_cache, topk_indices, attn_metadata + q, kv_c_and_k_pe_cache, topk_indices, attn_metadata ) else: attn_out = self._forward_fp8_kv_separate_prefill_decode( - q, kv_cache, topk_indices, attn_metadata + q, kv_c_and_k_pe_cache, topk_indices, attn_metadata ) - self._v_up_proj(attn_out, out=output[:num_actual_toks]) - return output + return attn_out, None diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 3abf8ad309d3..57a1d32d2d47 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -241,7 +241,7 @@ def _flash_attn_varlen_diff_headdims( return output - def _forward_decode( + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 08a7336b5fec..8e7f7bd27763 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -7,12 +7,10 @@ import numpy as np import torch -from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.attention.mla_attention import ( - MLACommonBaseImpl, get_mla_dims, ) from vllm.triton_utils import tl, triton @@ -23,6 +21,7 @@ AttentionMetadata, AttentionMetadataBuilder, CommonAttentionMetadata, + SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.mla.flashmla_sparse import ( triton_convert_req_index_to_global_index, @@ -269,7 +268,7 @@ def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: return (result, lse) -class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): +class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]): def __init__( self, num_heads: int, @@ -287,23 +286,15 @@ def __init__( indexer: Optional["Indexer"] = None, **mla_args, ) -> None: - super().__init__( - num_heads, - head_size, - scale, - num_kv_heads, - alibi_slopes, - sliding_window, - kv_cache_dtype, - logits_soft_cap, - attn_type, - kv_sharing_target_layer_name, - **mla_args, - ) + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.kv_lora_rank: int = mla_args["kv_lora_rank"] self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer - self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() def _forward_bf16_kv( self, @@ -342,56 +333,23 @@ def _forward_bf16_kv( return output[:, : self.num_heads, :] - def forward( + def forward_mqa( self, - layer: AttentionLayer, - q: torch.Tensor, - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: ROCMAiterMLASparseMetadata, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # MQA 576/512 approach for both prefill and decode - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported for ROCMAiterMLASparse" - ) - - if attn_metadata is None: - # The zero fill is required when used with DP + EP - # to ensure all ranks within a DP group compute the - # same expert outputs. - return output.fill_(0) - - num_actual_toks = attn_metadata.num_actual_tokens - - # Inputs and outputs may be padded for CUDA graphs - - q = q[:num_actual_toks, ...] - k_c_normed = k_c_normed[:num_actual_toks, ...] - k_pe = k_pe[:num_actual_toks, ...] - - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - if self.is_fp8bmm_enabled: - # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - ql_nope = rocm_aiter_ops.triton_fp8_bmm( - q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - ) - else: - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - ql_nope = ql_nope.transpose(0, 1) + # Concatenate q if it's a tuple (ql_nope, q_pe) + if isinstance(q, tuple): + q = torch.cat(q, dim=-1) + num_actual_toks = q.shape[0] + + # Get topk indices assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[:num_actual_toks] @@ -403,22 +361,8 @@ def forward( NUM_TOPK_TOKENS=attn_metadata.topk_tokens, ) - q = torch.cat([ql_nope, q_pe], dim=-1) - - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - attn_out = self._forward_bf16_kv( - q, kv_cache, topk_indices_global, attn_metadata + q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata ) - self._v_up_proj(attn_out, out=output[:num_actual_toks]) - return output + return attn_out, None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 84e025dcd358..2403dcc61313 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -110,7 +110,7 @@ def _flash_attn_varlen_diff_headdims( **kwargs, ) - def _forward_decode( + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor,