diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index ba70c8251745..32c0b9064275 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -19,8 +19,13 @@ ) from vllm import _custom_ops as ops from vllm.config.vllm import set_current_vllm_config -from vllm.model_executor.layers.attention.mla_attention import QueryLenSupport +from vllm.model_executor.layers.attention.mla_attention import ( + QueryLenSupport, + _DecodeConcatQuantFP8, +) from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backend import CommonAttentionMetadata @@ -50,6 +55,7 @@ if not is_flashmla_dense_supported()[0]: BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA) + SPEC_DECODE_BACKENDS = [] for backend in BACKENDS_TO_TEST: builder_cls, _ = try_get_attention_backend(backend) @@ -144,9 +150,8 @@ def create_and_prepopulate_kv_cache( common_attn_metadata: Common attention metadata randomize_blocks: Whether to randomly permute blocks or use sequential order - kv_cache_dtype: Optional kv cache dtype string. When set to - "fp8_ds_mla" the cache is populated using the - fp8 DeepSeek MLA layout via concat_and_cache_mla. + kv_cache_dtype: Optional kv cache dtype string. For fp8 cache dtype, + the cache is populated via concat_and_cache_mla. scale: Scaling factor forwarded to concat_and_cache_mla when the fp8 cache layout is requested. @@ -163,18 +168,21 @@ def create_and_prepopulate_kv_cache( block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + fp8_attention = kv_cache_dtype and kv_cache_dtype.startswith("fp8") use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla" - if use_fp8_ds_mla: - if not kv_c_contexts: - raise ValueError( - "kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype" - ) - kv_lora_rank = kv_c_contexts[0].shape[-1] - rope_dim = k_pe_contexts[0].shape[-1] - entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim + if fp8_attention: + if use_fp8_ds_mla: + kv_lora_rank = kv_c_contexts[0].shape[-1] + rope_dim = k_pe_contexts[0].shape[-1] + # 4 * 4: 4 float32 scale values for 128-element tiles + # 2 * rope_dim: 16-bit RoPE values + kv_entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim + else: + kv_entry_size = head_size + kv_cache = torch.zeros( - num_blocks, block_size, entry_size, dtype=torch.uint8, device=device + num_blocks, block_size, kv_entry_size, dtype=torch.uint8, device=device ) scale_tensor = ( scale @@ -201,14 +209,14 @@ def create_and_prepopulate_kv_cache( start = start_block_idx * block_size - if use_fp8_ds_mla: + if fp8_attention: slots = torch.arange(context_len, device=device, dtype=torch.long) + start ops.concat_and_cache_mla( kv_c_context, k_pe_context.squeeze(1), kv_cache, slots, - kv_cache_dtype="fp8_ds_mla", + kv_cache_dtype=kv_cache_dtype, scale=scale_tensor, ) else: @@ -329,8 +337,9 @@ def forward_impl( output: torch.Tensor, ) -> torch.Tensor: """Forward for sparse MLA - uses forward_mqa for all tokens.""" - # Write to KV cache kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto") + + # Write to KV cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( kv_c, @@ -426,6 +435,12 @@ def __init__( self._k_scale_float = 1.0 self._v_scale_float = 1.0 + self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + compile_native=True, + ) + def get_attn_backend(self): raise NotImplementedError @@ -443,16 +458,21 @@ def forward_impl( ) -> torch.Tensor: """Replicates MLAAttention.forward_impl logic for testing.""" # Write to KV cache + kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto") + fp8_attention = kv_cache_dtype.startswith("fp8") if kv_cache.numel() > 0: ops.concat_and_cache_mla( kv_c, k_pe.squeeze(1), kv_cache, attn_metadata.slot_mapping.flatten(), - kv_cache_dtype="auto", + kv_cache_dtype=kv_cache_dtype, scale=self._k_scale, ) + if fp8_attention and kv_cache_dtype != "fp8_ds_mla": + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + # Determine decode vs prefill split num_decode_tokens = attn_metadata.num_decode_tokens or 0 has_decode = (attn_metadata.num_decodes or 0) > 0 @@ -491,8 +511,14 @@ def forward_impl( # Convert from (N, B, L) to (B, N, L) mqa_ql_nope = mqa_ql_nope.transpose(0, 1) - # Pass as tuple to forward_mqa - mqa_q = (mqa_ql_nope, mqa_q_pe) + if fp8_attention and self.impl.supports_quant_query_input: + assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0] + assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1] + mqa_q = self._decode_concat_quant_fp8_op( + mqa_ql_nope, mqa_q_pe, self._q_scale + ) + else: + mqa_q = (mqa_ql_nope, mqa_q_pe) attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) @@ -526,6 +552,7 @@ def run_attention_backend( qk_rope_head_dim: int, v_head_dim: int, mock_kv_b_proj, + kv_cache_dtype: str = "auto", ) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" @@ -550,7 +577,7 @@ def run_attention_backend( num_kv_heads=num_kv_heads, alibi_slopes=None, sliding_window=None, - kv_cache_dtype="auto", + kv_cache_dtype=kv_cache_dtype, logits_soft_cap=None, attn_type="decoder", kv_sharing_target_layer_name=None, @@ -630,12 +657,14 @@ def run_attention_backend( ) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) def test_backend_correctness( default_vllm_config, dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int, + kv_cache_dtype: str, ): """ Test that all backends produce similar outputs to a reference implementation @@ -658,9 +687,18 @@ def test_backend_correctness( head counts. """ + # Filter backends to those that support the requested kv_cache_dtype + backends_to_test = [ + b + for b in BACKENDS_TO_TEST + if kv_cache_dtype in b.get_class().supported_kv_cache_dtypes + ] + if not backends_to_test: + pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}") + batch_spec = BATCH_SPECS[batch_spec_name] is_spec_decode_test = batch_spec_name.startswith("spec_decode") - unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values())) + unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES[b] for b in backends_to_test)) default_block_size = unique_block_sizes[0] required_blocks = sum( (seq_len + default_block_size - 1) // default_block_size @@ -694,6 +732,7 @@ def test_backend_correctness( block_size=default_block_size, hf_config_override=hf_config_override, ) + vllm_config.cache_config.cache_dtype = kv_cache_dtype # For spec decode tests, add a speculative_config to set the reorder_batch_threshold if is_spec_decode_test: @@ -751,7 +790,7 @@ def test_backend_correctness( kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) - for i, backend in enumerate(BACKENDS_TO_TEST): + for i, backend in enumerate(backends_to_test): all_sdpa_outputs.append([]) for i in range(batch_size): @@ -785,7 +824,7 @@ def test_backend_correctness( # pipeline (MHA-style). This ensures the reference implementation # matches each backend's actual decode/prefill pipeline path. is_decode = [] - for backend_idx, backend in enumerate(BACKENDS_TO_TEST): + for backend_idx, backend in enumerate(backends_to_test): builder_cls, _ = try_get_attention_backend(backend) if is_spec_decode_test: query_len_support = getattr( @@ -885,7 +924,7 @@ def test_backend_correctness( sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) - for backend_idx, backend in enumerate(BACKENDS_TO_TEST): + for backend_idx, backend in enumerate(backends_to_test): if is_decode[backend_idx]: all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode) else: @@ -905,7 +944,7 @@ def test_backend_correctness( kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) sdpa_outputs = {} - for backend_idx, backend in enumerate(BACKENDS_TO_TEST): + for backend_idx, backend in enumerate(backends_to_test): sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0) # Create mock kv_b_proj using the same weights as reference implementation @@ -973,12 +1012,13 @@ def test_backend_correctness( num_blocks=num_blocks_for_size, common_attn_metadata=common_attn_metadata, randomize_blocks=True, + kv_cache_dtype=kv_cache_dtype, ) kv_cache_per_block_size[block_size] = kv_cache # 4. Run vLLM backends and compare failures = [] - for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST): + for backend_idx, backend_name in enumerate(backends_to_test): # Skip backends that don't support spec decode for spec decode tests if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS: continue @@ -997,7 +1037,7 @@ def test_backend_correctness( head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, sliding_window=vllm_config.model_config.get_sliding_window(), - cache_dtype_str=vllm_config.cache_config.cache_dtype, + cache_dtype_str=kv_cache_dtype, ) backend_output = run_attention_backend( @@ -1016,6 +1056,7 @@ def test_backend_correctness( qk_rope_head_dim, v_head_dim, mock_kv_b_proj, + kv_cache_dtype=kv_cache_dtype, ) # Use backend_idx to get the correct SDPA output for this backend