diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 862f84939853..4859af43ae41 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -530,7 +530,7 @@ def forward_impl( scale=self._k_scale, ) - if fp8_attention: + if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla": kv_cache = kv_cache.view(current_platform.fp8_dtype()) # Sparse MLA impls only support forward_mqa (decode-style attention) @@ -614,7 +614,7 @@ def forward_impl( # Convert from (N, B, L) to (B, N, L) mqa_ql_nope = mqa_ql_nope.transpose(0, 1) - if fp8_attention: + if fp8_attention and self.impl.supports_quant_query_input: assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0] assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1] mqa_q = self._decode_concat_quant_fp8_op( @@ -1885,6 +1885,8 @@ def __init__( self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads + self.supports_quant_query_input = True + # Use flashinfer's optimized concat_mla_k kernel when available. # The kernel is optimized for DeepSeek V3 dimensions: # num_heads=128, nope_dim=128, rope_dim=64