diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index bb6962b791..ed3d2b8d0f 100755 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -1751,7 +1751,7 @@ def gen_fmha_cutlass_sm100a_module( ] nvcc_flags = current_compilation_context.get_nvcc_flags_list( - supported_major_versions=[10, 11, 12] + supported_major_versions=[10, 11] ) return gen_jit_spec( uri, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index ab6881b786..8c18d68041 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -53,7 +53,6 @@ canonicalize_torch_dtype, determine_attention_backend, device_support_pdl, - get_compute_capability, get_device_sm_count, is_float8, is_sm100a_supported, @@ -100,11 +99,7 @@ def get_fmha_module( device: torch.device, use_fp16_qk_reduction: bool = False, ): - if ( - is_sm100a_supported(device) - or is_sm110a_supported(device) - or is_sm12x_supported(device) - ): + if is_sm100a_supported(device) or is_sm110a_supported(device): return gen_fmha_cutlass_sm100a_module( dtype_q, dtype_kv, @@ -117,15 +112,10 @@ def get_fmha_module( use_logits_soft_cap, ).build_and_load() else: - major, minor = get_compute_capability(device) - if major == 12: - min_cuda = "13.0" if minor >= 1 else "12.8" - raise ValueError( - f"SM12x detected but CUDA version is too old; " - f"SM12{minor}x requires CUDA >= {min_cuda}." - ) raise ValueError( - "This device is not supported; requires SM100a, SM110a, or SM12x." + "CUTLASS FMHA requires SM100a (B200/GB200) or SM110a. " + "SM12x (RTX 5090/DGX Spark) lacks tcgen05 MMA required by this kernel. " + "Use backend='fa2' instead." ) @@ -3838,13 +3828,6 @@ def fmha_v2_prefill_deepseek( If return_lse is False, the output will be a single tensor. """ if not is_sm12x_supported(query.device): - major, minor = get_compute_capability(query.device) - if major == 12: - min_cuda = "13.0" if minor >= 1 else "12.8" - raise ValueError( - f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} " - f"for SM12{minor}x GPUs." - ) raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.") assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, ( "currently only support deepseek r1 192 query and 128 value"