diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 40108e490740..a8d2fd687fff 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -213,5 +213,5 @@ configuration. | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | +| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | diff --git a/tests/kernels/attention/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py index f6b066a7bd1e..a9b881629441 100644 --- a/tests/kernels/attention/test_triton_decode_attention.py +++ b/tests/kernels/attention/test_triton_decode_attention.py @@ -90,3 +90,137 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ) assert torch.allclose(o, o1) + + +def _quantize_to_fp8(tensor: torch.Tensor): + """Quantize a BF16 tensor to FP8 e4m3fn with per-tensor scale. + + Returns (fp8_tensor, scale) where: + fp8_tensor ≈ tensor / scale (stored as float8_e4m3fn) + tensor ≈ fp8_tensor.to(float32) * scale (dequantized) + """ + amax = tensor.abs().amax() + # float8_e4m3fn max representable value is 448.0 + scale = (amax / 448.0).clamp(min=1e-12).to(torch.float32) + fp8_tensor = ( + (tensor.to(torch.float32) / scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + ) + return fp8_tensor, scale + + +@pytest.mark.parametrize("B", [3]) +@pytest.mark.parametrize("L", [1025]) +@pytest.mark.parametrize("H_Q", [32]) +@pytest.mark.parametrize("H_KV", [32, 8]) +@pytest.mark.parametrize("D_QK", [128, 576]) +@pytest.mark.parametrize("D_V", [128, 512]) +@pytest.mark.parametrize("CACHE_SIZE", [16384]) +@pytest.mark.parametrize("PAGE_SIZE", [1, 16]) +def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): + """Test FP8 KV cache path: quantize K/V to FP8, run kernel with scales, + and compare against BF16 reference output.""" + assert CACHE_SIZE % PAGE_SIZE == 0 + dtype = torch.bfloat16 + seq_len = L + sm_scale = 1.0 / (D_QK**0.5) + num_kv_splits = 8 + + num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) + req_to_page = torch.randint( + 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" + ) + req_to_token = req_to_page * PAGE_SIZE + req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) + req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) + req_to_token = req_to_token.view(B, -1) + req_to_token = req_to_token[:, :seq_len].contiguous() + + q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") + + # Create BF16 K/V as reference + k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") + v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") + + # --- BF16 reference --- + o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + lse_ref = torch.zeros(B, H_Q, dtype=dtype, device="cuda") + attn_logits = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda" + ) + + if PAGE_SIZE == 1: + decode_attention_fwd( + q, + k_bf16, + v_bf16, + o_ref, + lse_ref, + req_to_token, + b_seq_len=torch.full((B,), seq_len, device="cuda"), + attn_logits=attn_logits, + num_kv_splits=num_kv_splits, + sm_scale=sm_scale, + ) + else: + k_paged = k_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) + v_paged = v_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) + decode_attention_fwd( + q, + k_paged, + v_paged, + o_ref, + lse_ref, + req_to_page, + b_seq_len=torch.full((B,), seq_len, device="cuda"), + attn_logits=attn_logits, + num_kv_splits=num_kv_splits, + sm_scale=sm_scale, + page_size=PAGE_SIZE, + ) + + # --- FP8 path --- + k_fp8, k_scale = _quantize_to_fp8(k_bf16) + v_fp8, v_scale = _quantize_to_fp8(v_bf16) + + o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device="cuda") + attn_logits_fp8 = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda" + ) + + if PAGE_SIZE == 1: + decode_attention_fwd( + q, + k_fp8, + v_fp8, + o_fp8, + lse_fp8, + req_to_token, + b_seq_len=torch.full((B,), seq_len, device="cuda"), + attn_logits=attn_logits_fp8, + num_kv_splits=num_kv_splits, + sm_scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + ) + else: + k_fp8_paged = k_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) + v_fp8_paged = v_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) + decode_attention_fwd( + q, + k_fp8_paged, + v_fp8_paged, + o_fp8, + lse_fp8, + req_to_page, + b_seq_len=torch.full((B,), seq_len, device="cuda"), + attn_logits=attn_logits_fp8, + num_kv_splits=num_kv_splits, + sm_scale=sm_scale, + page_size=PAGE_SIZE, + k_scale=k_scale, + v_scale=v_scale, + ) + + # FP8 tolerances match test_mla_backends.py test_backend_correctness. + torch.testing.assert_close(o_ref, o_fp8, atol=5e-1, rtol=1e-2) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 2da2bbd6bb5a..ca9f7452e311 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend): supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "bfloat16", + "fp8", + "fp8_e4m3", ] @classmethod @@ -108,10 +110,11 @@ def __init__( "TritonMLAImpl" ) + # For FP8 KV cache, we dequantize to BF16 on load inside the + # Triton kernel. Tell the common layer not to quantize queries + # to FP8 — we handle FP8 KV cache with BF16 queries (Mode 1). if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "TritonMLA V1 with FP8 KV cache not yet supported" - ) + self.supports_quant_query_input = False def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs @@ -135,9 +138,6 @@ def forward_mqa( assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Triton MLA not yet supported") - if type(q) is tuple: q = torch.cat(q, dim=-1) @@ -171,7 +171,8 @@ def forward_mqa( kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] PAGE_SIZE = kv_c_and_k_pe_cache.size(1) - # Run MQA + # Run MQA — always pass layer scales. When KV cache is + # BF16 the kernel's `if dtype.is_fp8()` check is a no-op. decode_attention_fwd( q, kv_c_and_k_pe_cache, @@ -184,6 +185,8 @@ def forward_mqa( num_kv_splits, self.scale, PAGE_SIZE, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) return o, lse diff --git a/vllm/v1/attention/ops/triton_decode_attention.py b/vllm/v1/attention/ops/triton_decode_attention.py index 1ed9698c507a..63263bc92e24 100644 --- a/vllm/v1/attention/ops/triton_decode_attention.py +++ b/vllm/v1/attention/ops/triton_decode_attention.py @@ -31,6 +31,7 @@ import logging +import torch from packaging import version from vllm.platforms import current_platform @@ -74,6 +75,8 @@ def _fwd_kernel_stage1( stride_mid_ob, stride_mid_oh, stride_mid_os, + k_scale, + v_scale, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, @@ -109,6 +112,8 @@ def _fwd_kernel_stage1( acc = tl.zeros([BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: + ks = tl.load(k_scale) + vs = tl.load(v_scale) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( @@ -129,6 +134,8 @@ def _fwd_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), other=0.0, ) + if k.dtype.is_fp8(): + k = (k.to(tl.float32) * ks).to(q.dtype) qk = tl.sum(q[None, :] * k, 1) qk *= sm_scale @@ -147,6 +154,8 @@ def _fwd_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) + if v.dtype.is_fp8(): + v = (v.to(tl.float32) * vs).to(q.dtype) n_e_max = tl.maximum(tl.max(qk, 0), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -194,6 +203,8 @@ def _decode_att_m_fwd( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ): BLOCK = 64 if not is_hip_ else 8 @@ -231,6 +242,8 @@ def _decode_att_m_fwd( att_out.stride(0), att_out.stride(1), att_out.stride(2), + k_scale, + v_scale, kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DV=BLOCK_DV, @@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1( stride_mid_ob, stride_mid_oh, stride_mid_os, + k_scale, + v_scale, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1( acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: + ks = tl.load(k_scale) + vs = tl.load(v_scale) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( @@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, ) + if k.dtype.is_fp8(): + k = (k.to(tl.float32) * ks).to(q.dtype) qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: offs_buf_kpe = ( @@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, ) + if kpe.dtype.is_fp8(): + kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype) qk += tl.dot(qpe, kpe.to(qpe.dtype)) qk *= sm_scale @@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) + if v.dtype.is_fp8(): + v = (v.to(tl.float32) * vs).to(q.dtype) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ): BLOCK = 32 Lk = k_buffer.shape[-1] @@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd( att_out.stride(0), att_out.stride(1), att_out.stride(2), + k_scale, + v_scale, kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, @@ -609,6 +636,8 @@ def decode_attention_fwd_normal( sm_scale, page_size, logit_cap=0.0, + k_scale=None, + v_scale=None, ): _decode_att_m_fwd( q, @@ -621,6 +650,8 @@ def decode_attention_fwd_normal( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ) _decode_softmax_reducev_fwd( attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits @@ -640,6 +671,8 @@ def decode_attention_fwd_grouped( sm_scale, page_size, logit_cap=0.0, + k_scale=None, + v_scale=None, ): _decode_grouped_att_m_fwd( q, @@ -652,6 +685,8 @@ def decode_attention_fwd_grouped( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ) _decode_softmax_reducev_fwd( attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits @@ -671,8 +706,16 @@ def decode_attention_fwd( sm_scale, page_size=1, logit_cap=0.0, + k_scale=None, + v_scale=None, ): assert num_kv_splits == attn_logits.shape[2] + + if k_scale is None: + k_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device) + if v_scale is None: + v_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device) + kv_group_num = q.shape[1] // v_buffer.shape[-2] if kv_group_num == 1: @@ -690,6 +733,8 @@ def decode_attention_fwd( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ) else: # GQA/MQA/MLA @@ -706,4 +751,6 @@ def decode_attention_fwd( sm_scale, page_size, logit_cap, + k_scale, + v_scale, )