Skip to content
Closed
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
06ef37d
use FA2DetermineCtaTileQ for pod
Edenzzzz May 20, 2025
14e3f13
modify wrapper..
Edenzzzz Jun 6, 2025
22dc539
fix
Edenzzzz Jun 8, 2025
3257899
bench against persistent
Edenzzzz Jun 25, 2025
82f1550
rename xsize to num_qo_tiles
Edenzzzz Jun 27, 2025
06fee31
fix
Edenzzzz Jun 30, 2025
f47f73e
fix
Edenzzzz Jun 30, 2025
c51cecc
fix
Edenzzzz Jun 30, 2025
e73566c
add mixed scheduler
Edenzzzz Jul 3, 2025
1b2d4c0
rename to num_to_merge_qo_len
Edenzzzz Jul 3, 2025
78e1266
add params
Edenzzzz Jul 3, 2025
4979a2a
plan to use one reduction kernel for prefill and decode
Edenzzzz Jul 4, 2025
2102e22
fix
Edenzzzz Jul 4, 2025
fab82ae
use unifed qkv indptr
Edenzzzz Jul 6, 2025
13b6b19
fix
Edenzzzz Jul 6, 2025
7d29232
fix plan func upper call interface
Edenzzzz Jul 6, 2025
106bfdc
rename new_batch_size to real_batch_size
Edenzzzz Jul 7, 2025
5e3e896
concat request_indices
Edenzzzz Jul 7, 2025
ac07253
unifed indices in wrapper.plan
Edenzzzz Jul 8, 2025
eb8f719
fixes
Edenzzzz Jul 8, 2025
e8b266d
fix params
Edenzzzz Jul 8, 2025
560918b
fix some indices and params
Edenzzzz Jul 8, 2025
2105101
update PODWithKVCacheTensorRun args
Edenzzzz Jul 9, 2025
1a82b17
add paged kv params
Edenzzzz Jul 9, 2025
dd80a06
complete PODWithKVCacheTensorRun params
Edenzzzz Jul 9, 2025
0bb164b
share lse
Edenzzzz Jul 10, 2025
32d762b
templaterize CTA_TILE_Q_P
Edenzzzz Jul 10, 2025
870b0b2
update dispatch logic
Edenzzzz Jul 11, 2025
a7eb44b
add pod template inst and .cu gen
Edenzzzz Jul 11, 2025
f57fde0
fixes
Edenzzzz Jul 11, 2025
0832bac
trying to fix template
Edenzzzz Jul 11, 2025
a570bff
fix PODSplitQOKVIndptr param
Edenzzzz Jul 13, 2025
965d073
fix template and vec type errors
Edenzzzz Jul 15, 2025
c2fcc54
use get_pod_module jit in wrapper
Edenzzzz Jul 15, 2025
708d1be
attempt to fix some templates
Edenzzzz Jul 16, 2025
9b9f8f4
fix
Edenzzzz Jul 16, 2025
6f75dae
Fix compilation errors
AKKamath Jul 16, 2025
332e4bf
Various other fixes for POD
AKKamath Jul 16, 2025
d7707f5
update pod test
Edenzzzz Jul 17, 2025
766e8a8
update pod test
Edenzzzz Jul 18, 2025
ac408f4
fix
Edenzzzz Jul 18, 2025
cb36898
fix tests
Edenzzzz Jul 18, 2025
ff0e4c3
remove kv_len_arr_p
Edenzzzz Jul 19, 2025
1a1323d
fix negative kv len bug
Edenzzzz Jul 19, 2025
ab5d4f1
add chunk_size check
Edenzzzz Jul 20, 2025
9927dd2
fix
Edenzzzz Jul 20, 2025
19cad42
remove mistaken causal
Edenzzzz Jul 20, 2025
f39be0b
use int32 for indices
Edenzzzz Jul 21, 2025
875f170
add some checks
Edenzzzz Jul 21, 2025
81fc9a7
fix dispatch logic when prefill request is empty
Edenzzzz Jul 22, 2025
b604637
Merge main Oct 30
Edenzzzz Oct 31, 2025
458bcce
fix 3rdparty
Edenzzzz Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_pod_inst,
generate_single_decode_inst,
generate_single_prefill_inst,
)
Expand Down Expand Up @@ -250,11 +251,69 @@ def write_if_different(path: Path, content: str) -> None:
f"f16qk_{bool(use_fp16_qk_reduction)}"
)

# POD files
pod_uris = []
for (
head_dim,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode_p,
mask_mode_d,
idtype,
) in product(
head_dims,
pos_encoding_modes,
use_fp16_qk_reductions,
mask_modes,
mask_modes, # mask_mode_d can be different from mask_mode_p
idtypes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"pod_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_maskp_{mask_mode_p}_maskd_{mask_mode_d}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_pod_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode_p,
mask_mode_d,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)

