File tree Expand file tree Collapse file tree 3 files changed +26
-8
lines changed Expand file tree Collapse file tree 3 files changed +26
-8
lines changed Original file line number Diff line number Diff line change @@ -363,25 +363,37 @@ def gen_xqa(
363363 head_grp_size_ : List [int ],
364364 use_sliding_window_ : List [bool ],
365365 has_sm90 : bool ,
366+ has_sm100 : bool ,
367+ has_sm120 : bool ,
366368) -> Iterator [JitSpec ]:
367369 """Generate XQA modules for various configurations."""
368- if not has_sm90 :
370+ if not has_sm90 and not has_sm100 and not has_sm120 :
369371 return # XQA requires SM90+
370372
373+ sm_versions = []
374+ if has_sm90 :
375+ sm_versions .append (90 )
376+ if has_sm100 :
377+ sm_versions .append (100 )
378+ if has_sm120 :
379+ sm_versions .append (120 )
380+
371381 for (
372382 fp16_input ,
373383 fp8_kv_cache ,
374384 token_per_page ,
375385 head_size ,
376386 head_grp_size ,
377387 use_sliding_window ,
388+ sm_version ,
378389 ) in product (
379390 fp16_input_ ,
380391 fp8_kv_cache_ ,
381392 token_per_page_ ,
382393 head_size_ ,
383394 head_grp_size_ ,
384395 use_sliding_window_ ,
396+ sm_versions ,
385397 ):
386398 # Skip invalid configurations
387399 if head_size % 16 != 0 or head_size > 256 or head_size < 16 :
@@ -396,6 +408,7 @@ def gen_xqa(
396408 head_size = head_size ,
397409 head_grp_size = head_grp_size ,
398410 use_sliding_window = use_sliding_window ,
411+ sm_version = sm_version ,
399412 )
400413
401414
@@ -527,6 +540,8 @@ def gen_all_modules(
527540 xqa_head_grp_size_ ,
528541 use_sliding_window_ ,
529542 has_sm90 ,
543+ has_sm100 ,
544+ has_sm120 ,
530545 )
531546 )
532547
Original file line number Diff line number Diff line change 2222 sm100a_nvcc_flags ,
2323 sm120a_nvcc_flags ,
2424)
25- from ..utils import get_compute_capability
26- import torch
2725
2826xqa_nvcc_flags = [
2927 "-DNDEBUG=1" ,
@@ -42,6 +40,7 @@ def gen_xqa_module(
4240 head_size : int ,
4341 head_grp_size : int ,
4442 use_sliding_window : bool ,
43+ sm_version : int = 90 ,
4544) -> JitSpec :
4645 if fp16_input :
4746 flag_data_type = ["-DINPUT_FP16=1" , "-DDTYPE=__half" ]
@@ -72,15 +71,15 @@ def gen_xqa_module(
7271 else :
7372 flag_sliding_window = ["-DSLIDING_WINDOW=0" ]
7473
75- if get_compute_capability ( torch . device ( device = "cuda" ))[ 0 ] == 10 :
74+ if sm_version == 100 :
7675 sm_nvcc_flags = sm100a_nvcc_flags
77- elif get_compute_capability ( torch . device ( device = "cuda" ))[ 0 ] == 12 :
76+ elif sm_version == 120 :
7877 sm_nvcc_flags = sm120a_nvcc_flags
7978 else :
8079 sm_nvcc_flags = sm90a_nvcc_flags
8180
8281 return gen_jit_spec (
83- f"xqa_fp16_input_{ fp16_input } _fp8_kv_cache_{ fp8_kv_cache } _token_per_page_{ token_per_page } _head_size_{ head_size } _head_grp_size_{ head_grp_size } _use_sliding_window_{ use_sliding_window } _sm_{ get_compute_capability ( torch . device ( device = 'cuda' ))[ 0 ] } 0 " ,
82+ f"xqa_fp16_input_{ fp16_input } _fp8_kv_cache_{ fp8_kv_cache } _token_per_page_{ token_per_page } _head_size_{ head_size } _head_grp_size_{ head_grp_size } _use_sliding_window_{ use_sliding_window } _sm_{ sm_version } " ,
8483 [
8584 jit_env .FLASHINFER_CSRC_DIR / "xqa/mha.cu" ,
8685 jit_env .FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu" ,
Original file line number Diff line number Diff line change @@ -35,6 +35,7 @@ def get_xqa_module(
3535 head_size : int ,
3636 head_grp_size : int ,
3737 use_sliding_window : bool ,
38+ sm_version : int = 90 ,
3839):
3940 module = gen_xqa_module (
4041 fp16_input ,
@@ -43,10 +44,11 @@ def get_xqa_module(
4344 head_size ,
4445 head_grp_size ,
4546 use_sliding_window ,
47+ sm_version ,
4648 ).build_and_load ()
4749
4850 @register_custom_op (
49- f"flashinfer::xqa_fp16_input_{ fp16_input } _fp8_kv_cache_{ fp8_kv_cache } _token_per_page_{ token_per_page } _head_size_{ head_size } _head_grp_size_{ head_grp_size } _use_sliding_window_{ use_sliding_window } " ,
51+ f"flashinfer::xqa_fp16_input_{ fp16_input } _fp8_kv_cache_{ fp8_kv_cache } _token_per_page_{ token_per_page } _head_size_{ head_size } _head_grp_size_{ head_grp_size } _use_sliding_window_{ use_sliding_window } _sm_ { sm_version } " ,
5052 mutates_args = ("output" , "scratch" ),
5153 )
5254 def xqa (
@@ -87,7 +89,7 @@ def xqa(
8789 )
8890
8991 @register_fake_op (
90- f"flashinfer::xqa_fp16_input_{ fp16_input } _fp8_kv_cache_{ fp8_kv_cache } _token_per_page_{ token_per_page } _head_size_{ head_size } _head_grp_size_{ head_grp_size } _use_sliding_window_{ use_sliding_window } "
92+ f"flashinfer::xqa_fp16_input_{ fp16_input } _fp8_kv_cache_{ fp8_kv_cache } _token_per_page_{ token_per_page } _head_size_{ head_size } _head_grp_size_{ head_grp_size } _use_sliding_window_{ use_sliding_window } _sm_ { sm_version } "
9193 )
9294 def _fake_xqa (
9395 run_fp8_mha : bool ,
@@ -140,13 +142,15 @@ def xqa(
140142) -> None :
141143 if get_compute_capability (torch .device (device = "cuda" ))[0 ] not in [9 , 10 , 12 ]:
142144 raise RuntimeError ("XQA is only supported on SM90, SM100, SM120 GPUs" )
145+ sm_version = int (get_compute_capability (torch .device (device = "cuda" ))[0 ] * 10 )
143146 xqa_module = get_xqa_module (
144147 fp16_input ,
145148 fp8_kv_cache ,
146149 token_per_page ,
147150 head_size ,
148151 head_grp_size ,
149152 use_sliding_window ,
153+ sm_version ,
150154 )
151155 xqa_module .xqa (
152156 run_fp8_mha ,
You can’t perform that action at this time.
0 commit comments