diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py index 7c4ecc1fc7..d0d26da4c3 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -529,8 +529,9 @@ def __init__( :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. :type cluster_shape_mn: Tuple[int, int] """ - assert sm_version == "sm_100", ( - "sm_100 is the only supported SM version for cute-dsl backend." + supported_sm_versions = ["sm_100", "sm_103"] + assert sm_version in supported_sm_versions, ( + f"{supported_sm_versions} are the only supported SM versions for cute-dsl backend, but encountered {sm_version}" ) self.acc_dtype = cutlass.Float32 @@ -561,7 +562,12 @@ def __init__( self.cta_sync_bar_id = 0 self.epilog_sync_bar_id = 1 self.tmem_ptr_sync_bar_id = 2 - self.smem_capacity = utils.get_smem_capacity_in_bytes(sm_version) + + # HACK "sm_103" doesn't work yet for the query + # https://github.com/NVIDIA/cutlass/blob/5016493cc0d8650d5b2f6d2c2751cf49bc217e86/python/CuTeDSL/cutlass/utils/smem_allocator.py#L19 + # self.smem_capacity = utils.get_smem_capacity_in_bytes(sm_version) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS diff --git a/requirements.txt b/requirements.txt index a31b6ebdc8..44ee69683d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ einops ninja numpy nvidia-cudnn-frontend>=1.13.0 -nvidia-cutlass-dsl>=4.2.1 +nvidia-cutlass-dsl>=4.3.1 nvidia-ml-py packaging>=24.2 requests diff --git a/tests/gemm/test_cute_dsl_blockscaled_gemm.py b/tests/gemm/test_cute_dsl_blockscaled_gemm.py index 2eb5abc832..30a59260d2 100644 --- a/tests/gemm/test_cute_dsl_blockscaled_gemm.py +++ b/tests/gemm/test_cute_dsl_blockscaled_gemm.py @@ -80,10 +80,12 @@ def test_blockscaled_gemm_python_interface( ): torch.manual_seed(42) device = torch.device("cuda:0") - major, minor = torch.cuda.get_device_capability(device) - - if not (major == 10 and minor == 0): - pytest.skip("Cute-dsl backend is only supported on SM100.") + device_ver = torch.cuda.get_device_capability(device) + supported_device_vers = [(10, 0), (10, 3)] + if device_ver not in supported_device_vers: + pytest.skip( + f"Cute-dsl backend is only supported on {supported_device_vers}, skipping {device_ver}." + ) l, m = lm k, n = kn