diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index 52452e05a8..80e42d802a 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -14,6 +14,7 @@ limitations under the License. """ +import csv import numpy as np import torch @@ -24,20 +25,34 @@ def bench_fmha_blackwell( batch_size, qkv_len, - num_heads, - head_dim, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, causal, dtype, ): - q = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) - k = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) - v = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) + # if sizeof(dtype) == 1, create randn from half and then convert to dtype + if dtype.itemsize == 1: + q = torch.randn( + batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=torch.half, device="cuda" + ).to(dtype) + k = torch.randn( + batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=torch.half, device="cuda" + ).to(dtype) + v = torch.randn( + batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=torch.half, device="cuda" + ).to(dtype) + else: + q = torch.randn( + batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=dtype, device="cuda" + ) + k = torch.randn( + batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=dtype, device="cuda" + ) + v = torch.randn( + batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=dtype, device="cuda" + ) qo_segment_offsets = ( torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len @@ -53,10 +68,10 @@ def bench_fmha_blackwell( wrapper.plan( qo_segment_offsets, kv_segment_offsets, - num_heads, - num_heads, - head_dim, - head_dim_vo=head_dim, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, causal=causal, q_data_type=dtype, kv_data_type=dtype, @@ -71,50 +86,80 @@ def bench_fmha_blackwell( def flops(ms): if causal: - return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 + return batch_size * qkv_len * qkv_len * num_qo_heads * head_dim_qk * 2 / ms / 1e9 else: - return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 + return batch_size * qkv_len * qkv_len * num_qo_heads * head_dim_qk * 4 / ms / 1e9 - print( - f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s" - ) + tflops = flops(ms) + + return { + "batch_size": batch_size, + "qkv_len": qkv_len, + "num_qo_heads": num_qo_heads, + "num_kv_heads": num_kv_heads, + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, + "causal": causal, + "dtype": str(dtype), + "time_ms": ms, + "tflops": tflops, + } if __name__ == "__main__": - print("\n === head_dim=128 ===") - bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 128, False, torch.bfloat16) - - bench_fmha_blackwell(128, 512, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16) - - print("\n === head_dim=64 ===") - bench_fmha_blackwell(128, 512, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 64, False, torch.bfloat16) - - bench_fmha_blackwell(128, 512, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 64, True, torch.bfloat16) + results = [] + + # Define configurations: (batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name) + # DeepSeek-R1 uses MLA (Multi-head Latent Attention) with 128 heads + # head_dim_qk=192 (128 nope + 64 rope), head_dim_vo=128 + configs = [ + (16, 512, 128, 128, 192, 128, "DeepSeek-R1"), + (8, 1024, 128, 128, 192, 128, "DeepSeek-R1"), + (4, 2048, 128, 128, 192, 128, "DeepSeek-R1"), + (2, 4096, 128, 128, 192, 128, "DeepSeek-R1"), + (1, 8192, 128, 128, 192, 128, "DeepSeek-R1"), + ] + + # Run benchmarks: Causal first, then non-causal + # For each config: bfloat16 then fp8 + for causal in [True, False]: + print(f"\n{'='*80}") + print(f"Running {'CAUSAL' if causal else 'NON-CAUSAL'} benchmarks") + print(f"{'='*80}") + + for batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name in configs: + # Run bfloat16 + print(f"\n[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, BF16") + result_bf16 = bench_fmha_blackwell( + batch_size, qkv_len, num_qo_heads, num_kv_heads, + head_dim_qk, head_dim_vo, causal, torch.bfloat16 + ) + result_bf16["config_name"] = config_name + results.append(result_bf16) + print(f" → {result_bf16['tflops']:.2f} TFLOPs/s, {result_bf16['time_ms']:.3f} ms") + + # Run fp8 + print(f"[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, FP8") + result_fp8 = bench_fmha_blackwell( + batch_size, qkv_len, num_qo_heads, num_kv_heads, + head_dim_qk, head_dim_vo, causal, torch.float8_e4m3fn + ) + result_fp8["config_name"] = config_name + results.append(result_fp8) + speedup = result_fp8['tflops'] / result_bf16['tflops'] + print(f" → {result_fp8['tflops']:.2f} TFLOPs/s, {result_fp8['time_ms']:.3f} ms (speedup: {speedup:.2f}x)") + + # Write results to CSV + csv_filename = "/workspace/logs/fp8_attention_deepseek_benchmark.csv" + fieldnames = ["config_name", "batch_size", "qkv_len", "num_qo_heads", "num_kv_heads", + "head_dim_qk", "head_dim_vo", "causal", "dtype", "time_ms", "tflops"] + + with open(csv_filename, 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for result in results: + writer.writerow(result) + + print(f"\n{'='*80}") + print(f"Results saved to: {csv_filename}") + print(f"{'='*80}") diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index c50116fa7f..d10fe45b8c 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -58,6 +58,11 @@ using tvm::ffi::Optional; using c_type_out = c_type_in; \ return __VA_ARGS__(); \ }); \ + } else { \ + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \ + using c_type_out = nv_bfloat16; \ + return __VA_ARGS__(); \ + }); \ } \ return false; \ }() @@ -80,14 +85,17 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff ffi::TensorView qo_tile_indices, ffi::TensorView qo_head_indices, ffi::TensorView batch_indices, ffi::TensorView o, Optional maybe_lse, int64_t mask_mode_code, - double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len) { + double sm_scale, int64_t max_qo_len) { TVM_FFI_ICHECK_EQ(q.dtype(), k.dtype()); auto scalar_type_in = q.dtype(); auto scalar_type_out = o.dtype(); MaskMode mask_mode = static_cast(mask_mode_code); int total_qo_len = q.size(0); int total_kv_len = k.size(0); + int num_qo_heads = q.size(1); + int num_kv_heads = k.size(1); + int head_dim_qk = q.size(2); + int head_dim_vo = v.size(2); int batch_size = qo_segment_offsets.size(0) - 1; int q_stride_n = q.stride(0); int q_stride_h = q.stride(1); diff --git a/csrc/fmha_cutlass_sm100_binding.cu b/csrc/fmha_cutlass_sm100_binding.cu index ddb3b8d9cd..69fa341b72 100644 --- a/csrc/fmha_cutlass_sm100_binding.cu +++ b/csrc/fmha_cutlass_sm100_binding.cu @@ -22,8 +22,7 @@ void FMHACutlassSM100Run(TensorView workspace_buffer, TensorView q, TensorView k TensorView work_indptr, TensorView qo_tile_indices, TensorView qo_head_indices, TensorView batch_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, double sm_scale, - int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, - int64_t head_dim_vo, int64_t max_qo_len); + int64_t max_qo_len); void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_offsets, TensorView work_indptr, TensorView qo_tile_indices, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 49abe60897..bd6aa78d55 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2906,8 +2906,10 @@ def run( lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" ) if out is None: + # when input dtype is fp8, we need to use bf16 output + out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype out = torch.empty( - q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device + q.shape[:-1] + v.shape[-1:], dtype=out_dtype, device=q.device ) else: check_shape_dtype_device( @@ -3145,12 +3147,14 @@ def fmha_varlen( ) = plan_info if out is None: + # when input dtype is fp8, we need to use bf16 output + out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype out = torch.empty( qo_total_len + max(max_qo_len, 128), num_qo_heads, head_dim_vo, device=q.device, - dtype=q.dtype, + dtype=out_dtype, )[max(max_qo_len, 128) :] if lse is None and return_lse: @@ -3173,10 +3177,6 @@ def fmha_varlen( lse, mask_mode_code, sm_scale, - num_qo_heads, - num_kv_heads, - head_dim_qk, - head_dim_vo, max_qo_len, ) diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 640b876b49..a973117548 100644 --- a/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -64,9 +64,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { using Mask = Mask_; static constexpr int StageCountQ = 2; - static constexpr int StageCountKV = (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64) - ? 2 - : 1; // sizeof(Element_) == 1 ? 2 : 2; + static constexpr int StageCountKV = + (sizeof(Element_) == 1) + ? (get<2>(TileShapeQK{}) == 128 ? 4 : 2) + : (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64 ? 2 : 1); using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; diff --git a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index d6e913a319..ebf38d7347 100644 --- a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -66,8 +66,8 @@ struct Sm100FmhaCtxKernelWarpspecializedSchedule { static const bool kDebugUsingPrintf = false; static const int NumRegsSoftmax = 192; - static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); - static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0); static const int NumRegsEmpty = 24; static const int NumWarps = 16; diff --git a/tests/attention/test_blackwell_fmha.py b/tests/attention/test_blackwell_fmha.py index 298bfa5db4..5b885a8589 100644 --- a/tests/attention/test_blackwell_fmha.py +++ b/tests/attention/test_blackwell_fmha.py @@ -347,6 +347,131 @@ def test_blackwell_cutlass_qo_kv_varlen( torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [1, 2, 9, 12]) +@pytest.mark.parametrize("qo_len", [177, 377]) +@pytest.mark.parametrize("kv_len", [544, 977]) +@pytest.mark.parametrize( + "num_qo_heads,num_kv_heads", + [ + (128, 128), # DeepSeek-R1 MHA (Multi-head Attention for Prefill) + ], +) +@pytest.mark.parametrize( + "head_dim_qk,head_dim_vo,sm_scale", + [ + (192, 128, 1.0 / math.sqrt(192)), # DeepSeek-R1: qk_nope(128) + qk_rope(64) = 192, v=128 + ], +) +@pytest.mark.parametrize("causal", [False, True]) +def test_blackwell_cutlass_fmha_fp8( + batch_size, + qo_len, + kv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + sm_scale, + causal, +): + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + if not is_sm100a_supported(torch.device("cuda")) and not is_sm110a_supported( + torch.device("cuda") + ): + pytest.skip("only SM100A and SM110A are supported on this device") + + torch.manual_seed(42) + dtype_in = torch.float8_e4m3fn + dtype_out = torch.bfloat16 + + # Create FP8 tensors by generating half precision then converting + q = torch.randn( + batch_size * qo_len, num_qo_heads, head_dim_qk, dtype=torch.half, device="cuda" + ).to(dtype_in) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim_qk, dtype=torch.half, device="cuda" + ).to(dtype_in) + v = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim_vo, dtype=torch.half, device="cuda" + ).to(dtype_in) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + kv_layout="NHD", + backend="cutlass", + ) + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=sm_scale, + q_data_type=dtype_in, + kv_data_type=dtype_in, + ) + o, lse = wrapper.run(q, k, v, return_lse=True) + + # Verify output is bfloat16 + assert o.dtype == dtype_out, f"Expected output dtype {dtype_out}, got {o.dtype}" + + gqa_group_ratio = num_qo_heads // num_kv_heads + k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) + v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + + # Reference implementation with FP8 inputs, upcast to float32, output as bfloat16 + qo_len_ref = q.shape[0] // batch_size + kv_len_ref = k_repeated.shape[0] // batch_size + num_qo_heads_ref = q.shape[1] + head_dim_qk_ref = q.shape[2] + head_dim_vo_ref = v_repeated.shape[2] + + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len_ref, num_qo_heads_ref, head_dim_qk_ref).float(), + k_repeated.view(batch_size, kv_len_ref, num_qo_heads_ref, head_dim_qk_ref).float(), + ) + * sm_scale + ) + + if causal: + mask = torch.arange(kv_len_ref - qo_len_ref, kv_len_ref, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len_ref, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len_ref, kv_len_ref, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v_repeated.view(batch_size, kv_len_ref, num_qo_heads_ref, head_dim_vo_ref).float(), + ) + .contiguous() + .view(batch_size * qo_len_ref, num_qo_heads_ref, head_dim_vo_ref) + .to(dtype_out) # Convert to bfloat16 for FP8 output + ) + lse_ref = (lse_ref * math.log2(math.e)).flatten(0, 1) + + # FP8 has lower precision, use relaxed tolerances + torch.testing.assert_close(o, o_ref, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": test_blackwell_cutlass_fmha( 9,