Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
06ef37d
use FA2DetermineCtaTileQ for pod
Edenzzzz May 20, 2025
14e3f13
modify wrapper..
Edenzzzz Jun 6, 2025
22dc539
fix
Edenzzzz Jun 8, 2025
3257899
bench against persistent
Edenzzzz Jun 25, 2025
82f1550
rename xsize to num_qo_tiles
Edenzzzz Jun 27, 2025
06fee31
fix
Edenzzzz Jun 30, 2025
f47f73e
fix
Edenzzzz Jun 30, 2025
c51cecc
fix
Edenzzzz Jun 30, 2025
e73566c
add mixed scheduler
Edenzzzz Jul 3, 2025
1b2d4c0
rename to num_to_merge_qo_len
Edenzzzz Jul 3, 2025
78e1266
add params
Edenzzzz Jul 3, 2025
4979a2a
plan to use one reduction kernel for prefill and decode
Edenzzzz Jul 4, 2025
2102e22
fix
Edenzzzz Jul 4, 2025
fab82ae
use unifed qkv indptr
Edenzzzz Jul 6, 2025
13b6b19
fix
Edenzzzz Jul 6, 2025
7d29232
fix plan func upper call interface
Edenzzzz Jul 6, 2025
106bfdc
rename new_batch_size to real_batch_size
Edenzzzz Jul 7, 2025
5e3e896
concat request_indices
Edenzzzz Jul 7, 2025
ac07253
unifed indices in wrapper.plan
Edenzzzz Jul 8, 2025
eb8f719
fixes
Edenzzzz Jul 8, 2025
e8b266d
fix params
Edenzzzz Jul 8, 2025
560918b
fix some indices and params
Edenzzzz Jul 8, 2025
2105101
update PODWithKVCacheTensorRun args
Edenzzzz Jul 9, 2025
1a82b17
add paged kv params
Edenzzzz Jul 9, 2025
dd80a06
complete PODWithKVCacheTensorRun params
Edenzzzz Jul 9, 2025
0bb164b
share lse
Edenzzzz Jul 10, 2025
32d762b
templaterize CTA_TILE_Q_P
Edenzzzz Jul 10, 2025
870b0b2
update dispatch logic
Edenzzzz Jul 11, 2025
a7eb44b
add pod template inst and .cu gen
Edenzzzz Jul 11, 2025
f57fde0
fixes
Edenzzzz Jul 11, 2025
0832bac
trying to fix template
Edenzzzz Jul 11, 2025
a570bff
fix PODSplitQOKVIndptr param
Edenzzzz Jul 13, 2025
965d073
fix template and vec type errors
Edenzzzz Jul 15, 2025
c2fcc54
use get_pod_module jit in wrapper
Edenzzzz Jul 15, 2025
708d1be
attempt to fix some templates
Edenzzzz Jul 16, 2025
9b9f8f4
fix
Edenzzzz Jul 16, 2025
6f75dae
Fix compilation errors
AKKamath Jul 16, 2025
332e4bf
Various other fixes for POD
AKKamath Jul 16, 2025
d7707f5
update pod test
Edenzzzz Jul 17, 2025
766e8a8
update pod test
Edenzzzz Jul 18, 2025
ac408f4
fix
Edenzzzz Jul 18, 2025
cb36898
fix tests
Edenzzzz Jul 18, 2025
ff0e4c3
remove kv_len_arr_p
Edenzzzz Jul 19, 2025
1a1323d
fix negative kv len bug
Edenzzzz Jul 19, 2025
ab5d4f1
add chunk_size check
Edenzzzz Jul 20, 2025
9927dd2
fix
Edenzzzz Jul 20, 2025
19cad42
remove mistaken causal
Edenzzzz Jul 20, 2025
f39be0b
use int32 for indices
Edenzzzz Jul 21, 2025
875f170
add some checks
Edenzzzz Jul 21, 2025
81fc9a7
fix dispatch logic when prefill request is empty
Edenzzzz Jul 22, 2025
b604637
Merge main Oct 30
Edenzzzz Oct 31, 2025
458bcce
fix 3rdparty
Edenzzzz Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions 3rdparty/googletest
Submodule googletest added at 5a37b5
1 change: 1 addition & 0 deletions 3rdparty/nvbench
Submodule nvbench added at 555d62
388 changes: 388 additions & 0 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,388 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import argparse
from itertools import product
from pathlib import Path
from typing import List

from . import (
generate_aot_default_additional_params_header,
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_pod_inst,
generate_single_decode_inst,
generate_single_prefill_inst,
)


def get_instantiation_cu(args: argparse.Namespace) -> List[str]:
def write_if_different(path: Path, content: str) -> None:
if path.exists() and path.read_text() == content:
return
path.write_text(content)

path: Path = args.path
head_dims: List[int] = args.head_dims
pos_encoding_modes: List[int] = args.pos_encoding_modes
use_fp16_qk_reductions: List[int] = args.use_fp16_qk_reductions
mask_modes: List[int] = args.mask_modes
enable_f16: bool = args.enable_f16
enable_bf16: bool = args.enable_bf16
enable_fp8_e4m3: bool = args.enable_fp8_e4m3
enable_fp8_e5m2: bool = args.enable_fp8_e5m2

path.mkdir(parents=True, exist_ok=True)

write_if_different(
path / "aot_default_additional_params.h",
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),
)