for sliding_window_p in [True, False]:
for sliding_window_d in [True, False]:
for logits_soft_cap_p in [True, False]:
for logits_soft_cap_d in [True, False]:
if (
mask_mode_p == 0 and mask_mode_d == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
pod_uris.append(
f"pod_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_p_{sliding_window_p}_"
f"use_swa_d_{sliding_window_d}_"
f"use_logits_cap_p_{logits_soft_cap_p}_"
f"use_logits_cap_d_{logits_soft_cap_d}_"
f"f16qk_{bool(use_fp16_qk_reduction)}"
)

return (
single_decode_uris
+ batch_decode_uris
+ single_prefill_uris
+ batch_prefill_uris
+ pod_uris
)


Expand Down
121 changes: 121 additions & 0 deletions aot_build_utils/generate_pod_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode_p,
mask_mode_d,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):
cta_tile_q_choice = [128, 64, 16]
cta_tile_q_d = 16

def get_insts(attention_variant_p, attention_variant_d, dtype_out):
return "\n".join(
[
"""template cudaError_t PODWithKVCacheTensorDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode_p}, {cta_tile_q_p}, {cta_tile_q_d}, {mask_mode_d}, {attention_variant_p}, {attention_variant_d}, PrefillParams, DecodeParams>(
PrefillParams prefill_params, DecodeParams decode_params,
{dtype_out}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
""".format(
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
use_fp16_qk_reduction=use_fp16_qk_reduction,
mask_mode_p=mask_mode_literal[int(mask_mode_p)],
cta_tile_q_p=cta_tile_q_p,
cta_tile_q_d=cta_tile_q_d,
mask_mode_d=mask_mode_literal[int(mask_mode_d)],
attention_variant_p=attention_variant_p,
attention_variant_d=attention_variant_d,
dtype_out=dtype_out,
)
for cta_tile_q_p in cta_tile_q_choice
]
)

use_custom_mask_p = "true" if int(mask_mode_p) == 2 else "false"
use_custom_mask_d = "true" if int(mask_mode_d) == 2 else "false"
dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/pod.cuh>
#include "pod_config.inc"

