diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 18167c1ad2..e6c2bd836d 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -586,6 +586,19 @@ def is_sm12x_supported(device: torch.device) -> bool: return version_at_least(torch.version.cuda, min_cuda) +def is_cvt_rs_supported(device: torch.device = None) -> bool: + """Check if the GPU supports the PTX cvt.rs.f16x2.f32 instruction. + + This is a non-forward-compatible SM100a feature — not all SM >= 100 have it. + In particular, SM120 (Blackwell lite) does NOT support it. + """ + if device is None: + device = torch.device("cuda") + major, _ = get_compute_capability(device) + # SM100a and SM110a support cvt.rs; SM120 does not. + return major in (10, 11) + + def determine_mla_backend(device: torch.device) -> str: return "fa3" if is_sm90a_supported(device) else "fa2" diff --git a/tests/mamba/test_philox_rounding.py b/tests/mamba/test_philox_rounding.py index 66c344c87f..9e55edf71f 100644 --- a/tests/mamba/test_philox_rounding.py +++ b/tests/mamba/test_philox_rounding.py @@ -12,7 +12,7 @@ import triton.language as tl from torch.utils.cpp_extension import load_inline -from flashinfer.utils import get_compute_capability +from flashinfer.utils import get_compute_capability, is_cvt_rs_supported # --------------------------------------------------------------------------- @@ -222,8 +222,8 @@ def philox_module(): def stochastic_round_module(): """Compile cvt_rs_f16x2_f32 test kernel with sm_100a (hardware PTX path).""" major, minor = get_compute_capability(torch.device("cuda")) - if major < 10: - pytest.skip("cvt.rs.f16x2.f32 requires sm_100a (Blackwell or later)") + if not is_cvt_rs_supported(torch.device("cuda")): + pytest.skip("cvt.rs.f16x2.f32 requires sm_100a; not supported on this GPU") # Append 'a' suffix for SM >= 9, matching flashinfer/compilation_context.py:44-45 minor_str = f"{minor}a" if major >= 9 else str(minor) gencode = f"-gencode=arch=compute_{major}{minor_str},code=sm_{major}{minor_str}" diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index 0dc91e444c..3ec0c9c0bf 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -10,7 +10,7 @@ import torch import flashinfer -from flashinfer.utils import get_compute_capability +from flashinfer.utils import is_cvt_rs_supported from .triton_reference.selective_state_update import selective_state_update_triton from .utils import create_test_inputs, clone_preserving_strides @@ -1114,11 +1114,11 @@ def make_inputs( def make_reference_output(self, inputs): """Compute reference output using Triton with stochastic rounding.""" state_ref = clone_preserving_strides(inputs["state_cache"]) - major, _ = get_compute_capability(torch.device("cuda")) - # Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton - # reference falls back to regular rounding while the CUDA kernel still - # exercises its software stochastic rounding path. - rand_seed = self.RAND_SEED if major >= 10 else None + # Triton cvt.rs.f16x2.f32 requires SM100a (non-forward-compatible); + # on unsupported GPUs the Triton reference falls back to regular + # rounding while the CUDA kernel still exercises its software + # stochastic rounding path. + rand_seed = self.RAND_SEED if is_cvt_rs_supported() else None y_ref = selective_state_update_triton( state_ref, inputs["x"], @@ -1249,11 +1249,11 @@ def make_reference_output(self, inputs): """Compute reference output using Triton with SR and intermediate states.""" state_ref = clone_preserving_strides(inputs["state_cache"]) intermediate_states_ref = inputs["intermediate_states_buffer"].clone() - major, _ = get_compute_capability(torch.device("cuda")) - # Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton - # reference falls back to regular rounding while the CUDA kernel still - # exercises its software stochastic rounding path. - rand_seed = self.RAND_SEED if major >= 10 else None + # Triton cvt.rs.f16x2.f32 requires SM100a (non-forward-compatible); + # on unsupported GPUs the Triton reference falls back to regular + # rounding while the CUDA kernel still exercises its software + # stochastic rounding path. + rand_seed = self.RAND_SEED if is_cvt_rs_supported() else None y_ref = selective_state_update_triton( state_ref, diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index 7da054d18f..dfb1794520 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -3,7 +3,7 @@ import torch import flashinfer -from flashinfer.utils import get_compute_capability +from flashinfer.utils import get_compute_capability, is_cvt_rs_supported from .triton_reference.selective_state_update import selective_state_update_triton from .utils import create_test_inputs, clone_preserving_strides @@ -666,11 +666,11 @@ def make_inputs(self, batch, nheads, dim, dstate, _state_dtype, weight_dtype): def make_reference_output(self, inputs): """Compute reference output using Triton with stochastic rounding.""" state_ref = inputs["state_cache"].clone() - major, _ = get_compute_capability(torch.device("cuda")) - # Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton - # reference falls back to regular rounding while the CUDA kernel still - # exercises its software stochastic rounding path. - rand_seed = self.RAND_SEED if major >= 10 else None + # Triton cvt.rs.f16x2.f32 requires SM100a (non-forward-compatible); + # on unsupported GPUs the Triton reference falls back to regular + # rounding while the CUDA kernel still exercises its software + # stochastic rounding path. + rand_seed = self.RAND_SEED if is_cvt_rs_supported() else None y_ref = selective_state_update_triton( state_ref, inputs["x"],