idtypes = ["i32"]
prefill_dtypes = []
decode_dtypes = []
fp16_dtypes = []
fp8_dtypes = []
if enable_f16:
prefill_dtypes.append("f16")
decode_dtypes.append("f16")
fp16_dtypes.append("f16")
if enable_bf16:
prefill_dtypes.append("bf16")
decode_dtypes.append("bf16")
fp16_dtypes.append("bf16")
if enable_fp8_e4m3:
fp8_dtypes.extend(["e4m3"])
decode_dtypes.extend(["e4m3"])
if enable_fp8_e5m2:
fp8_dtypes.extend(["e5m2"])
decode_dtypes.extend(["e5m2"])

single_decode_uris = []
# single decode files
for head_dim, pos_encoding_mode in product(head_dims, pos_encoding_modes):
for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list(
product(fp16_dtypes, fp8_dtypes)
):
dtype_out = dtype_q
fname = f"single_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu"
content = generate_single_decode_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
dtype_q,
dtype_kv,
dtype_out,
)
for use_sliding_window in [True, False]:
for use_logits_soft_cap in [True, False]:
single_decode_uris.append(
f"single_decode_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_out}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
)
write_if_different(path / fname, content)

# batch decode files
batch_decode_uris = []
for (
head_dim,
pos_encoding_mode,
) in product(
head_dims,
pos_encoding_modes,
):
for idtype in idtypes:
for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list(
product(fp16_dtypes, fp8_dtypes)
):
dtype_out = dtype_q
fname = f"batch_paged_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu"
content = generate_batch_paged_decode_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
)
for use_sliding_window in [True, False]:
for use_logits_soft_cap in [True, False]:
batch_decode_uris.append(
f"batch_decode_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_out}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
)
write_if_different(path / fname, content)

# single prefill files
single_prefill_uris = []
for (
head_dim,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
) in product(
head_dims,
pos_encoding_modes,
use_fp16_qk_reductions,
mask_modes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"single_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu"
content = generate_single_prefill_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
)
for use_sliding_window in [True, False]:
for use_logits_soft_cap in [True, False]:
if (
mask_mode == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
single_prefill_uris.append(
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
f"f16qk_{bool(use_fp16_qk_reduction)}"
)
write_if_different(path / fname, content)

# batch prefill files
batch_prefill_uris = []
for (
head_dim,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
idtype,
) in product(
head_dims,
pos_encoding_modes,
use_fp16_qk_reductions,
mask_modes,
idtypes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_batch_paged_prefill_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)

fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_batch_ragged_prefill_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)

for sliding_window in [True, False]:
for logits_soft_cap in [True, False]:
if (
mask_mode == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
batch_prefill_uris.append(
f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{sliding_window}_"
f"use_logits_cap_{logits_soft_cap}_"
f"f16qk_{bool(use_fp16_qk_reduction)}"
)

# POD files
pod_uris = []
for (
head_dim,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode_p,
mask_mode_d,
idtype,
) in product(
head_dims,
pos_encoding_modes,
use_fp16_qk_reductions,
mask_modes,
mask_modes, # mask_mode_d can be different from mask_mode_p
idtypes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"pod_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_maskp_{mask_mode_p}_maskd_{mask_mode_d}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_pod_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode_p,
mask_mode_d,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)

for sliding_window_p in [True, False]:
for sliding_window_d in [True, False]:
for logits_soft_cap_p in [True, False]:
for logits_soft_cap_d in [True, False]:
if (
mask_mode_p == 0 and mask_mode_d == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
pod_uris.append(
f"pod_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_p_{sliding_window_p}_"
f"use_swa_d_{sliding_window_d}_"
f"use_logits_cap_p_{logits_soft_cap_p}_"
f"use_logits_cap_d_{logits_soft_cap_d}_"
f"f16qk_{bool(use_fp16_qk_reduction)}"
)

return (
single_decode_uris
+ batch_decode_uris
+ single_prefill_uris
+ batch_prefill_uris
+ pod_uris
)


if __name__ == "__main__":
parser = argparse.ArgumentParser("Generate cuda files")
parser.add_argument(
"--path", type=Path, required=True, help="Path to the dispatch inc file"
)
parser.add_argument(
"--head_dims", type=int, required=True, nargs="+", help="Head dimensions"
)
parser.add_argument(
"--pos_encoding_modes",
type=int,
required=True,
nargs="+",
help="Position encoding modes",
)
parser.add_argument(
"--use_fp16_qk_reductions",
type=lambda x: (
x if isinstance(x, int) else int(x.lower() == "true" or x.lower() == "on")
),
required=True,
nargs="+",
help="Allow fp16 qk reductions",
)
parser.add_argument(
"--mask_modes",
type=int,
required=True,
nargs="+",
help="Mask modes",
)
parser.add_argument(
"--enable_f16",
type=lambda x: (
x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on")
),
required=True,
nargs="?",
help="Enable fp16",
)
parser.add_argument(
"--enable_bf16",
type=lambda x: (
x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on")
),
required=True,
nargs="?",
help="Enable bf16",
)
parser.add_argument(
"--enable_fp8_e4m3",
type=lambda x: (
x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on")
),
default=True,
nargs="?",
help="Enable fp8_e4m3",
)
parser.add_argument(
"--enable_fp8_e5m2",
type=lambda x: (
x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on")
),
default=True,
nargs="?",
help="Enable fp8_e5m2",
)
args = parser.parse_args()
get_instantiation_cu(args)
Loading
Loading