diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index a4ea13ea7e..fe263137cf 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -3636,23 +3636,15 @@ def _heuristic_func_mm_fp4( - If cuda version is 12 - use cutlass. - If cuda version is 13 and cudnn version is less than 9.15 - use cutlass. - If cuda version is 13 and cudnn version is 9.15 or greater: - - On SM103 (B300) - use cutlass (faster based on benchmarks). - - On SM100 (B200) - use cudnn (faster based on benchmarks). + - Use cudnn first for both SM100 (B200) and SM103 (B300). """ cuda_major = get_cuda_version().major - # Get compute capability to distinguish between SM100 (10.0) and SM103 (10.3) - major, minor = get_compute_capability(a.device) - is_sm103 = major == 10 and minor == 3 - # If cuda version is 13 or greater and cudnn version is 9.15 or greater: - # On SM103 (B300), cutlass is more performant than cudnn. - # On SM100 (B200), cudnn is more performant than cutlass. + # If cuda version is 13 or greater and cudnn version is 9.15 or greater, + # prioritize cudnn for both SM100 (B200) and SM103 (B300). if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500: - if is_sm103: - candidate_backends = ("cutlass", "cudnn") - else: - candidate_backends = ("cudnn", "cutlass") + candidate_backends = ("cudnn", "cutlass") # Otherwise, prioritize cutlass else: candidate_backends = ("cutlass", "cudnn")