Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 69 additions & 62 deletions benchmarks/bench_mixed_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,24 @@ def run_bench(
measurements = bench_gpu_time(lambda: wrapper_old.run(q, kv_data))
ms_old = np.median(measurements)

wrapper_persistent = flashinfer.BatchAttention(kv_layout="NHD")
wrapper_persistent.plan(
q_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_blocks, dtype=torch.int32, device=device),
seq_lens.to(device),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim,
page_block_size,
causal=causal,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
o_persistent, _ = wrapper_persistent.run(q, kv_data)
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
ms_persistent = np.mean(measurements_persistent)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the other measurements in this benchmark, it's better to use np.median instead of np.mean. np.median is more robust to outliers, which can be common in performance measurements.

Suggested change
ms_persistent = np.mean(measurements_persistent)
ms_persistent = np.median(measurements_persistent)

Comment on lines +90 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Drop unused persistent output.

Line 90 binds o_persistent, but the value is never read and Ruff emits RUF059. Please discard the binding (for example, call wrapper_persistent.run(q, kv_data) without assignment or bind to _) so the warm-up still happens without leaving an unused variable.

-    o_persistent, _ = wrapper_persistent.run(q, kv_data)
+    wrapper_persistent.run(q, kv_data)
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
o_persistent, _ = wrapper_persistent.run(q, kv_data)
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
ms_persistent = np.mean(measurements_persistent)
wrapper_persistent.run(q, kv_data)
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
ms_persistent = np.mean(measurements_persistent)
🧰 Tools
πŸͺ› Ruff (0.14.2)

90-90: Unpacked variable o_persistent is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

πŸ€– Prompt for AI Agents
In benchmarks/bench_mixed_attention.py around lines 90 to 92, the first call
assigns o_persistent which is never used (RUF059); remove the unused variable by
calling wrapper_persistent.run(q, kv_data) without assignment or assign the
result to _ so the warm-up call still executes but no unused binding remains.

if len(p_kv_lens) == 1:
q_d = q[: d_q_indptr[-1]]
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
Expand Down Expand Up @@ -123,9 +141,46 @@ def run_bench(
)
)
ms_pod = np.median(measurements)

# Sequential two kernels: single prefill + batch decode (tensor cores)
# Prefill using single_prefill_with_kv_cache
def _run_single_prefill():
return flashinfer.prefill.single_prefill_with_kv_cache(
q_p,
k_p,
v_p,
causal=causal,
pos_encoding_mode="NONE",
backend="fa2",
)

measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
ms_prefill = np.median(measurements_prefill)

# Batch decode using tensor cores
wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
)
wrapper_decode.plan(
d_kv_indptr.to(device),
kv_indices_d.to(device),
last_page_len_d,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
data_type=torch.bfloat16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The data_type parameter in BatchDecodeWithPagedKVCacheWrapper.plan is deprecated. Please use kv_data_type instead for clarity and to avoid using deprecated APIs.

Suggested change
data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,

q_data_type=torch.bfloat16,
)
measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
ms_decode = np.median(measurements_decode)
ms_seq_two_kernels = ms_prefill + ms_decode

Comment on lines +145 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Measure sequential path in one benchmarked call.

Lines 158-177 derive ms_seq_two_kernels by summing medians from two completely separate benchmark runs. Because bench_gpu_time synchronizes around each callable, that sum omits the synchronization gap between kernels and hides any stream/data dependency penalties when prefill hands off to decode. As a result, the reported β€œSequential two kernels” latency is optimistic and not directly comparable to the single-call POD/persistent timings. Benchmark the sequential path inside a single callable and use that median instead so the printed number reflects the real pipeline cost.

         measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
         ms_prefill = np.median(measurements_prefill)
 
         # Batch decode using tensor cores
         wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
             workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
         )
@@
         )