namespace flashinfer {{

constexpr auto use_custom_mask_p = MaskMode::kNone == MaskMode::kCustom;
constexpr auto use_custom_mask_d = MaskMode::kNone == MaskMode::kCustom;

using PrefillParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}>;
using DecodeParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>;

using AttentionVariant1_P = DefaultAttention<use_custom_mask_p, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;
using AttentionVariant1_D = DefaultAttention<use_custom_mask_d, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;

{get_insts("AttentionVariant1_P", "AttentionVariant1_D", dtype_out)}

using AttentionVariant2_P = DefaultAttention<use_custom_mask_p, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;
using AttentionVariant2_D = DefaultAttention<use_custom_mask_d, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;

{get_insts("AttentionVariant2_P", "AttentionVariant2_D", dtype_out)}

using AttentionVariant3_P = DefaultAttention<use_custom_mask_p, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;
using AttentionVariant3_D = DefaultAttention<use_custom_mask_d, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;

{get_insts("AttentionVariant3_P", "AttentionVariant3_D", dtype_out)}

using AttentionVariant4_P = DefaultAttention<use_custom_mask_p, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;
using AttentionVariant4_D = DefaultAttention<use_custom_mask_d, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;

{get_insts("AttentionVariant4_P", "AttentionVariant4_D", dtype_out)}

}}"""
return content


if __name__ == "__main__":
pattern = (
r"pod_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_maskp_([0-9]+)_maskd_([0-9]+)_"
r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)

with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,67 @@ def run_bench(
device=0,
causal=True,
):
# POD Attention only supports page size = 1 due to use of single prefill kernel
page_block_size = 1
# if page size > 1, prefill kv len must be divisible by page size to ensure
# an identical workload as in BatchAttention
page_size = 1
seq_lens = torch.tensor(d_kv_lens + p_kv_lens, dtype=torch.int32)
q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32)

seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
d_seq_lens_blocks = (
torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size
).int()
seq_lens_blocks = torch.ceil(seq_lens / page_size).int()
p_seq_lens = torch.tensor(p_kv_lens, dtype=torch.int32) / page_size
d_seq_lens = (torch.tensor(d_kv_lens, dtype=torch.int32) / page_size).int()

q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
# General params
qo_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
kv_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0
).int()
d_q_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0
).int()
d_kv_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(d_seq_lens_blocks, 0)], dim=0
).int()
num_blocks = kv_indptr[-1].item()

q = torch.rand(q_indptr[-1].item(), num_qo_heads, head_dim).to(
num_pages = kv_indptr[-1].item()
q = torch.rand(qo_indptr[-1].item(), num_qo_heads, head_dim).to(
device, dtype=torch.bfloat16
)
kv_data = torch.randn(num_blocks, 2, page_block_size, num_kv_heads, head_dim).to(
kv_data = torch.randn(num_pages, 2, page_size, num_kv_heads, head_dim).to(
device, dtype=torch.bfloat16
)

# Prefill params
seq_lens_blocks_p = torch.ceil(
torch.tensor(p_kv_lens, dtype=torch.int32) / page_size
).int()
qo_indptr_p = torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0
).int()
kv_indptr_p = torch.cat(
[torch.tensor([0]), torch.cumsum(p_seq_lens, 0)], dim=0
).int()
num_pages_p = seq_lens_blocks_p[-1].item()
kv_indices_p = torch.arange(num_pages_p, device=device, dtype=torch.int32)
last_page_len_p = (p_seq_lens - 1) % page_size + 1

# Decode params

qo_indptr_d = torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0
).int()
kv_indptr_d = torch.cat(
[torch.tensor([0]), torch.cumsum(d_seq_lens, 0)], dim=0
).int()
num_pages_d = kv_indptr_d[-1].item()
kv_indices_d = torch.arange(num_pages_d, device=device, dtype=torch.int32)
last_page_len_d = (d_seq_lens - 1) % page_size + 1

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
kv_layout = "NHD"

wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout=kv_layout,
backend="fa2",
)
last_page_len = (seq_lens - 1) % page_block_size + 1
last_page_len = (seq_lens - 1) % page_size + 1
wrapper_old.plan(
q_indptr.to(device),
qo_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_blocks).int().to(device),
torch.arange(num_pages, dtype=torch.int32, device=device),
last_page_len,
num_qo_heads,
num_kv_heads,
Expand All @@ -72,28 +91,28 @@ def run_bench(
ms_old = do_bench(lambda: wrapper_old.run(q, kv_data))

if len(p_kv_lens) == 1:
q_d = q[: d_q_indptr[-1]]
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
q_p = q[d_q_indptr[-1] :]
k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(1)
k_p, v_p = k_p.squeeze(1), v_p.squeeze(1)
kv_indices_d = torch.arange(
0, d_kv_indptr[-1], device=device, dtype=torch.int32
)
q_d = q[: qo_indptr_d[-1]]
q_p = q[qo_indptr_d[-1] :]
kv_indices_d = torch.arange(0, num_pages_d, device=device, dtype=torch.int32)

last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
last_page_len_d = (d_seq_lens - 1) % page_size + 1
wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout=kv_layout,
)
wrapper_pod.plan(
d_kv_indptr.to(device),
qo_indptr_p.to(device),
kv_indptr_p.to(device),
kv_indices_p.to(device),
last_page_len_p,
qo_indptr_d.to(device),
kv_indptr_d.to(device),
kv_indices_d.to(device),
last_page_len=last_page_len_d,
last_page_len_d,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_block_size,
page_size=page_size,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
Expand All @@ -113,29 +132,47 @@ def run_bench(
ms_pod = do_bench(
lambda: wrapper_pod.run(
q_p,
k_p,
v_p,
q_d,
kv_d,
paged_kv_cache=kv_data,
causal_p=causal,
causal_d=causal,
)
)
# Persistent attention
wrapper = flashinfer.BatchAttention(kv_layout="NHD")
wrapper.plan(
qo_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_pages, dtype=torch.int32, device=device),
seq_lens.to(device),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim,
page_size,
causal=causal,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
ms_persistent = do_bench(lambda: wrapper.run(q, kv_data))

print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
if len(p_kv_lens) == 1:
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
total_bytes = (
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
)
print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
if len(p_kv_lens) == 1:
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
print(f"Elapsed time (Persistent Attention): {ms_persistent:.2f} ms")

print(f"Loading memory size (MB): {total_bytes / (1024**2):.2f} MB")

bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3)

bandwidth_new_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3)
print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s")
if len(p_kv_lens) == 1:
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
print(f"Memory bandwidth (Persistent Attention): {bandwidth_new_gb_s:.2f} GB/s")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion csrc/batch_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_indptr_offset);
params[i].merge_o_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_o_indices_offset);
params[i].num_packed_qo_len =
params[i].num_to_merge_qo_len =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.num_qo_len_offset);

params[i].num_kv_heads = num_kv_heads;
Expand Down
Loading