Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions benchmarks/bench_flash_attention_fp8_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""FA4 fused FP8 output vs. BF16 attn + torch.compile'd post-quant cast.

Requires SM100/SM110 (Blackwell). FA4 ships no FP8 quant op, so the unfused
baseline uses torch.compile to fuse the divide+clamp+cast into one kernel.

python benchmarks/bench_flash_attention_fp8_output.py [--shape ...] [--rep N]
"""

import argparse
import time

import torch
from triton.testing import do_bench

from flash_attn.cute.bench_utils import flops
from flash_attn.cute.interface import flash_attn_func


# (name, batch, seqlen_q, seqlen_k, num_heads, num_kv_heads, head_dim, head_dim_v, causal)
# Naming convention: <mode>_<attn>_<seqlen>, where:
# mode = prefill (sq == sk) | decode (sq == 1, sk large)
# attn = mla (qk=192, v=128) | mha (h_q == h_kv) | gqa (h_q > h_kv)
# seqlen = the K-side context length
SHAPES = {
# DeepSeek-V3 MLA prefill — the primary target of this PR.
"prefill_mla_4k": (2, 4096, 4096, 16, 1, 192, 128, True),
# Standard MHA prefill, 4K context.
"prefill_mha_4k": (2, 4096, 4096, 32, 32, 128, 128, True),
# Llama-style GQA prefill (8:1 ratio), 8K context.
"prefill_gqa_8k": (2, 8192, 8192, 32, 4, 128, 128, True),
# GQA decode (sq=1, h=16/1), 8K context — common decode shape.
"decode_gqa_8k": (16, 1, 8192, 16, 1, 128, 128, True),
# MHA decode, 8K context.
"decode_mha_8k": (16, 1, 8192, 16, 16, 128, 128, True),
}


def static_fp8_quant_eager(out_bf16: torch.Tensor, inv_scale: float) -> torch.Tensor:
"""Stand-in for vLLM's `static_scaled_fp8_quant`."""
finfo = torch.finfo(torch.float8_e4m3fn)
return out_bf16.float().mul(inv_scale).clamp(finfo.min, finfo.max).to(torch.float8_e4m3fn)



_static_fp8_quant_compiled = torch.compile(static_fp8_quant_eager, mode="reduce-overhead")


def bench_one(name, shape, warmup, rep):
batch, sq, sk, nh, nkv, dq, dv, causal = shape
device = torch.device("cuda")
dtype = torch.bfloat16

q = torch.randn(batch, sq, nh, dq, dtype=dtype, device=device)
k = torch.randn(batch, sk, nkv, dq, dtype=dtype, device=device)
v = torch.randn(batch, sk, nkv, dv, dtype=dtype, device=device)

# Pick a representative scale (peak of one BF16 forward).
ref_out, _ = flash_attn_func(q, k, v, causal=causal)
finfo = torch.finfo(torch.float8_e4m3fn)
out_scale = max(float(ref_out.float().abs().amax().item()) / finfo.max, 1e-4)
inv_scale = 1.0 / out_scale
out_scale_t = torch.tensor(out_scale, dtype=torch.float32, device=device)

fp8_buf = torch.empty(batch, sq, nh, dv, dtype=torch.float8_e4m3fn, device=device)

def fwd_bf16():
return flash_attn_func(q, k, v, causal=causal)

def fwd_bf16_then_quant():
out, _ = flash_attn_func(q, k, v, causal=causal)
return _static_fp8_quant_compiled(out, inv_scale)

def fwd_fp8_fused():
return flash_attn_func(
q, k, v, causal=causal,
out=fp8_buf,
output_scale=out_scale_t,
)

time.sleep(1.0)
ms_bf16 = do_bench(fwd_bf16, warmup=warmup, rep=rep) * 1e-3
time.sleep(1.0)
ms_unfused = do_bench(fwd_bf16_then_quant, warmup=warmup, rep=rep) * 1e-3
time.sleep(1.0)
ms_fused = do_bench(fwd_fp8_fused, warmup=warmup, rep=rep) * 1e-3

n_flops = flops(batch, nh, sq, sk, dq, dv, causal=causal)
def tflops(s): return n_flops / s * 1e-12

saved = ms_unfused - ms_fused
speedup = ms_unfused / ms_fused

print(
f"{name:<14} b={batch} sq={sq:>5} sk={sk:>5} h={nh:>3}/{nkv:<3} d={dq}-{dv:<3} "
f"bf16={ms_bf16*1e6:>7.1f}us/{tflops(ms_bf16):>4.0f}TF "
f"bf16+quant={ms_unfused*1e6:>7.1f}us/{tflops(ms_unfused):>4.0f}TF "
f"fused-fp8={ms_fused*1e6:>7.1f}us/{tflops(ms_fused):>4.0f}TF "
f"saved={saved*1e6:>+6.1f}us ({speedup:.2f}x)"
)


