Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,19 @@ def is_sm12x_supported(device: torch.device) -> bool:
return version_at_least(torch.version.cuda, min_cuda)


def is_cvt_rs_supported(device: torch.device = None) -> bool:
"""Check if the GPU supports the PTX cvt.rs.f16x2.f32 instruction.

This is a non-forward-compatible SM100a feature β€” not all SM >= 100 have it.
In particular, SM120 (Blackwell lite) does NOT support it.
"""
if device is None:
device = torch.device("cuda")
major, _ = get_compute_capability(device)
# SM100a and SM110a support cvt.rs; SM120 does not.
return major in (10, 11)
Comment on lines +597 to +599
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of is_cvt_rs_supported only checks the major compute capability. However, support for specific architecture features like sm_100a also depends on the CUDA toolkit version. For consistency with other support-check functions in this file (e.g., is_sm100a_supported), this function should also verify the minimum required CUDA version. This ensures that the check is robust and prevents runtime errors if an older CUDA toolkit is used with a compatible GPU.

Suggested change
major, _ = get_compute_capability(device)
# SM100a and SM110a support cvt.rs; SM120 does not.
return major in (10, 11)
major, _ = get_compute_capability(device)
# SM100a and SM110a support cvt.rs; SM120 does not.
if major == 10:
return version_at_least(torch.version.cuda, "12.8")
if major == 11:
return version_at_least(torch.version.cuda, "13.0")
return False



def determine_mla_backend(device: torch.device) -> str:
return "fa3" if is_sm90a_supported(device) else "fa2"

Expand Down
6 changes: 3 additions & 3 deletions tests/mamba/test_philox_rounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import triton.language as tl
from torch.utils.cpp_extension import load_inline

from flashinfer.utils import get_compute_capability
from flashinfer.utils import get_compute_capability, is_cvt_rs_supported


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -222,8 +222,8 @@ def philox_module():
def stochastic_round_module():
"""Compile cvt_rs_f16x2_f32 test kernel with sm_100a (hardware PTX path)."""
major, minor = get_compute_capability(torch.device("cuda"))
if major < 10:
pytest.skip("cvt.rs.f16x2.f32 requires sm_100a (Blackwell or later)")
if not is_cvt_rs_supported(torch.device("cuda")):
pytest.skip("cvt.rs.f16x2.f32 requires sm_100a; not supported on this GPU")
# Append 'a' suffix for SM >= 9, matching flashinfer/compilation_context.py:44-45
minor_str = f"{minor}a" if major >= 9 else str(minor)
gencode = f"-gencode=arch=compute_{major}{minor_str},code=sm_{major}{minor_str}"
Expand Down
22 changes: 11 additions & 11 deletions tests/mamba/test_selective_state_update_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

import flashinfer
from flashinfer.utils import get_compute_capability
from flashinfer.utils import is_cvt_rs_supported

from .triton_reference.selective_state_update import selective_state_update_triton
from .utils import create_test_inputs, clone_preserving_strides
Expand Down Expand Up @@ -1114,11 +1114,11 @@ def make_inputs(
def make_reference_output(self, inputs):
"""Compute reference output using Triton with stochastic rounding."""
state_ref = clone_preserving_strides(inputs["state_cache"])
major, _ = get_compute_capability(torch.device("cuda"))
# Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton
# reference falls back to regular rounding while the CUDA kernel still
# exercises its software stochastic rounding path.
rand_seed = self.RAND_SEED if major >= 10 else None
# Triton cvt.rs.f16x2.f32 requires SM100a (non-forward-compatible);
# on unsupported GPUs the Triton reference falls back to regular
# rounding while the CUDA kernel still exercises its software
# stochastic rounding path.
rand_seed = self.RAND_SEED if is_cvt_rs_supported() else None
y_ref = selective_state_update_triton(
state_ref,
inputs["x"],
Expand Down Expand Up @@ -1249,11 +1249,11 @@ def make_reference_output(self, inputs):
"""Compute reference output using Triton with SR and intermediate states."""
state_ref = clone_preserving_strides(inputs["state_cache"])
intermediate_states_ref = inputs["intermediate_states_buffer"].clone()
major, _ = get_compute_capability(torch.device("cuda"))
# Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton
# reference falls back to regular rounding while the CUDA kernel still
# exercises its software stochastic rounding path.
rand_seed = self.RAND_SEED if major >= 10 else None
# Triton cvt.rs.f16x2.f32 requires SM100a (non-forward-compatible);
# on unsupported GPUs the Triton reference falls back to regular
# rounding while the CUDA kernel still exercises its software
# stochastic rounding path.
rand_seed = self.RAND_SEED if is_cvt_rs_supported() else None

y_ref = selective_state_update_triton(
state_ref,
Expand Down
12 changes: 6 additions & 6 deletions tests/mamba/test_selective_state_update_stp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

import flashinfer
from flashinfer.utils import get_compute_capability
from flashinfer.utils import get_compute_capability, is_cvt_rs_supported

from .triton_reference.selective_state_update import selective_state_update_triton
from .utils import create_test_inputs, clone_preserving_strides
Expand Down Expand Up @@ -666,11 +666,11 @@ def make_inputs(self, batch, nheads, dim, dstate, _state_dtype, weight_dtype):
def make_reference_output(self, inputs):
"""Compute reference output using Triton with stochastic rounding."""
state_ref = inputs["state_cache"].clone()
major, _ = get_compute_capability(torch.device("cuda"))
# Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton
# reference falls back to regular rounding while the CUDA kernel still
# exercises its software stochastic rounding path.
rand_seed = self.RAND_SEED if major >= 10 else None
# Triton cvt.rs.f16x2.f32 requires SM100a (non-forward-compatible);
# on unsupported GPUs the Triton reference falls back to regular
# rounding while the CUDA kernel still exercises its software
# stochastic rounding path.
rand_seed = self.RAND_SEED if is_cvt_rs_supported() else None
y_ref = selective_state_update_triton(
state_ref,
inputs["x"],
Expand Down
Loading