diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index e726d99256f5..b6cd49bce125 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -213,4 +213,4 @@ configuration. | `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | +| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index f6c1790f60c8..da02363a5888 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -15,11 +15,11 @@ from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backend import ( AttentionLayer, AttentionType, - is_quantized_kv_cache, ) from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd @@ -31,6 +31,8 @@ class TritonMLABackend(MLACommonBackend): supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "bfloat16", + "fp8", + "fp8_e4m3", ] @staticmethod @@ -93,11 +95,6 @@ def __init__( "TritonMLAImpl" ) - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "TritonMLA V1 with FP8 KV cache not yet supported" - ) - def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs ): @@ -120,8 +117,8 @@ def forward_mqa( assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Triton MLA not yet supported") + # Determine if we're using FP8 KV cache + use_fp8 = self.kv_cache_dtype.startswith("fp8") if type(q) is tuple: q = torch.cat(q, dim=-1) @@ -129,10 +126,15 @@ def forward_mqa( assert isinstance(q, torch.Tensor) B = q.shape[0] q_num_heads = q.shape[1] + # Cast FP8 q to bfloat16 for Blackwell compatibility (FP8 tl.dot + # may produce invalid instructions) and for V up-proj torch.bmm. + if use_fp8 and q.dtype not in (torch.float16, torch.bfloat16): + q = q.to(torch.bfloat16) + out_dtype = q.dtype o = torch.zeros( - B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + B, q_num_heads, self.kv_lora_rank, dtype=out_dtype, device=q.device ) - lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) + lse = torch.zeros(B, q_num_heads, dtype=out_dtype, device=q.device) # For batch invariance, use only 1 split to ensure deterministic reduction num_kv_splits = 1 if vllm_is_batch_invariant() else 4 @@ -151,12 +153,18 @@ def forward_mqa( device=q.device, ) + # View as FP8 and pass scale for in-kernel dequantization. + if use_fp8: + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(current_platform.fp8_dtype()) + k_scale = layer._k_scale + else: + k_scale = None + # Add a head dim of 1 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] PAGE_SIZE = kv_c_and_k_pe_cache.size(1) - # Run MQA decode_attention_fwd( q, kv_c_and_k_pe_cache, @@ -169,6 +177,7 @@ def forward_mqa( num_kv_splits, self.scale, PAGE_SIZE, + k_scale=k_scale, ) return o, lse diff --git a/vllm/v1/attention/ops/triton_decode_attention.py b/vllm/v1/attention/ops/triton_decode_attention.py index 1ed9698c507a..daacd1e18395 100644 --- a/vllm/v1/attention/ops/triton_decode_attention.py +++ b/vllm/v1/attention/ops/triton_decode_attention.py @@ -9,6 +9,7 @@ # Changes: # - Add support for page size >= 1. +# - Add support for FP8 quantized KV cache. # Copyright 2025 vLLM Team # Copyright 2023-2024 SGLang Team @@ -27,10 +28,12 @@ """ Memory-efficient attention for decoding. It supports page size >= 1. +It supports FP8 quantized KV cache with on-the-fly dequantization. """ import logging +import torch from packaging import version from vllm.platforms import current_platform @@ -64,6 +67,7 @@ def _fwd_kernel_stage1( Req_to_tokens, B_Seqlen, Att_Out, + K_scale, # FP8 dequantization scale for K/V cache stride_req_to_tokens_b, stride_qbs, stride_qh, @@ -83,6 +87,7 @@ def _fwd_kernel_stage1( logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, + USE_FP8: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -129,6 +134,11 @@ def _fwd_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), other=0.0, ) + # Dequantize FP8 KV cache on-the-fly + if USE_FP8: + k_scale = tl.load(K_scale) + k = k.to(tl.float32) * k_scale + qk = tl.sum(q[None, :] * k, 1) qk *= sm_scale @@ -147,6 +157,9 @@ def _fwd_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) + # Dequantize FP8 KV cache on-the-fly + if USE_FP8: + v = v.to(tl.float32) * k_scale n_e_max = tl.maximum(tl.max(qk, 0), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -194,6 +207,7 @@ def _decode_att_m_fwd( sm_scale, page_size, logit_cap, + k_scale=None, ): BLOCK = 64 if not is_hip_ else 8 @@ -213,6 +227,15 @@ def _decode_att_m_fwd( BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) + # Determine if we're using FP8 KV cache + use_fp8 = k_scale is not None + + # Create a dummy scale tensor if not using FP8 (Triton requires valid tensor) + if k_scale is None: + k_scale = torch.empty(1, dtype=torch.float32, device=q.device) + + num_stages = 1 if use_fp8 else 2 + _fwd_kernel_stage1[grid]( q, k_buffer, @@ -221,6 +244,7 @@ def _decode_att_m_fwd( Req_to_tokens, B_Seqlen, att_out, + k_scale, Req_to_tokens.stride(0), q.stride(0), q.stride(1), @@ -239,9 +263,10 @@ def _decode_att_m_fwd( PAGE_SIZE=page_size, logit_cap=logit_cap, num_warps=num_warps, - num_stages=2, + num_stages=num_stages, Lk=Lk, Lv=Lv, + USE_FP8=use_fp8, ) @@ -254,6 +279,7 @@ def _fwd_grouped_kernel_stage1( Req_to_tokens, B_Seqlen, Att_Out, + K_scale, # FP8 dequantization scale for K/V cache stride_req_to_tokens_b, stride_qbs, stride_qh, @@ -276,6 +302,7 @@ def _fwd_grouped_kernel_stage1( logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, + USE_FP8: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) @@ -336,6 +363,11 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, ) + # Dequantize FP8 KV cache on-the-fly + if USE_FP8: + k_scale = tl.load(K_scale) + k = k.to(tl.float32) * k_scale + qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: offs_buf_kpe = ( @@ -348,6 +380,9 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, ) + # Dequantize FP8 KV cache on-the-fly + if USE_FP8: + kpe = kpe.to(tl.float32) * k_scale qk += tl.dot(qpe, kpe.to(qpe.dtype)) qk *= sm_scale @@ -368,6 +403,9 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) + # Dequantize FP8 KV cache on-the-fly + if USE_FP8: + v = v.to(tl.float32) * k_scale n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -416,6 +454,7 @@ def _decode_grouped_att_m_fwd( sm_scale, page_size, logit_cap, + k_scale=None, ): BLOCK = 32 Lk = k_buffer.shape[-1] @@ -455,6 +494,18 @@ def _decode_grouped_att_m_fwd( extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} num_stages = 1 + # Determine if we're using FP8 KV cache + use_fp8 = k_scale is not None + + # Reduce pipeline stages for FP8 to avoid shared memory overflow + # from float32 intermediates during dequantization. + if use_fp8: + num_stages = 1 + + # Create a dummy scale tensor if not using FP8 (Triton requires valid tensor) + if k_scale is None: + k_scale = torch.empty(1, dtype=torch.float32, device=q.device) + _fwd_grouped_kernel_stage1[grid]( q, k_buffer, @@ -463,6 +514,7 @@ def _decode_grouped_att_m_fwd( Req_to_tokens, B_Seqlen, att_out, + k_scale, Req_to_tokens.stride(0), q.stride(0), q.stride(1), @@ -487,6 +539,7 @@ def _decode_grouped_att_m_fwd( num_stages=num_stages, Lk=Lk, Lv=Lv, + USE_FP8=use_fp8, **extra_kargs, ) @@ -609,6 +662,7 @@ def decode_attention_fwd_normal( sm_scale, page_size, logit_cap=0.0, + k_scale=None, ): _decode_att_m_fwd( q, @@ -621,6 +675,7 @@ def decode_attention_fwd_normal( sm_scale, page_size, logit_cap, + k_scale, ) _decode_softmax_reducev_fwd( attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits @@ -640,6 +695,7 @@ def decode_attention_fwd_grouped( sm_scale, page_size, logit_cap=0.0, + k_scale=None, ): _decode_grouped_att_m_fwd( q, @@ -652,6 +708,7 @@ def decode_attention_fwd_grouped( sm_scale, page_size, logit_cap, + k_scale, ) _decode_softmax_reducev_fwd( attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits @@ -671,6 +728,7 @@ def decode_attention_fwd( sm_scale, page_size=1, logit_cap=0.0, + k_scale=None, ): assert num_kv_splits == attn_logits.shape[2] kv_group_num = q.shape[1] // v_buffer.shape[-2] @@ -690,6 +748,7 @@ def decode_attention_fwd( sm_scale, page_size, logit_cap, + k_scale, ) else: # GQA/MQA/MLA @@ -706,4 +765,5 @@ def decode_attention_fwd( sm_scale, page_size, logit_cap, + k_scale, )