def main():
parser = argparse.ArgumentParser(description="FA4 fused FP8 output benchmark")
parser.add_argument("--shape", action="append", choices=list(SHAPES) + ["all"],
default=None, help="Shape preset to run (repeatable). Default: all.")
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument("--rep", type=int, default=10)
args = parser.parse_args()

cap = torch.cuda.get_device_capability()
if cap[0] not in (10, 11):
raise SystemExit(
f"Fused FP8 output requires SM100/SM110 (Blackwell). "
f"Detected sm{cap[0]}{cap[1]}; aborting."
)

shapes = list(SHAPES) if not args.shape or "all" in args.shape else args.shape
print(f"Device: {torch.cuda.get_device_name(0)}")
print(f"Warmup={args.warmup}, rep={args.rep}\n")

for name in shapes:
torch.cuda.empty_cache()
bench_one(name, SHAPES[name], args.warmup, args.rep)


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
mask_mod: Optional[cutlass.Constexpr] = None,
has_aux_tensors: bool = False,
q_subtile_factor: int | None = None,
output_quant_key: Optional[cutlass.Constexpr[str]] = None,
):
"""Initializes the configuration for a flash attention kernel.

Expand All @@ -75,6 +76,10 @@ def __init__(
Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any``
:param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked.
Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean``
:param output_quant_key: compile-time tag represents specialized fused quant output in epilogue.
Used as a tag in compile_key. Inspired by quant keys in vLLM and derived from output scales args.
Supported: ``"kFp8StaticTensorSym"`` (per-tensor static FP8 e4m3fn)
TODO: ``"kFp8Dynamic128Sym"``, ``"kFp8Dynamic64Sym"``, ``"kNvfp4Dynamic"``
"""
self.dtype = dtype
# padding head_dim to a multiple of 16 as k_block_size
Expand All @@ -98,6 +103,7 @@ def __init__(
self.Q_in_regs = Q_in_regs
self.score_mod = score_mod
self.mask_mod = mask_mod
self.output_quant_key = output_quant_key
self.qk_acc_dtype = Float32
self.vec_size: cutlass.Constexpr = getattr(
score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
Expand Down Expand Up @@ -604,6 +610,12 @@ def load_V(


class FlashAttentionForwardSm80(FlashAttentionForwardBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.output_quant_key is None, (
f"Fused quant output not implemented for {type(self).__name__}"
)

def _get_smem_layout_atom(self):
sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim)
sK_layout_atom = sQ_layout_atom
Expand Down Expand Up @@ -665,6 +677,7 @@ def __call__(
learnable_sink: Optional[cute.Tensor] = None,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
aux_tensors=None,
output_scale: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -762,6 +775,7 @@ def __call__(
TileScheduler,
aux_tensors,
fastdiv_mods,
output_scale,
).launch(
grid=grid_dim,
block=[self.num_threads, 1, 1],
Expand Down Expand Up @@ -801,6 +815,7 @@ def kernel(
TileScheduler: cutlass.Constexpr[Callable],
aux_tensors=None,
fastdiv_mods=None,
output_scale: Optional[cute.Tensor] = None,
):
# Thread index, block index
tidx, _, _ = cute.arch.thread_idx()
Expand Down
21 changes: 19 additions & 2 deletions flash_attn/cute/flash_fwd_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
log_max_splits: int = 4,
num_threads: int = 256,
stages: int = 4,
output_quant_key: Optional[cutlass.Constexpr[str]] = None,
):
"""
Forward combine kernel for split attention computation.
Expand All @@ -42,6 +43,8 @@ def __init__(
:param num_threads: number of threads
:param varlen: whether using variable length sequences
:param stages: number of pipeline stages
:param output_quant_key: compile-time tag for fused quant output,
see FlashAttentionForwardBase for more details.
"""
self.dtype = dtype
self.dtype_partial = dtype_partial
Expand All @@ -52,6 +55,7 @@ def __init__(
self.num_threads = num_threads
self.is_even_k = head_dim % k_block_size == 0
self.stages = stages
self.output_quant_key = output_quant_key

@staticmethod
def can_implement(
Expand All @@ -64,7 +68,9 @@ def can_implement(
num_threads,
) -> bool:
"""Check if the kernel can be implemented with the given parameters."""
if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
if dtype not in [
cutlass.Float16, cutlass.BFloat16, cutlass.Float32, cutlass.Float8E4M3FN,
]:
return False
if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]:
return False
Expand Down Expand Up @@ -199,6 +205,7 @@ def __call__(
num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
varlen_batch_idx: Optional[cute.Tensor] = None,
semaphore_to_reset: Optional[cute.Tensor] = None,
output_scale: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -313,6 +320,7 @@ class SharedStorage:
seqlen_divmod,
head_divmod,
varlen,
output_scale,
).launch(
grid=grid_dim,
block=[self.num_threads, 1, 1],
Expand Down Expand Up @@ -342,11 +350,16 @@ def kernel(
seqlen_divmod: FastDivmodDivisor,
head_divmod: FastDivmodDivisor,
varlen: cutlass.Constexpr[bool],
output_scale: Optional[cute.Tensor] = None,
):
# Thread and block indices
tidx, _, _ = cute.arch.thread_idx()
m_block, k_block, maybe_virtual_batch = cute.arch.block_idx()

# Load FP8 output scale and invert in-kernel.
if const_expr(self.output_quant_key == "kFp8StaticTensorSym"):
output_scale_inv = Float32(1.0) / Float32(output_scale[0])

# Map virtual batch index to real batch index (for persistent tile schedulers)
batch_idx = (
varlen_batch_idx[maybe_virtual_batch]
Expand Down Expand Up @@ -643,7 +656,11 @@ def kernel(
# ===============================

rO = cute.make_rmem_tensor_like(tOrO, self.dtype)
rO.store(tOrO.load().to(self.dtype))
# Fold per-tensor output scale into the cast (fused FP8 out).
if const_expr(self.output_quant_key == "kFp8StaticTensorSym"):
rO.store((tOrO.load() * output_scale_inv).to(self.dtype))
else:
rO.store(tOrO.load().to(self.dtype))
mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3)
if const_expr(cu_seqlens is None):
mO_cur = mO[None, None, None, batch_idx]
Expand Down
17 changes: 16 additions & 1 deletion flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def __init__(
is_varlen_q: bool = False,
use_2cta_instrs: bool = False,
use_clc_scheduler: bool = False,
# Derived tag for fused quant output, also see FlashAttentionForwardBase
output_quant_key: cutlass.Constexpr[str] | None = None,
):
self.use_tma_KV = not paged_kv_non_tma
# self.dtype = dtype
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(
)
self.score_mod = score_mod
self.mask_mod = mask_mod
self.output_quant_key = output_quant_key
self.vec_size: cutlass.Constexpr = getattr(
score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
)
Expand Down Expand Up @@ -374,6 +377,7 @@ def __call__(
descale_tensors: Optional[DescaleTensors] = None,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
aux_tensors: Optional[list] = None,
output_scale: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -760,6 +764,7 @@ class SharedStorage:
aux_tensors,
fastdiv_mods,
head_divmod,
output_scale,
).launch(
grid=grid_dim,
block=[self.threads_per_cta, 1, 1],
Expand Down Expand Up @@ -807,6 +812,7 @@ def kernel(
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
head_divmod=None,
output_scale: Optional[cute.Tensor] = None,
):
"""The device kernel implementation of the Fused Multi-Head Attention.

Expand Down Expand Up @@ -1284,7 +1290,8 @@ def kernel(
num_splits,
SeqlenInfoCls,
blocksparse_tensors,
tile_scheduler=tile_scheduler,
tile_scheduler,
output_scale,
)
tmem_alloc_barrier.arrive()

Expand Down Expand Up @@ -2345,6 +2352,7 @@ def correction_loop(
SeqlenInfoCls: Callable,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
tile_scheduler=None,
output_scale: Optional[cute.Tensor] = None,
):
tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids))
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
Expand All @@ -2365,6 +2373,10 @@ def correction_loop(
tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(self.q_stage)]
tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape

# Load FP8 output scale and invert in-kernel.
if const_expr(self.output_quant_key == "kFp8StaticTensorSym"):
output_scale_inv = Float32(1.0) / Float32(output_scale[0])

# First iter: no correction is required
# Notify mma warp that O has been rescaled
for stage in cutlass.range(self.q_stage):
Expand Down Expand Up @@ -2500,6 +2512,9 @@ def correction_loop(
stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0)
scale = scale * v_descale
# Fold per-tensor output scale into the existing per-row scale.
if const_expr(self.output_quant_key == "kFp8StaticTensorSym"):
scale = scale * output_scale_inv
# Wait for the last O to be ready from the MMA warp
pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase)
if const_expr(not self.use_correction_warps_for_epi):
Expand Down
7 changes: 6 additions & 1 deletion flash_attn/cute/flash_fwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
assert self.output_quant_key is None, (
f"Fused quant output not implemented for {type(self).__name__}"
)
self.intra_wg_overlap = intra_wg_overlap
self.mma_pv_is_rs = mma_pv_is_rs
self.buffer_align_bytes = 1024
Expand Down Expand Up @@ -179,6 +182,7 @@ def __call__(
learnable_sink: Optional[cute.Tensor] = None,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
aux_tensors: Optional[list] = None,
output_scale: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand All @@ -187,7 +191,6 @@ def __call__(
mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
(batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
"""

self._check_type(
*(
t.element_type if t is not None else None
Expand Down Expand Up @@ -411,6 +414,7 @@ def __call__(
num_splits,
aux_tensors,
fastdiv_mods,
output_scale,
).launch(
grid=grid_dim,
block=[self.num_threads, 1, 1],
Expand Down Expand Up @@ -458,6 +462,7 @@ def kernel(
num_splits: Int32 = Int32(1),
aux_tensors=Optional[list[cute.Tensor]],
fastdiv_mods=None,
output_scale: Optional[cute.Tensor] = None,
):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
# Prefetch tma descriptor
Expand Down
Loading