Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,5 @@ configuration.
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
134 changes: 134 additions & 0 deletions tests/kernels/attention/test_triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,137 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
)

assert torch.allclose(o, o1)


def _quantize_to_fp8(tensor: torch.Tensor):
"""Quantize a BF16 tensor to FP8 e4m3fn with per-tensor scale.

Returns (fp8_tensor, scale) where:
fp8_tensor ≈ tensor / scale (stored as float8_e4m3fn)
tensor ≈ fp8_tensor.to(float32) * scale (dequantized)
"""
amax = tensor.abs().amax()
# float8_e4m3fn max representable value is 448.0
scale = (amax / 448.0).clamp(min=1e-12).to(torch.float32)
fp8_tensor = (
(tensor.to(torch.float32) / scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn)
)
return fp8_tensor, scale


@pytest.mark.parametrize("B", [3])
@pytest.mark.parametrize("L", [1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
"""Test FP8 KV cache path: quantize K/V to FP8, run kernel with scales,
and compare against BF16 reference output."""
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8

num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint(
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
)
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()

q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")

# Create BF16 K/V as reference
k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")

# --- BF16 reference ---
o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
lse_ref = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda"
)

if PAGE_SIZE == 1:
decode_attention_fwd(
q,
k_bf16,
v_bf16,
o_ref,
lse_ref,
req_to_token,
b_seq_len=torch.full((B,), seq_len, device="cuda"),
attn_logits=attn_logits,
num_kv_splits=num_kv_splits,
sm_scale=sm_scale,
)
else:
k_paged = k_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_paged = v_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
decode_attention_fwd(
q,
k_paged,
v_paged,
o_ref,
lse_ref,
req_to_page,
b_seq_len=torch.full((B,), seq_len, device="cuda"),
attn_logits=attn_logits,
num_kv_splits=num_kv_splits,
sm_scale=sm_scale,
page_size=PAGE_SIZE,
)

# --- FP8 path ---
k_fp8, k_scale = _quantize_to_fp8(k_bf16)
v_fp8, v_scale = _quantize_to_fp8(v_bf16)

o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
attn_logits_fp8 = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda"
)

if PAGE_SIZE == 1:
decode_attention_fwd(
q,
k_fp8,
v_fp8,
o_fp8,
lse_fp8,
req_to_token,
b_seq_len=torch.full((B,), seq_len, device="cuda"),
attn_logits=attn_logits_fp8,
num_kv_splits=num_kv_splits,
sm_scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
)
else:
k_fp8_paged = k_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_fp8_paged = v_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
decode_attention_fwd(
q,
k_fp8_paged,
v_fp8_paged,
o_fp8,
lse_fp8,
req_to_page,
b_seq_len=torch.full((B,), seq_len, device="cuda"),
attn_logits=attn_logits_fp8,
num_kv_splits=num_kv_splits,
sm_scale=sm_scale,
page_size=PAGE_SIZE,
k_scale=k_scale,
v_scale=v_scale,
)

# FP8 tolerances match test_mla_backends.py test_backend_correctness.
torch.testing.assert_close(o_ref, o_fp8, atol=5e-1, rtol=1e-2)
17 changes: 10 additions & 7 deletions vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend):
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]

@classmethod
Expand Down Expand Up @@ -108,10 +110,11 @@ def __init__(
"TritonMLAImpl"
)

# For FP8 KV cache, we dequantize to BF16 on load inside the
# Triton kernel. Tell the common layer not to quantize queries
# to FP8 — we handle FP8 KV cache with BF16 queries (Mode 1).
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported"
)
self.supports_quant_query_input = False

def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
Expand All @@ -135,9 +138,6 @@ def forward_mqa(
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")

if type(q) is tuple:
q = torch.cat(q, dim=-1)

Expand Down Expand Up @@ -171,7 +171,8 @@ def forward_mqa(
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)

# Run MQA
# Run MQA — always pass layer scales. When KV cache is
# BF16 the kernel's `if dtype.is_fp8()` check is a no-op.
decode_attention_fwd(
q,
kv_c_and_k_pe_cache,
Expand All @@ -184,6 +185,8 @@ def forward_mqa(
num_kv_splits,
self.scale,
PAGE_SIZE,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
)

return o, lse
47 changes: 47 additions & 0 deletions vllm/v1/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import logging

import torch
from packaging import version

from vllm.platforms import current_platform
Expand Down Expand Up @@ -74,6 +75,8 @@ def _fwd_kernel_stage1(
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
k_scale,
v_scale,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
Expand Down Expand Up @@ -109,6 +112,8 @@ def _fwd_kernel_stage1(
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)

if split_kv_end > split_kv_start:
ks = tl.load(k_scale)
vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Expand All @@ -129,6 +134,8 @@ def _fwd_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
other=0.0,
)
if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale

Expand All @@ -147,6 +154,8 @@ def _fwd_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)

n_e_max = tl.maximum(tl.max(qk, 0), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand Down Expand Up @@ -194,6 +203,8 @@ def _decode_att_m_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
):
BLOCK = 64 if not is_hip_ else 8

Expand Down Expand Up @@ -231,6 +242,8 @@ def _decode_att_m_fwd(
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
k_scale,
v_scale,
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
Expand Down Expand Up @@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1(
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
k_scale,
v_scale,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
Expand Down Expand Up @@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1(
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)

if split_kv_end > split_kv_start:
ks = tl.load(k_scale)
vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Expand All @@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
Expand All @@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
)
if kpe.dtype.is_fp8():
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale

Expand All @@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand Down Expand Up @@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
Expand Down Expand Up @@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd(
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
k_scale,
v_scale,
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
Expand Down Expand Up @@ -609,6 +636,8 @@ def decode_attention_fwd_normal(
sm_scale,
page_size,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
_decode_att_m_fwd(
q,
Expand All @@ -621,6 +650,8 @@ def decode_attention_fwd_normal(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
Expand All @@ -640,6 +671,8 @@ def decode_attention_fwd_grouped(
sm_scale,
page_size,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
_decode_grouped_att_m_fwd(
q,
Expand All @@ -652,6 +685,8 @@ def decode_attention_fwd_grouped(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
Expand All @@ -671,8 +706,16 @@ def decode_attention_fwd(
sm_scale,
page_size=1,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
assert num_kv_splits == attn_logits.shape[2]

if k_scale is None:
k_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
if v_scale is None:
v_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)

kv_group_num = q.shape[1] // v_buffer.shape[-2]

if kv_group_num == 1:
Expand All @@ -690,6 +733,8 @@ def decode_attention_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
else:
# GQA/MQA/MLA
Expand All @@ -706,4 +751,6 @@ def decode_attention_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
Loading