Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
858919f
Add BatchPOD which supports batches of prefills instead of a single p…
Nov 12, 2025
37ae780
Ruff fix
Nov 12, 2025
42424b8
Fix formatting required by pre-commit
Nov 12, 2025
74fefdd
Update benchmarks/bench_mixed_attention.py
AKKamath Nov 12, 2025
e8e0934
Fix bugs and issues with passed params.
Nov 12, 2025
bc11239
Merge branch 'pod_batched_new' of github.com:AKKamath/flashinfer into…
Nov 12, 2025
78f9467
Update benchmarks/bench_mixed_attention.py
AKKamath Nov 12, 2025
5f1e346
Update flashinfer/pod.py
AKKamath Nov 12, 2025
1e004a6
Merge branch 'pod_batched_new' of github.com:AKKamath/flashinfer into…
Nov 12, 2025
3a396f3
'Fix' error where decode is mismatching
Nov 12, 2025
a8123ab
Remove dead code in batch_pod.cuh
Nov 12, 2025
5694da7
Avoid static variable for SM-aware scheduling, and move memory alloc …
Nov 12, 2025
53233ef
Fix pre-commit issues
Nov 12, 2025
b7b4d4c
add threshold
Edenzzzz Nov 12, 2025
f275715
Fixes to docstring and unused params
Nov 12, 2025
db9038a
Remove more dead code
Nov 12, 2025
77d20fd
Merge pull request #1 from Edenzzzz/pod_batched_new
AKKamath Nov 12, 2025
639f447
Clamp lowerbound of max_grid_size to 0, to prevent possible underflow…
Nov 12, 2025
1b19890
Pre-commit
Nov 12, 2025
6e4fedb
Add num_colocated_ctas to BlockSparseAttention
Nov 13, 2025
970c480
Update BatchedPod documentation with demo usage of API.
Nov 13, 2025
64721eb
Replace qo_indptr_host_p.sum() with total_num_rows_p, since qo_indptr…
Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 95 additions & 7 deletions benchmarks/bench_mixed_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,25 @@ def run_bench(
q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32)

seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
d_seq_lens_blocks = (
p_seq_lens_blocks = torch.ceil(
torch.tensor(p_kv_lens, dtype=torch.int32) / page_block_size
).int()
d_seq_lens_blocks = torch.ceil(
torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size
).int()

q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
kv_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0
).int()

p_q_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0
).int()
p_kv_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(p_seq_lens_blocks, 0)], dim=0
).int()

d_q_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0
).int()
Expand All @@ -46,7 +57,7 @@ def run_bench(
device, dtype=torch.bfloat16
)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
workspace_buffer = torch.empty(156 * 1024 * 1024, dtype=torch.uint8, device=device)
kv_layout = "NHD"

wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
Expand Down Expand Up @@ -90,7 +101,67 @@ def run_bench(
o_persistent, _ = wrapper_persistent.run(q, kv_data)
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
ms_persistent = np.mean(measurements_persistent)

# Batched POD Attention
q_d = q[: d_q_indptr[-1]]
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
q_p = q[d_q_indptr[-1] :]
kv_p = kv_data[d_kv_indptr[-1] :].unbind(1)
kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32)
kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32)

last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
wrapper_pod = flashinfer.BatchPODWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout=kv_layout,
)

wrapper_pod.plan(
# Prefill params
p_q_indptr.to(device),
p_kv_indptr.to(device),
kv_indices_p.to(device),
last_page_len_p,
# Decode params
d_q_indptr.to(device),
d_kv_indptr.to(device),
kv_indices_d.to(device),
last_page_len_d,
# Common params
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_block_size,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
o_p_batch, o_d_batch = wrapper_pod.run(
q_p,
kv_p,
q_d,
kv_d,
causal_p=causal,
)
o_batch_pod = torch.cat([o_d_batch, o_p_batch], dim=0)

# Verify output matches
torch.testing.assert_close(
o_batch_pod, o, rtol=4e-3, atol=4e-3, msg="Batch POD-Attention decode mismatch!"
)
measurements = bench_gpu_time(
lambda: wrapper_pod.run(
q_p,
kv_p,
q_d,
kv_d,
causal_p=causal,
)
)
ms_batch_pod = np.median(measurements)

if len(p_kv_lens) == 1:
# Single POD attention
q_d = q[: d_q_indptr[-1]]
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
q_p = q[d_q_indptr[-1] :]
Expand Down Expand Up @@ -127,7 +198,7 @@ def run_bench(
o_pod = torch.cat([o_d, o_p], dim=0)
# Verify output matches
torch.testing.assert_close(
o, o_pod, rtol=1e-3, atol=1e-3, msg="POD-Attention output mismatch!"
o, o_pod, rtol=4e-3, atol=4e-3, msg="POD-Attention output mismatch!"
)
measurements = bench_gpu_time(
lambda: wrapper_pod.run(
Expand Down Expand Up @@ -177,10 +248,15 @@ def _run_single_prefill():
ms_seq_two_kernels = ms_prefill + ms_decode

print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
print(f"Elapsed time (Batched POD Attention): {ms_batch_pod:.2f} ms")
if len(p_kv_lens) == 1:
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms")
print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms")
print(
f"Batch POD speedup over Persistent BatchAttention: {ms_persistent / ms_batch_pod:.2f}x"
)

total_bytes = (
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
)
Expand All @@ -189,6 +265,10 @@ def _run_single_prefill():
bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3)

print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s")
bandwidth_batch_pod_gb_s = total_bytes / (ms_batch_pod * 1e-3) / (1024**3)
print(
f"Memory bandwidth (Batched POD Attention): {bandwidth_batch_pod_gb_s:.2f} GB/s"
)
if len(p_kv_lens) == 1:
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
Expand All @@ -207,10 +287,18 @@ def _run_single_prefill():
torch.random.manual_seed(42)

# Irregular sequence lengths for prefill and decode
d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128]
d_kv_len_configs = [[2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128]
p_q_configs = [[2048], [4096], [4096], [6000]]
p_kv_configs = [[2048], [4096], [4096], [7000]]
d_q_len_configs = [[1] * 128] * 7
d_kv_len_configs = [
[2048] * 128,
[2048] * 128,
[2048] * 128,
[2048] * 128,
[4096] * 128,
[8192] * 128,
[8192] * 128,
]
p_q_configs = [[512], [1536], [2048] * 2, [2048], [4096], [4096], [6000]]
p_kv_configs = [[512], [1536], [2048] * 2, [2048], [4096], [4096], [7000]]

page_block_size = 1
num_kv_heads = 8
Expand Down
Loading