Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 310 additions & 31 deletions benchmarks/bench_hopper_fp8_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np
import torch

Expand All @@ -8,39 +24,65 @@
)


def bench_single_prefill(seq_len, num_heads, causal, head_dim):
num_qo_heads = num_kv_heads = num_heads
q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda")
k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")

sm80_ms, sm90_ms = (
np.median(
bench_gpu_time(
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
q, k, v, causal=causal, backend=backend
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
)
for backend in ["fa2", "fa3"]
def per_head_symmetric_quant(x, quant_dtype):
"""Per-head symmetric quantization to FP8."""
o_min_val, o_max_val = (
(-448.0, 448.0) if quant_dtype == torch.float8_e4m3fn else (-57344, 57344)
)
x_max_val = x.abs().amax(dim=(0, 2)).to(dtype=torch.float32)
s_out = torch.clamp(x_max_val / o_max_val, min=1e-6)
s_out_broadcast = s_out.view(1, -1, 1)
q_x_out = torch.clamp(x / s_out_broadcast, min=o_min_val, max=o_max_val).to(
dtype=quant_dtype
)
return q_x_out, s_out

q = torch.randn(

def bench_fp8_single_prefill(
seq_len, num_heads, causal, head_dim, dtype=torch.float8_e4m3fn
):
"""Benchmark FP8 single prefill attention."""
num_qo_heads = num_kv_heads = num_heads

# Create FP16 tensors first, then quantize
q_fp16 = torch.randn(
seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)
k = torch.randn(
)
k_fp16 = torch.randn(
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)
v = torch.randn(
)
v_fp16 = torch.randn(
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)
)

# Quantize to FP8
q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype)
k_fp8, s_k = per_head_symmetric_quant(k_fp16, dtype)
v_fp8, s_v = per_head_symmetric_quant(v_fp16, dtype)

fp8_sm90_ms = np.median(
# FP16 baseline (fa3)
fp16_ms = np.median(
bench_gpu_time(
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
q, k, v, causal=causal, backend="fa3", o_dtype=torch.half
q_fp16, k_fp16, v_fp16, causal=causal, backend="fa3"
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
)

# FP8 (fa3)
fp8_ms = np.median(
bench_gpu_time(
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
q_fp8,
k_fp8,
v_fp8,
causal=causal,
backend="fa3",
scale_q=s_q,
scale_k=s_k,
scale_v=s_v,
),
dry_run_time_ms=100,
repeat_time_ms=1000,
Expand All @@ -59,7 +101,222 @@ def flops(ms):
)

print(
f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s"
f"bench_fp8_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), "
f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), "
f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), "
f"speedup: {fp16_ms / fp8_ms:.2f}x"
)


def bench_fp8_batch_ragged_prefill(
batch_size, num_heads, seq_len, causal, head_dim, dtype=torch.float8_e4m3fn
):
"""Benchmark FP8 batch ragged prefill attention."""
num_qo_heads = num_kv_heads = num_heads
total_len = batch_size * seq_len

# Create FP16 tensors first
q_fp16 = torch.randn(
total_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
)
k_fp16 = torch.randn(
total_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
)
v_fp16 = torch.randn(
total_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
)

# Quantize to FP8
q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype)
k_fp8, s_k = per_head_symmetric_quant(k_fp16, dtype)
v_fp8, s_v = per_head_symmetric_quant(v_fp16, dtype)

qo_indptr = torch.arange(
0, total_len + 1, seq_len, dtype=torch.int32, device="cuda"
)
kv_indptr = torch.arange(
0, total_len + 1, seq_len, dtype=torch.int32, device="cuda"
)

# FP16 wrapper
fp16_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"),
kv_layout="NHD",
backend="fa3",
)
fp16_wrapper.plan(
qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal
)

# FP8 wrapper
fp8_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"),
kv_layout="NHD",
backend="fa3",
)
fp8_wrapper.plan(
qo_indptr,
kv_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
q_data_type=dtype,
kv_data_type=dtype,
o_data_type=torch.half,
causal=causal,
)

fp16_ms = np.median(
bench_gpu_time(
lambda: fp16_wrapper.run(q_fp16, k_fp16, v_fp16),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
)

fp8_ms = np.median(
bench_gpu_time(
lambda: fp8_wrapper.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
)

def flops(ms):
return attention_tflops_per_sec_with_actual_seq_lens(
torch.full((batch_size,), seq_len),
torch.full((batch_size,), seq_len),
head_dim,
head_dim,
num_qo_heads,
causal,
ms,
)

print(
f"bench_fp8_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), "
f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), "
f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), "
f"speedup: {fp16_ms / fp8_ms:.2f}x"
)


def bench_fp8_batch_paged_prefill(
page_size,
batch_size,
num_heads,
seq_len,
causal,
head_dim,
dtype=torch.float8_e4m3fn,
):
"""Benchmark FP8 batch paged prefill attention."""
num_qo_heads = num_kv_heads = num_heads
total_qo_len = batch_size * seq_len
num_pages = batch_size * seq_len // page_size

