From bbfd40752ce05ea97118fc593be86aeb0395b910 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Thu, 4 Dec 2025 19:43:26 +0000 Subject: [PATCH 1/2] ampere and hopper flags and cleanup --- .pre-commit-config.yaml | 0 flashinfer/jit/attention/fmha_v2/generator_utils.py | 4 ++-- flashinfer/jit/attention/modules.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) mode change 100644 => 100755 .pre-commit-config.yaml mode change 100644 => 100755 flashinfer/jit/attention/modules.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/flashinfer/jit/attention/fmha_v2/generator_utils.py b/flashinfer/jit/attention/fmha_v2/generator_utils.py index 82bab5311f..439261dca6 100755 --- a/flashinfer/jit/attention/fmha_v2/generator_utils.py +++ b/flashinfer/jit/attention/fmha_v2/generator_utils.py @@ -3711,10 +3711,10 @@ def generate_files(specs_names): ] if "CUDA_PATH" in os.environ: cmd[0] = os.environ["CUDA_PATH"] + "/bin/" + cmd[0] - print('Running command "{}" to build "bin/print_traits.exe":'.format(" ".join(cmd))) + # print('Running command "{}" to build "bin/print_traits.exe":'.format(" ".join(cmd))) process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE) output, error = process.communicate() - print('Running "bin/print_traits.exe":') + # print('Running "bin/print_traits.exe":') process = subprocess.Popen( "bin/print_traits.exe", stdin=subprocess.PIPE, stdout=subprocess.PIPE ) diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py old mode 100644 new mode 100755 index 3f01d0aca5..ea1a4ce16b --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -1901,9 +1901,10 @@ def gen_trtllm_fmha_v2_module() -> JitSpec: source_paths = kernel_paths + [binding_source_path] nvcc_flags = current_compilation_context.get_nvcc_flags_list( - supported_major_versions=[10, 11, 12] + supported_major_versions=[8, 9, 10, 11, 12] ) nvcc_flags.append(f"-I{jit_env.FLASHINFER_CSRC_DIR / 'fmha_v2'}") + nvcc_flags.append("-Wno-deprecated-gpu-targets") return gen_jit_spec( uri, From 1e7fe36b8a4e3633ead2f0a68c8ef529b214369c Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Thu, 4 Dec 2025 22:12:57 +0000 Subject: [PATCH 2/2] only sm120 --- flashinfer/jit/attention/modules.py | 2 +- flashinfer/prefill.py | 2 ++ tests/attention/test_fmha_v2_prefill_deepseek.py | 3 +++ 3 files changed, 6 insertions(+), 1 deletion(-) mode change 100644 => 100755 tests/attention/test_fmha_v2_prefill_deepseek.py diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index ea1a4ce16b..64b2c794a6 100755 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -1901,7 +1901,7 @@ def gen_trtllm_fmha_v2_module() -> JitSpec: source_paths = kernel_paths + [binding_source_path] nvcc_flags = current_compilation_context.get_nvcc_flags_list( - supported_major_versions=[8, 9, 10, 11, 12] + supported_major_versions=[12] ) nvcc_flags.append(f"-I{jit_env.FLASHINFER_CSRC_DIR / 'fmha_v2'}") nvcc_flags.append("-Wno-deprecated-gpu-targets") diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 5b8140ec48..ab33923728 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3603,6 +3603,8 @@ def fmha_v2_prefill_deepseek( If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor. If return_lse is False, the output will be a single tensor. """ + if not is_sm120a_supported(query.device): + raise ValueError("fmha_v2_prefill_deepseek is only supported on SM120 GPUs.") assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, ( "currently only support deepseek r1 192 query and 128 value" ) diff --git a/tests/attention/test_fmha_v2_prefill_deepseek.py b/tests/attention/test_fmha_v2_prefill_deepseek.py old mode 100644 new mode 100755 index ebb08efa4d..2dad1355ce --- a/tests/attention/test_fmha_v2_prefill_deepseek.py +++ b/tests/attention/test_fmha_v2_prefill_deepseek.py @@ -5,6 +5,7 @@ from flashinfer.prefill import fmha_v2_prefill_deepseek from tests.utils_fp8 import to_float8 +from flashinfer.utils import is_sm120a_supported def attention_ref( @@ -56,6 +57,8 @@ def attention_ref( def test_fmha_v2_prefill_deepseek( batch_size, num_heads, head_dim_qk, head_dim_v, seq_len, qkv_dtype, o_dtype ): + if not is_sm120a_supported(torch.device("cuda")): + pytest.skip("fmha_v2_prefill_deepseek is only supported on SM120 GPUs.") torch.manual_seed(42) def initialize_tensors(batch_size, num_heads, head_dim_qk, head_dim_v, seq_len):