Skip to content

Commit 82957fc

Browse files
committed
fix aot
Signed-off-by: Qidi Sang <[email protected]>
1 parent e89c7ec commit 82957fc

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

flashinfer/aot.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,25 +363,37 @@ def gen_xqa(
363363
head_grp_size_: List[int],
364364
use_sliding_window_: List[bool],
365365
has_sm90: bool,
366+
has_sm100: bool,
367+
has_sm120: bool,
366368
) -> Iterator[JitSpec]:
367369
"""Generate XQA modules for various configurations."""
368-
if not has_sm90:
370+
if not has_sm90 and not has_sm100 and not has_sm120:
369371
return # XQA requires SM90+
370372

373+
sm_versions = []
374+
if has_sm90:
375+
sm_versions.append(90)
376+
if has_sm100:
377+
sm_versions.append(100)
378+
if has_sm120:
379+
sm_versions.append(120)
380+
371381
for (
372382
fp16_input,
373383
fp8_kv_cache,
374384
token_per_page,
375385
head_size,
376386
head_grp_size,
377387
use_sliding_window,
388+
sm_version,
378389
) in product(
379390
fp16_input_,
380391
fp8_kv_cache_,
381392
token_per_page_,
382393
head_size_,
383394
head_grp_size_,
384395
use_sliding_window_,
396+
sm_versions,
385397
):
386398
# Skip invalid configurations
387399
if head_size % 16 != 0 or head_size > 256 or head_size < 16:
@@ -396,6 +408,7 @@ def gen_xqa(
396408
head_size=head_size,
397409
head_grp_size=head_grp_size,
398410
use_sliding_window=use_sliding_window,
411+
sm_version=sm_version,
399412
)
400413

401414

@@ -527,6 +540,8 @@ def gen_all_modules(
527540
xqa_head_grp_size_,
528541
use_sliding_window_,
529542
has_sm90,
543+
has_sm100,
544+
has_sm120,
530545
)
531546
)
532547

flashinfer/jit/xqa.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
sm100a_nvcc_flags,
2323
sm120a_nvcc_flags,
2424
)
25-
from ..utils import get_compute_capability
26-
import torch
2725

2826
xqa_nvcc_flags = [
2927
"-DNDEBUG=1",
@@ -42,6 +40,7 @@ def gen_xqa_module(
4240
head_size: int,
4341
head_grp_size: int,
4442
use_sliding_window: bool,
43+
sm_version: int = 90,
4544
) -> JitSpec:
4645
if fp16_input:
4746
flag_data_type = ["-DINPUT_FP16=1", "-DDTYPE=__half"]
@@ -72,15 +71,15 @@ def gen_xqa_module(
7271
else:
7372
flag_sliding_window = ["-DSLIDING_WINDOW=0"]
7473

75-
if get_compute_capability(torch.device(device="cuda"))[0] == 10:
74+
if sm_version == 100:
7675
sm_nvcc_flags = sm100a_nvcc_flags
77-
elif get_compute_capability(torch.device(device="cuda"))[0] == 12:
76+
elif sm_version == 120:
7877
sm_nvcc_flags = sm120a_nvcc_flags
7978
else:
8079
sm_nvcc_flags = sm90a_nvcc_flags
8180

8281
return gen_jit_spec(
83-
f"xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}_sm_{get_compute_capability(torch.device(device='cuda'))[0]}0",
82+
f"xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}_sm_{sm_version}",
8483
[
8584
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
8685
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu",

flashinfer/xqa.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def get_xqa_module(
3535
head_size: int,
3636
head_grp_size: int,
3737
use_sliding_window: bool,
38+
sm_version: int = 90,
3839
):
3940
module = gen_xqa_module(
4041
fp16_input,
@@ -43,10 +44,11 @@ def get_xqa_module(
4344
head_size,
4445
head_grp_size,
4546
use_sliding_window,
47+
sm_version,
4648
).build_and_load()
4749

4850
@register_custom_op(
49-
f"flashinfer::xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}",
51+
f"flashinfer::xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}_sm_{sm_version}",
5052
mutates_args=("output", "scratch"),
5153
)
5254
def xqa(
@@ -87,7 +89,7 @@ def xqa(
8789
)
8890

8991
@register_fake_op(
90-
f"flashinfer::xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}"
92+
f"flashinfer::xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}_sm_{sm_version}"
9193
)
9294
def _fake_xqa(
9395
run_fp8_mha: bool,
@@ -140,13 +142,15 @@ def xqa(
140142
) -> None:
141143
if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]:
142144
raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs")
145+
sm_version = int(get_compute_capability(torch.device(device="cuda"))[0] * 10)
143146
xqa_module = get_xqa_module(
144147
fp16_input,
145148
fp8_kv_cache,
146149
token_per_page,
147150
head_size,
148151
head_grp_size,
149152
use_sliding_window,
153+
sm_version,
150154
)
151155
xqa_module.xqa(
152156
run_fp8_mha,

0 commit comments

Comments
 (0)