# Create FP16 tensors first
q_fp16 = torch.randn(
total_qo_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
)
# Paged KV cache: (num_pages, page_size, num_heads, head_dim)
k_fp16 = torch.randn(
num_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
)
v_fp16 = torch.randn(
num_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
)

# Quantize to FP8
q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype)
# For paged KV, reshape to (total_tokens, num_heads, head_dim) for quantization
k_flat = k_fp16.view(-1, num_kv_heads, head_dim)
v_flat = v_fp16.view(-1, num_kv_heads, head_dim)
k_fp8_flat, s_k = per_head_symmetric_quant(k_flat, dtype)
v_fp8_flat, s_v = per_head_symmetric_quant(v_flat, dtype)
k_fp8 = k_fp8_flat.view(num_pages, page_size, num_kv_heads, head_dim)
v_fp8 = v_fp8_flat.view(num_pages, page_size, num_kv_heads, head_dim)

qo_indptr = torch.arange(
0, total_qo_len + 1, seq_len, dtype=torch.int32, device="cuda"
)
kv_indptr = torch.arange(
0, num_pages + 1, seq_len // page_size, dtype=torch.int32, device="cuda"
)
kv_indices = torch.arange(0, num_pages, dtype=torch.int32, device="cuda")
last_page_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") * page_size

# FP16 wrapper
fp16_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"),
kv_layout="NHD",
backend="fa3",
)
fp16_wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=causal,
)

# FP8 wrapper
fp8_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"),
kv_layout="NHD",
backend="fa3",
)
fp8_wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
q_data_type=dtype,
kv_data_type=dtype,
o_data_type=torch.half,
causal=causal,
)

fp16_ms = np.median(
bench_gpu_time(
lambda: fp16_wrapper.run(q_fp16, (k_fp16, v_fp16)),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
)

fp8_ms = np.median(
bench_gpu_time(
lambda: fp8_wrapper.run(q_fp8, (k_fp8, v_fp8), s_q, s_k, s_v),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
)

def flops(ms):
return attention_tflops_per_sec_with_actual_seq_lens(
torch.full((batch_size,), seq_len),
torch.full((batch_size,), seq_len),
head_dim,
head_dim,
num_qo_heads,
causal,
ms,
)

print(
f"bench_fp8_batch_paged_prefill (page_size={page_size}, batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), "
f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), "
f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), "
f"speedup: {fp16_ms / fp8_ms:.2f}x"
)


Expand All @@ -70,8 +327,30 @@ def flops(ms):
print("Current benchmark targets capability (9, 0). Returning...")
exit()

for seq_len in [4096, 8192, 16384]:
for num_heads in [24, 32]:
for causal in [True, False]:
for head_dim in [64, 128, 256]:
bench_single_prefill(seq_len, num_heads, causal, head_dim)
# Skip single prefill for now due to compilation issues
# print("=" * 80)
# print("FP8 Single Prefill Benchmarks")
# print("=" * 80)
# for head_dim in [128, 256]:
# for seq_len in [1024, 4096, 8192]:
# bench_fp8_single_prefill(seq_len, 32, True, head_dim)

print()
print("=" * 80)
print("FP8 Batch Ragged Prefill Benchmarks")
print("=" * 80)
for head_dim in [128, 256]:
bench_fp8_batch_ragged_prefill(128, 32, 1024, True, head_dim)
bench_fp8_batch_ragged_prefill(64, 32, 2048, True, head_dim)
bench_fp8_batch_ragged_prefill(32, 32, 4096, True, head_dim)
bench_fp8_batch_ragged_prefill(16, 32, 8192, True, head_dim)

print()
print("=" * 80)
print("FP8 Batch Paged Prefill Benchmarks")
print("=" * 80)
for head_dim in [128, 256]:
bench_fp8_batch_paged_prefill(16, 128, 32, 1024, True, head_dim)
bench_fp8_batch_paged_prefill(16, 64, 32, 2048, True, head_dim)
bench_fp8_batch_paged_prefill(16, 32, 32, 4096, True, head_dim)
bench_fp8_batch_paged_prefill(16, 16, 32, 8192, True, head_dim)
16 changes: 15 additions & 1 deletion csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
// TODO: Not implemented yet
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"

namespace flashinfer {

{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}

}; // namespace flashinfer
Comment on lines +1 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Fix namespace closing syntax.

Line 15 uses }; to close the namespace, but namespaces should be closed with just } (no semicolon).

-};  // namespace flashinfer
+}  // namespace flashinfer
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"
namespace flashinfer {
{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}
}; // namespace flashinfer
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"
namespace flashinfer {
{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}
} // namespace flashinfer
πŸ€– Prompt for AI Agents
In csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja around lines 1 to 15,
the namespace is closed using "};" but C++ namespace blocks should be closed
with a plain "}" (no semicolon); remove the trailing semicolon after the closing
brace so the file ends with "}" to correctly close the flashinfer namespace.

Loading