Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,4 @@ configuration.
| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
31 changes: 20 additions & 11 deletions vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd

Expand All @@ -31,6 +31,8 @@ class TritonMLABackend(MLACommonBackend):
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]

@staticmethod
Expand Down Expand Up @@ -93,11 +95,6 @@ def __init__(
"TritonMLAImpl"
)

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported"
)

def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
Expand All @@ -120,19 +117,24 @@ def forward_mqa(
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
# Determine if we're using FP8 KV cache
use_fp8 = self.kv_cache_dtype.startswith("fp8")

if type(q) is tuple:
q = torch.cat(q, dim=-1)

assert isinstance(q, torch.Tensor)
B = q.shape[0]
q_num_heads = q.shape[1]
# Cast FP8 q to bfloat16 for Blackwell compatibility (FP8 tl.dot
# may produce invalid instructions) and for V up-proj torch.bmm.
if use_fp8 and q.dtype not in (torch.float16, torch.bfloat16):
q = q.to(torch.bfloat16)
out_dtype = q.dtype
o = torch.zeros(
B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
B, q_num_heads, self.kv_lora_rank, dtype=out_dtype, device=q.device
)
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
lse = torch.zeros(B, q_num_heads, dtype=out_dtype, device=q.device)

# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if vllm_is_batch_invariant() else 4
Expand All @@ -151,12 +153,18 @@ def forward_mqa(
device=q.device,
)

# View as FP8 and pass scale for in-kernel dequantization.
if use_fp8:
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(current_platform.fp8_dtype())
k_scale = layer._k_scale
else:
k_scale = None

# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)

# Run MQA
decode_attention_fwd(
q,
kv_c_and_k_pe_cache,
Expand All @@ -169,6 +177,7 @@ def forward_mqa(
num_kv_splits,
self.scale,
PAGE_SIZE,
k_scale=k_scale,
)

return o, lse
62 changes: 61 additions & 1 deletion vllm/v1/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# Changes:
# - Add support for page size >= 1.
# - Add support for FP8 quantized KV cache.

# Copyright 2025 vLLM Team
# Copyright 2023-2024 SGLang Team
Expand All @@ -27,10 +28,12 @@
"""
Memory-efficient attention for decoding.
It supports page size >= 1.
It supports FP8 quantized KV cache with on-the-fly dequantization.
"""

import logging

import torch
from packaging import version

from vllm.platforms import current_platform
Expand Down Expand Up @@ -64,6 +67,7 @@ def _fwd_kernel_stage1(
Req_to_tokens,
B_Seqlen,
Att_Out,
K_scale, # FP8 dequantization scale for K/V cache
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
Expand All @@ -83,6 +87,7 @@ def _fwd_kernel_stage1(
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
USE_FP8: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
Expand Down Expand Up @@ -129,6 +134,11 @@ def _fwd_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
other=0.0,
)
# Dequantize FP8 KV cache on-the-fly
if USE_FP8:
k_scale = tl.load(K_scale)
k = k.to(tl.float32) * k_scale

qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale

Expand All @@ -147,6 +157,9 @@ def _fwd_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
# Dequantize FP8 KV cache on-the-fly
if USE_FP8:
v = v.to(tl.float32) * k_scale

n_e_max = tl.maximum(tl.max(qk, 0), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand Down Expand Up @@ -194,6 +207,7 @@ def _decode_att_m_fwd(
sm_scale,
page_size,
logit_cap,
k_scale=None,
):
BLOCK = 64 if not is_hip_ else 8

Expand All @@ -213,6 +227,15 @@ def _decode_att_m_fwd(
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)

# Determine if we're using FP8 KV cache
use_fp8 = k_scale is not None

# Create a dummy scale tensor if not using FP8 (Triton requires valid tensor)
if k_scale is None:
k_scale = torch.empty(1, dtype=torch.float32, device=q.device)

num_stages = 1 if use_fp8 else 2

_fwd_kernel_stage1[grid](
q,
k_buffer,
Expand All @@ -221,6 +244,7 @@ def _decode_att_m_fwd(
Req_to_tokens,
B_Seqlen,
att_out,
k_scale,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
Expand All @@ -239,9 +263,10 @@ def _decode_att_m_fwd(
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=2,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
USE_FP8=use_fp8,
)


Expand All @@ -254,6 +279,7 @@ def _fwd_grouped_kernel_stage1(
Req_to_tokens,
B_Seqlen,
Att_Out,
K_scale, # FP8 dequantization scale for K/V cache
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
Expand All @@ -276,6 +302,7 @@ def _fwd_grouped_kernel_stage1(
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
USE_FP8: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
Expand Down Expand Up @@ -336,6 +363,11 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
# Dequantize FP8 KV cache on-the-fly
if USE_FP8:
k_scale = tl.load(K_scale)
k = k.to(tl.float32) * k_scale

qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
Expand All @@ -348,6 +380,9 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
)
# Dequantize FP8 KV cache on-the-fly
if USE_FP8:
kpe = kpe.to(tl.float32) * k_scale
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale

Expand All @@ -368,6 +403,9 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
# Dequantize FP8 KV cache on-the-fly
if USE_FP8:
v = v.to(tl.float32) * k_scale

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand Down Expand Up @@ -416,6 +454,7 @@ def _decode_grouped_att_m_fwd(
sm_scale,
page_size,
logit_cap,
k_scale=None,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
Expand Down Expand Up @@ -455,6 +494,18 @@ def _decode_grouped_att_m_fwd(
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
num_stages = 1

# Determine if we're using FP8 KV cache
use_fp8 = k_scale is not None

# Reduce pipeline stages for FP8 to avoid shared memory overflow
# from float32 intermediates during dequantization.
if use_fp8:
num_stages = 1

# Create a dummy scale tensor if not using FP8 (Triton requires valid tensor)
if k_scale is None:
k_scale = torch.empty(1, dtype=torch.float32, device=q.device)

_fwd_grouped_kernel_stage1[grid](
q,
k_buffer,
Expand All @@ -463,6 +514,7 @@ def _decode_grouped_att_m_fwd(
Req_to_tokens,
B_Seqlen,
att_out,
k_scale,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
Expand All @@ -487,6 +539,7 @@ def _decode_grouped_att_m_fwd(
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
USE_FP8=use_fp8,
**extra_kargs,
)

Expand Down Expand Up @@ -609,6 +662,7 @@ def decode_attention_fwd_normal(
sm_scale,
page_size,
logit_cap=0.0,
k_scale=None,
):
_decode_att_m_fwd(
q,
Expand All @@ -621,6 +675,7 @@ def decode_attention_fwd_normal(
sm_scale,
page_size,
logit_cap,
k_scale,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
Expand All @@ -640,6 +695,7 @@ def decode_attention_fwd_grouped(
sm_scale,
page_size,
logit_cap=0.0,
k_scale=None,
):
_decode_grouped_att_m_fwd(
q,
Expand All @@ -652,6 +708,7 @@ def decode_attention_fwd_grouped(
sm_scale,
page_size,
logit_cap,
k_scale,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
Expand All @@ -671,6 +728,7 @@ def decode_attention_fwd(
sm_scale,
page_size=1,
logit_cap=0.0,
k_scale=None,
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
Expand All @@ -690,6 +748,7 @@ def decode_attention_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
)
else:
# GQA/MQA/MLA
Expand All @@ -706,4 +765,5 @@ def decode_attention_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
)