-        measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
-        ms_decode = np.median(measurements_decode)
-        ms_seq_two_kernels = ms_prefill + ms_decode
+        measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
+        ms_decode = np.median(measurements_decode)
+
+        def _run_prefill_and_decode():
+            _run_single_prefill()
+            return wrapper_decode.run(q_d, kv_d)
+
+        measurements_seq = bench_gpu_time(_run_prefill_and_decode)
+        ms_seq_two_kernels = np.median(measurements_seq)
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Sequential two kernels: single prefill + batch decode (tensor cores)
# Prefill using single_prefill_with_kv_cache
def _run_single_prefill():
return flashinfer.prefill.single_prefill_with_kv_cache(
q_p,
k_p,
v_p,
causal=causal,
pos_encoding_mode="NONE",
backend="fa2",
)
measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
ms_prefill = np.median(measurements_prefill)
# Batch decode using tensor cores
wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
)
wrapper_decode.plan(
d_kv_indptr.to(device),
kv_indices_d.to(device),
last_page_len_d,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
data_type=torch.bfloat16,
q_data_type=torch.bfloat16,
)
measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
ms_decode = np.median(measurements_decode)
ms_seq_two_kernels = ms_prefill + ms_decode
# Sequential two kernels: single prefill + batch decode (tensor cores)
# Prefill using single_prefill_with_kv_cache
def _run_single_prefill():
return flashinfer.prefill.single_prefill_with_kv_cache(
q_p,
k_p,
v_p,
causal=causal,
pos_encoding_mode="NONE",
backend="fa2",
)
measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
ms_prefill = np.median(measurements_prefill)
# Batch decode using tensor cores
wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
)
wrapper_decode.plan(
d_kv_indptr.to(device),
kv_indices_d.to(device),
last_page_len_d,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
data_type=torch.bfloat16,
q_data_type=torch.bfloat16,
)
measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
ms_decode = np.median(measurements_decode)
def _run_prefill_and_decode():
_run_single_prefill()
return wrapper_decode.run(q_d, kv_d)
measurements_seq = bench_gpu_time(_run_prefill_and_decode)
ms_seq_two_kernels = np.median(measurements_seq)
πŸ€– Prompt for AI Agents
In benchmarks/bench_mixed_attention.py around lines 145 to 178, the sequential
two-kernel latency is computed by summing medians from two separate
bench_gpu_time runs (prefill and decode), which omits inter-kernel
synchronization and handoff cost; instead, wrap the whole sequential sequence
(call single_prefill_with_kv_cache followed immediately by wrapper_decode.run)
in a single callable passed to bench_gpu_time so the synchronization overhead
between kernels is measured, take the median of that single measurement as
ms_seq_two_kernels, and use that value wherever the combined sequential latency
is reported.

print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
if len(p_kv_lens) == 1:
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms")
print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms")
total_bytes = (
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
)
Expand All @@ -137,77 +192,29 @@ def run_bench(
if len(p_kv_lens) == 1:
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
bandwidth_seq_gb_s = total_bytes / (ms_seq_two_kernels * 1e-3) / (1024**3)
print(
f"Memory bandwidth (Sequential two kernels): {bandwidth_seq_gb_s:.2f} GB/s"
)
bandwidth_persistent_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3)
print(
f"Memory bandwidth (Persistent BatchAttention): {bandwidth_persistent_gb_s:.2f} GB/s"
)


if __name__ == "__main__":
np.random.seed(42)
torch.random.manual_seed(42)

# Irregular sequence lengths for prefill and decode
d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256]
d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256]
p_q_configs = [[17] * 1, [10000], [17] * 1, []]
p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []]

# construct random length testcases
for _ in range(1):
bsz = 256
stride = 16
sparsity = 0.05

full_kv_len = np.random.randint(1000, 8192, size=bsz)
p_q_lens = []
p_kv_lens = []
d_q_lens = []
d_kv_lens = []
for i in range(bsz):
if i % stride == 0:
kv_len = full_kv_len[i]
qo_len = stride + 1
p_q_lens.append(qo_len)
p_kv_lens.append(kv_len)
else:
kv_len = int(full_kv_len[i] * sparsity)
qo_len = 1
d_q_lens.append(qo_len)
d_kv_lens.append(kv_len)

p_q_configs.append(p_q_lens)
p_kv_configs.append(p_kv_lens)
d_q_len_configs.append(d_q_lens)
d_kv_len_configs.append(d_kv_lens)

for _ in range(1):
bsz = 128
stride = 16
sparsity = 0.05

full_kv_len = np.random.randint(2000, 16000, size=bsz)
p_q_lens = []
p_kv_lens = []
d_q_lens = []
d_kv_lens = []

for i in range(bsz):
if i % stride == 0:
kv_len = full_kv_len[i]
qo_len = stride + 1
p_q_lens.append(qo_len)
p_kv_lens.append(kv_len)
else:
kv_len = int(full_kv_len[i] * sparsity)
qo_len = 1
d_q_lens.append(qo_len)
d_kv_lens.append(kv_len)

p_q_configs.append(p_q_lens)
p_kv_configs.append(p_kv_lens)
d_q_len_configs.append(d_q_lens)
d_kv_len_configs.append(d_kv_lens)
d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128]
d_kv_len_configs = [[2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128]
p_q_configs = [[2048], [4096], [4096], [6000]]
p_kv_configs = [[2048], [4096], [4096], [7000]]

page_block_size = 1
num_kv_heads = 4
num_qo_heads = 28
num_kv_heads = 8
num_qo_heads = 32
head_dim = 128

for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate(
Expand Down