Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flashinfer/jit/attention/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,7 +1751,7 @@ def gen_fmha_cutlass_sm100a_module(
]

nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10, 11, 12]
supported_major_versions=[10, 11]
)
return gen_jit_spec(
uri,
Expand Down
25 changes: 4 additions & 21 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
canonicalize_torch_dtype,
determine_attention_backend,
device_support_pdl,
get_compute_capability,
get_device_sm_count,
is_float8,
is_sm100a_supported,
Expand Down Expand Up @@ -100,11 +99,7 @@ def get_fmha_module(
device: torch.device,
use_fp16_qk_reduction: bool = False,
):
if (
is_sm100a_supported(device)
or is_sm110a_supported(device)
or is_sm12x_supported(device)
):
if is_sm100a_supported(device) or is_sm110a_supported(device):
return gen_fmha_cutlass_sm100a_module(
dtype_q,
dtype_kv,
Expand All @@ -117,15 +112,10 @@ def get_fmha_module(
use_logits_soft_cap,
).build_and_load()
else:
major, minor = get_compute_capability(device)
if major == 12:
min_cuda = "13.0" if minor >= 1 else "12.8"
raise ValueError(
f"SM12x detected but CUDA version is too old; "
f"SM12{minor}x requires CUDA >= {min_cuda}."
)
raise ValueError(
"This device is not supported; requires SM100a, SM110a, or SM12x."
"CUTLASS FMHA requires SM100a (B200/GB200) or SM110a. "
"SM12x (RTX 5090/DGX Spark) lacks tcgen05 MMA required by this kernel. "
"Use backend='fa2' instead."
)


Expand Down Expand Up @@ -3838,13 +3828,6 @@ def fmha_v2_prefill_deepseek(
If return_lse is False, the output will be a single tensor.
"""
if not is_sm12x_supported(query.device):
major, minor = get_compute_capability(query.device)
if major == 12:
min_cuda = "13.0" if minor >= 1 else "12.8"
raise ValueError(
f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} "
f"for SM12{minor}x GPUs."
)
raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.")
assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
"currently only support deepseek r1 192 query and 128 value"
Expand Down
Loading