From b1fc962ff3359542a4f1d7e5c3d5077ccb0eab0e Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 24 Oct 2025 18:27:06 +0000 Subject: [PATCH 01/19] Add first draft of mm_fp4 backend --- flashinfer/gemm/gemm_base.py | 41 ++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ac0fbab4a0..1c4662840a 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -93,6 +93,10 @@ def _match_sm_version(device: torch.device, sm_version: list[str]): return device_arch in sm_version +def get_cuda_version(device: torch.device): + return tuple(map(int, torch.version.cuda.split("."))) # (major, minor) + + @functools.cache def get_gemm_module(): module = gen_gemm_module().build_and_load() @@ -1853,7 +1857,7 @@ def mm_fp4( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ) -> torch.Tensor: r"""MM FP4 @@ -1887,8 +1891,8 @@ def mm_fp4( use_8x4_sf_layout: bool Whether to use 8x4 scale factor layout or 128x4 scale factor layout, defaults to False. - backend: Literal["cudnn", "trtllm", "cutlass"] - Backend to use, defaults to "cudnn". + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] + Backend to use, defaults to "auto", which automatically selects the best backend between cudnn and cutlass. use_nvfp4: bool Whether to use nvfp4 quantization or mxfp4 quantization, defaults to False. @@ -1930,7 +1934,32 @@ def mm_fp4( "mm_fp4_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) - if backend == "cudnn": + # Auto-select the best backend + if backend == "auto": + cuda_major, _ = get_cuda_version(a.device) + cc_major, cc_minor = get_compute_capability(a.device) + cc_arch = cc_major * 10 + cc_minor + # If cuda version is 13 or greater AND cudnn version is 9.X or greater, prioritize cudnn. + if cuda_major >= 13: # to-do add cudnn version threshold + candidate_backends = ["cudnn", "cutlass"] + # Otherwise, prioritize cutlass + else: + candidate_backends = ["cutlass", "cudnn"] + + # Support check + backends_to_delete = [] + for candidate_backend in candidate_backends: + if not mm_fp4.is_backend_supported(candidate_backend, cc_arch): + backends_to_delete.append(candidate_backend) + for backend_to_delete in backends_to_delete: + candidate_backends.remove(backend_to_delete) + selected_backend = candidate_backends[0] + print( + f"Selected backend: {selected_backend} for cuda version {cuda_major} and compute capability {cc_major}{cc_minor}" + ) + else: + selected_backend = backend + if selected_backend == "cudnn": # the fp4 cudnn graph will be shared for both mm and bmm, so # here we need to get the 3d shape and stride including the # batch dimension for both input and block scale tensors. @@ -1966,7 +1995,7 @@ def mm_fp4( execute_cudnn_gemm_fp4_graph( graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer ) - elif backend == "trtllm": + elif selected_backend == "trtllm": get_trtllm_fp4_gemm_module().trtllm_fp4_gemm( a, b.T, @@ -1977,7 +2006,7 @@ def mm_fp4( use_8x4_sf_layout=use_8x4_sf_layout, workspace_buffer=workspace_buffer, ) - elif backend == "cutlass": + elif selected_backend == "cutlass": # cutlass require uint8 scale when a/b is fp4 packed uint8. if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: a_descale = a_descale.view(torch.uint8) From d71943f331639f92a2bb23a7507590bfd9356181 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 24 Oct 2025 21:02:03 +0000 Subject: [PATCH 02/19] Add second draft of mm_fp4 backend --- .../routines/flashinfer_benchmark_utils.py | 8 +- benchmarks/routines/gemm.py | 13 ++- flashinfer/gemm/gemm_base.py | 96 +++++++++++++++---- tests/gemm/test_mm_fp4.py | 15 +++ 4 files changed, 106 insertions(+), 26 deletions(-) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 520029f0ec..3f4811ceb1 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -241,10 +241,10 @@ def dtype_str_to_torch_dtype(dtype_str): "8.6": [], "8.9": [], "9.0": [], - "10.0": ["cudnn", "trtllm", "cutlass"], - "10.3": ["cudnn", "trtllm", "cutlass"], - "12.0": ["cudnn", "cutlass"], - "12.1": ["cudnn", "cutlass"], + "10.0": ["cudnn", "trtllm", "cutlass", "auto"], + "10.3": ["cudnn", "trtllm", "cutlass", "auto"], + "12.0": ["cudnn", "cutlass", "auto"], + "12.1": ["cudnn", "cutlass", "auto"], }, # MOE "trtllm_fp4_block_scale_moe": { diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 17336189d0..0cc16ad23e 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -131,7 +131,7 @@ def parse_gemm_args(line, parser): required=False, nargs="+", default=["cudnn"], - choices=["cudnn", "cublas", "trtllm", "cutlass"], + choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"], help="Kernel backends to test. Default: cudnn", ) parser.add_argument( @@ -823,7 +823,7 @@ def testMmFp4(args): print( "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" ) - backends.remove("cutlass") + remove_cutlass = True if remove_cutlass: backends.remove("cutlass") if "cudnn" in backends: @@ -833,6 +833,13 @@ def testMmFp4(args): remove_cudnn = True if remove_cudnn: backends.remove("cudnn") + if "auto" in backends: + remove_auto = False + if not use_128x4_sf_layout: + print("[INFO] auto backend does not support use_128x4_sf_layout=False") + remove_auto = True + if remove_auto: + backends.remove("auto") if getattr(args, "autotune", False): backends_to_remove = [] for cur_backend in backends: @@ -889,7 +896,7 @@ def testMmFp4(args): # res = torch.empty([m, n], device="cuda", dtype=res_dtype) def run_backend(backend): - if backend in ["cudnn", "trtllm", "cutlass"]: + if backend in ["cudnn", "trtllm", "cutlass", "auto"]: return flashinfer.gemm.mm_fp4( a=input_fp4, b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 1c4662840a..b99e9f4499 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -17,7 +17,7 @@ import functools from enum import Enum from types import SimpleNamespace -from typing import List, Literal, Optional, Tuple +from typing import List, Literal, Optional, Tuple, cast from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch @@ -1691,7 +1691,7 @@ def _check_mm_fp4_problem_size( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ): # Generic checks @@ -1731,8 +1731,8 @@ def _check_mm_fp4_problem_size( if backend != "trtllm" and use_8x4_sf_layout: raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") - if backend != "cudnn" and not use_nvfp4: - raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") + if backend not in ["cudnn", "auto"] and not use_nvfp4: + raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") if use_nvfp4 and block_size != 16: raise ValueError("nvfp4 only supports block_size = 16.") @@ -1753,7 +1753,7 @@ def _cudnn_gemm_fp4_requirement( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ): if ( @@ -1811,7 +1811,7 @@ def _trtllm_gemm_fp4_requirement( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ): if out_dtype != torch.bfloat16: @@ -1833,17 +1833,57 @@ def _cutlass_gemm_fp4_requirement( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ): return True +@supported_compute_capability([100, 103, 110, 120]) +def _auto_gemm_fp4_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", + use_nvfp4: bool = True, +): + # Auto backend requires at least one backend to be supported on the current device + cc_major, cc_minor = get_compute_capability(a.device) + cc_arch = cc_major * 10 + cc_minor + + # Check if at least one backend is supported for this compute capability + candidate_backends = ["cudnn", "cutlass", "trtllm"] + backend_checkers = { + "cudnn": _cudnn_gemm_fp4_requirement, + "cutlass": _cutlass_gemm_fp4_requirement, + # Does not consider trtllm due to different interface. + } + + for candidate in candidate_backends: + checker = backend_checkers[candidate] + if hasattr( + checker, "is_compute_capability_supported" + ) and checker.is_compute_capability_supported(cc_arch): + # At least one backend is supported + print(f"Backend {candidate} is supported on this device.") + return True + + # No backend is supported on this device + return False + + @backend_requirement( { "cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function "trtllm": _trtllm_gemm_fp4_requirement, "cutlass": _cutlass_gemm_fp4_requirement, + "auto": _auto_gemm_fp4_requirement, # Auto backend requires at least one backend to be supported on the current device }, common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends ) @@ -1938,22 +1978,40 @@ def mm_fp4( if backend == "auto": cuda_major, _ = get_cuda_version(a.device) cc_major, cc_minor = get_compute_capability(a.device) - cc_arch = cc_major * 10 + cc_minor # If cuda version is 13 or greater AND cudnn version is 9.X or greater, prioritize cudnn. if cuda_major >= 13: # to-do add cudnn version threshold - candidate_backends = ["cudnn", "cutlass"] + candidate_backends = ("cudnn", "cutlass") # Otherwise, prioritize cutlass else: - candidate_backends = ["cutlass", "cudnn"] - - # Support check - backends_to_delete = [] - for candidate_backend in candidate_backends: - if not mm_fp4.is_backend_supported(candidate_backend, cc_arch): - backends_to_delete.append(candidate_backend) - for backend_to_delete in backends_to_delete: - candidate_backends.remove(backend_to_delete) - selected_backend = candidate_backends[0] + candidate_backends = ("cutlass", "cudnn") + + # Filter to only supported backends for this compute capability + # Note: The requirement function already validated that at least one backend is supported + supported_backends = [] + for candidate in candidate_backends: + # mypy requires explicit type casting for the backend literal + backend_literal = cast( + Literal["cudnn", "trtllm", "cutlass", "auto"], candidate + ) + try: + _check_mm_fp4_problem_size( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_8x4_sf_layout, + backend_literal, + use_nvfp4, + ) + supported_backends.append(candidate) + except Exception: + pass + print(f"Supported backends: {supported_backends}") + selected_backend = supported_backends[0] print( f"Selected backend: {selected_backend} for cuda version {cuda_major} and compute capability {cc_major}{cc_minor}" ) diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index 9d7a7abbbd..8429baf47a 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -105,5 +105,20 @@ def test_mm_fp4( pytest.fail(str(e)) +# Split tests for checking auto functionality +@pytest.mark.parametrize("m", [1, 48, 256, 512]) +@pytest.mark.parametrize("n", [256, 512]) +@pytest.mark.parametrize("k", [256, 512]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("backend", ["auto"]) +@pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) +@pytest.mark.parametrize("auto_tuning", [False, True]) +@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) +def test_mm_fp4_backend_auto( + m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type +): + test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type) + + if __name__ == "__main__": pytest.main([__file__]) From e3b68ddffe1659610b9acac0e3ee41ced7c73e37 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 24 Oct 2025 21:53:46 +0000 Subject: [PATCH 03/19] Add third draft of mm_fp4 backend -- no audotune --- benchmarks/routines/gemm.py | 2 +- flashinfer/gemm/gemm_base.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 0cc16ad23e..d86ca1cd0a 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -790,7 +790,7 @@ def testMmFp4(args): run_refcheck = args.refcheck use_128x4_sf_layout = args.use_128x4_sf_layout use_nvfp4 = args.use_nvfp4 - autotune_supported_backends = ["cutlass", "trtllm"] + autotune_supported_backends = ["cutlass", "trtllm", "auto"] res = [] backends = filter_backends_by_compute_capability(backends, args.routine, device) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index b99e9f4499..859ecd1b6f 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1871,7 +1871,6 @@ def _auto_gemm_fp4_requirement( checker, "is_compute_capability_supported" ) and checker.is_compute_capability_supported(cc_arch): # At least one backend is supported - print(f"Backend {candidate} is supported on this device.") return True # No backend is supported on this device @@ -1978,8 +1977,9 @@ def mm_fp4( if backend == "auto": cuda_major, _ = get_cuda_version(a.device) cc_major, cc_minor = get_compute_capability(a.device) - # If cuda version is 13 or greater AND cudnn version is 9.X or greater, prioritize cudnn. - if cuda_major >= 13: # to-do add cudnn version threshold + # If cuda version is 13 or greater: + # cudnn is more performant if cudnn version is 9.14 or greater. + if cuda_major >= 13 and cudnn.backend_version() >= 91400: candidate_backends = ("cudnn", "cutlass") # Otherwise, prioritize cutlass else: @@ -2010,11 +2010,7 @@ def mm_fp4( supported_backends.append(candidate) except Exception: pass - print(f"Supported backends: {supported_backends}") selected_backend = supported_backends[0] - print( - f"Selected backend: {selected_backend} for cuda version {cuda_major} and compute capability {cc_major}{cc_minor}" - ) else: selected_backend = backend if selected_backend == "cudnn": From 34f3ecab8470e43f5570489803fefbfce7f09626 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Sat, 25 Oct 2025 00:17:30 +0000 Subject: [PATCH 04/19] Add 4th draft of mm_fp4 backend -- enable cross-backend autotune on auto, but no cudnn autotune --- benchmarks/routines/gemm.py | 2 +- flashinfer/gemm/gemm_base.py | 549 ++++++++++++++++++----------------- 2 files changed, 283 insertions(+), 268 deletions(-) diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index d86ca1cd0a..2ab8838a06 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -790,7 +790,7 @@ def testMmFp4(args): run_refcheck = args.refcheck use_128x4_sf_layout = args.use_128x4_sf_layout use_nvfp4 = args.use_nvfp4 - autotune_supported_backends = ["cutlass", "trtllm", "auto"] + autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"] res = [] backends = filter_backends_by_compute_capability(backends, args.routine, device) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 859ecd1b6f..bab75cc74b 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -410,87 +410,46 @@ def fp8_gemm_sm100( def _create_cutlass_fp4_gemm_module(module, op_name: str, tuner_name: str): """Helper function to create cutlass FP4 GEMM module.""" - class CutlassFp4GemmRunner(TunableRunner): - def __init__(self): - self._fp4_gemm_runner = module.fp4_gemm + def cutlass_fp4_gemm_runner(): + class CutlassFp4GemmRunner(TunableRunner): + def __init__(self): + self._fp4_gemm_runner = module.fp4_gemm - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - return list(range(module.fp4_gemm_tactic_num())) - - def forward( - self, - inputs: List[torch.Tensor], - tactic: int = -1, - do_preparation: bool = False, - **kwargs, - ): - a, b, a_descale, b_descale, alpha, out, workspace_buffer = inputs - module.fp4_gemm( - a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic - ) - return out - - @register_custom_op( - op_name, - mutates_args=(""), - ) - def cutlass_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - - a_tensor_index = 0 - a_scale_tensor_index = 2 - out_tensor_index = 5 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up(shapes[a_tensor_index][0], 128), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), - ) - - fp4_runner = CutlassFp4GemmRunner() + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(module.fp4_gemm_tactic_num())) - inputs = [a, b, a_descale, b_descale, alpha, out, workspace_buffer] - _, tactic = tuner.choose_one( - tuner_name, - [fp4_runner], - tuning_config, - inputs, - ) + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + module.fp4_gemm( + a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer, tactic + ) + return out - fp4_runner(inputs=inputs, tactic=tactic) + return CutlassFp4GemmRunner() return SimpleNamespace( - cutlass_fp4_gemm=cutlass_fp4_gemm, + cutlass_fp4_gemm_runner=cutlass_fp4_gemm_runner, ) @@ -1681,6 +1640,101 @@ def mm_fp8( return out +def _cudnn_gemm_fp4( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_nvfp4: bool = True, + workspace_buffer: torch.Tensor = None, +): + _check_cudnn_availability() + # the fp4 cudnn graph will be shared for both mm and bmm, so + # here we need to get the 3d shape and stride including the + # batch dimension for both input and block scale tensors. + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) + + # build the fp4 cudnn graph + graph = build_plans_cudnn_fp4_gemm_graph( + real_a_shape, + real_a_stride, + real_b_shape, + real_b_stride, + expanded_a_descale_shape, + expanded_a_descale_stride, + expanded_b_descale_shape, + expanded_b_descale_stride, + cudnn.data_type.FP4_E2M1, + _torch_data_type_to_cudnn_data_type(out_dtype), + block_size, + a.device, + alpha is not None, + use_nvfp4, + ) + + # execute the fp4 cudnn graph + execute_cudnn_gemm_fp4_graph( + graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer + ) + + +def _cudnn_gemm_fp4_runner(): + class CudnnFp4GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic + return [0] + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + ( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ) = inputs + _cudnn_gemm_fp4( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ) + + return CudnnFp4GemmRunner() + + def _check_mm_fp4_problem_size( a: torch.Tensor, b: torch.Tensor, @@ -1987,7 +2041,7 @@ def mm_fp4( # Filter to only supported backends for this compute capability # Note: The requirement function already validated that at least one backend is supported - supported_backends = [] + backends = [] for candidate in candidate_backends: # mypy requires explicit type casting for the backend literal backend_literal = cast( @@ -2007,76 +2061,94 @@ def mm_fp4( backend_literal, use_nvfp4, ) - supported_backends.append(candidate) + backends.append(candidate) except Exception: pass - selected_backend = supported_backends[0] else: - selected_backend = backend - if selected_backend == "cudnn": - # the fp4 cudnn graph will be shared for both mm and bmm, so - # here we need to get the 3d shape and stride including the - # batch dimension for both input and block scale tensors. - real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) - real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) - batch = real_a_shape[0] - expanded_a_descale_shape, expanded_a_descale_stride = ( - _expand_block_scale_tensor_shape(a_descale, batch) - ) - expanded_b_descale_shape, expanded_b_descale_stride = ( - _expand_block_scale_tensor_shape(b_descale, batch) - ) - - # build the fp4 cudnn graph - graph = build_plans_cudnn_fp4_gemm_graph( - real_a_shape, - real_a_stride, - real_b_shape, - real_b_stride, - expanded_a_descale_shape, - expanded_a_descale_stride, - expanded_b_descale_shape, - expanded_b_descale_stride, - cudnn.data_type.FP4_E2M1, - _torch_data_type_to_cudnn_data_type(out_dtype), - block_size, - a.device, - alpha is not None, - use_nvfp4, - ) + backends = [backend] - # execute the fp4 cudnn graph - execute_cudnn_gemm_fp4_graph( - graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer - ) - elif selected_backend == "trtllm": - get_trtllm_fp4_gemm_module().trtllm_fp4_gemm( - a, - b.T, - a_descale, - b_descale.T, - alpha, - out, - use_8x4_sf_layout=use_8x4_sf_layout, - workspace_buffer=workspace_buffer, - ) - elif selected_backend == "cutlass": - # cutlass require uint8 scale when a/b is fp4 packed uint8. - if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: - a_descale = a_descale.view(torch.uint8) - if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: - b_descale = b_descale.view(torch.uint8) - - # Dispatch to the correct module based on device architecture - major, _ = get_compute_capability(a.device) - if major == 12: - gemm_module = get_gemm_sm120_module_cutlass_fp4() + # At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'. + runners = [] + for cur_backend in backends: + if cur_backend == "cudnn": + runners.append(_cudnn_gemm_fp4_runner()) + elif cur_backend == "trtllm": + runners.append( + get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner(use_8x4_sf_layout) + ) + elif cur_backend == "cutlass": + if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: + a_descale = a_descale.view(torch.uint8) + if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: + b_descale = b_descale.view(torch.uint8) + + # Dispatch to the correct module based on device architecture + major, _ = get_compute_capability(a.device) + if major == 12: + runners.append( + get_gemm_sm120_module_cutlass_fp4().cutlass_fp4_gemm_runner() + ) + else: + runners.append( + get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner() + ) else: - gemm_module = get_gemm_sm100_module_cutlass_fp4() + # Should not reach this + raise ValueError(f"Unsupported backend: {cur_backend}") - gemm_module.cutlass_fp4_gemm( - a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer - ) + # Now we have a list of runners for desired & supported backends. + tuner = AutoTuner.get() + + a_tensor_index = 0 + a_scale_tensor_index = 2 + out_tensor_index = 6 + + def pad_up(x, y): + return ((x + y - 1) // y) * y + + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (a_tensor_index,), + (0,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + a_scale_tensor_index, + 0, + lambda shapes: pad_up( + shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 + ), + ), + ConstraintSpec( + out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] + ), + ), + ) + + inputs = [ + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ] + runner, tactic = tuner.choose_one( + "fp4_gemm", + runners, + tuning_config, + inputs, + ) + + runner(inputs=inputs, tactic=tactic) return out @@ -2438,139 +2510,82 @@ def get_trtllm_fp4_gemm_module(): op = mod.build_and_load() setup_cubin_loader(mod.get_library_path()) - class TrtllmFp4GemmRunner(TunableRunner): - def __init__(self, use_8x4_sf_layout: bool = True): - self._fp4_gemm_runner = op.trtllm_gemm - self._use_8x4_sf_layout = use_8x4_sf_layout + def trtllm_fp4_gemm_runner(use_8x4_sf_layout: bool = True): + class TrtllmFp4GemmRunner(TunableRunner): + def __init__(self, use_8x4_sf_layout: bool = True): + self._fp4_gemm_runner = op.trtllm_gemm + self._use_8x4_sf_layout = use_8x4_sf_layout - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - a_tensor_index = 1 - b_tensor_index = 2 - - a = profile.get_opt_shapes()[a_tensor_index] - b = profile.get_opt_shapes()[b_tensor_index] - m = a[0] - n = b[0] - k = a[1] * 2 - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs - type_e2m1 = 0 - type_bf16 = 2 - return list( - op.trtllm_gemm_tactics( - m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + a_tensor_index = 1 + b_tensor_index = 2 + + a = profile.get_opt_shapes()[a_tensor_index] + b = profile.get_opt_shapes()[b_tensor_index] + m = a[0] + n = b[0] + k = a[1] * 2 + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + type_e2m1 = 0 + type_bf16 = 2 + return list( + op.trtllm_gemm_tactics( + m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout + ) ) - ) - - def forward( - self, - inputs: List[torch.Tensor], - tactic: int = -1, - do_preparation: bool = False, - **kwargs, - ): - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs - op.trtllm_gemm( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - self._use_8x4_sf_layout, - tactic, - ) - return out - - @register_custom_op( - "flashinfer::trtllm_fp4_gemm", - mutates_args=(""), - ) - def trtllm_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - use_8x4_sf_layout: bool, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - - a_tensor_index = 1 - a_scale_tensor_index = 3 - out_tensor_index = 6 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up( - shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 - ), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), - ) - - fp4_runner = TrtllmFp4GemmRunner(use_8x4_sf_layout) - inputs = [ - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ] - _, tactic = tuner.choose_one( - "trtllm_fp4_gemm_8x4" if use_8x4_sf_layout else "trtllm_fp4_gemm_128x4", - [fp4_runner], - tuning_config, - inputs, - ) + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + self._fp4_gemm_runner( + workspace_buffer, + a, + b.T, + a_descale, + b_descale.T, + alpha, + out, + self._use_8x4_sf_layout, + tactic, + ) + return out - fp4_runner(inputs=inputs, tactic=tactic) + return TrtllmFp4GemmRunner(use_8x4_sf_layout) # Register the module return SimpleNamespace( - trtllm_fp4_gemm=trtllm_fp4_gemm, + trtllm_fp4_gemm_runner=trtllm_fp4_gemm_runner, ) From 7def07884afdc4895189448c754919b76fd2ef96 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Sat, 25 Oct 2025 00:38:51 +0000 Subject: [PATCH 05/19] Add 5th draft of mm_fp4 backend -- enable cudnn autotune --- flashinfer/gemm/gemm_base.py | 62 +++++++++++++++++++++++++++++++++--- tests/gemm/test_mm_fp4.py | 2 -- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index bab75cc74b..a069200d6f 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1255,6 +1255,7 @@ def build_plans_cudnn_fp4_gemm_graph( device, alpha, use_nvfp4, + tactic: int = -1, ): graph = create_cudnn_execution_plans_fp4_gemm( a_shape, @@ -1274,7 +1275,10 @@ def build_plans_cudnn_fp4_gemm_graph( ) graph.check_support() - graph.build_plans() + if tactic != -1: + graph.build_plan_at_index(tactic) + else: + graph.build_plans() return graph @@ -1287,6 +1291,7 @@ def execute_cudnn_gemm_fp4_graph( alpha, c_final, workspace_buffer, + tactic: int = -1, ): variant_pack = { UIDs.A_UID.value: a.view(get_native_fp4_dtype()), @@ -1306,7 +1311,12 @@ def execute_cudnn_gemm_fp4_graph( stream = torch.cuda.current_stream(a.device) - graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) + if tactic == -1: + graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) + else: + graph.execute_plan_at_index( + variant_pack, workspace_buffer, tactic, handle=_get_cudnn_handle(stream) + ) @functools.cache @@ -1651,6 +1661,7 @@ def _cudnn_gemm_fp4( block_size: int = 16, use_nvfp4: bool = True, workspace_buffer: torch.Tensor = None, + tactic: int = -1, ): _check_cudnn_availability() # the fp4 cudnn graph will be shared for both mm and bmm, so @@ -1682,11 +1693,12 @@ def _cudnn_gemm_fp4( a.device, alpha is not None, use_nvfp4, + tactic=tactic, ) # execute the fp4 cudnn graph execute_cudnn_gemm_fp4_graph( - graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer + graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic=tactic ) @@ -1698,7 +1710,48 @@ def get_valid_tactics( profile: OptimizationProfile, ) -> List[int]: # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic - return [0] + _check_cudnn_availability() + ( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ) = inputs + + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) + + graph = build_plans_cudnn_fp4_gemm_graph( + real_a_shape, + real_a_stride, + real_b_shape, + real_b_stride, + expanded_a_descale_shape, + expanded_a_descale_stride, + expanded_b_descale_shape, + expanded_b_descale_stride, + cudnn.data_type.FP4_E2M1, + _torch_data_type_to_cudnn_data_type(out_dtype), + block_size, + a.device, + alpha is not None, + use_nvfp4, + ) + num_plans = graph.get_execution_plan_count() + return list(range(num_plans)) def forward( self, @@ -1730,6 +1783,7 @@ def forward( block_size, use_nvfp4, workspace_buffer, + tactic=tactic, ) return CudnnFp4GemmRunner() diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index 8429baf47a..a618583094 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -40,8 +40,6 @@ def test_mm_fp4( pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.") if not use_128x4_sf_layout and backend != "trtllm": pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False") - if auto_tuning and backend == "cudnn": - pytest.skip("Skipping test for cudnn fp4 with auto_tuning=True") if not use_nvfp4 and backend != "cudnn": pytest.skip("mx_fp4 is only supported for cudnn backend") From 0a4560d3ca9377f5d9b77e840f301f39068f784f Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Wed, 29 Oct 2025 00:16:10 +0000 Subject: [PATCH 06/19] Refactor test_mm_fp4.py --- tests/gemm/test_mm_fp4.py | 41 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index a618583094..cc85f6126a 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -12,16 +12,7 @@ from flashinfer.gemm.gemm_base import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR -# TODO: Consdier splitting this function up for the various backends -@pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) -@pytest.mark.parametrize("n", [128, 256, 512]) -@pytest.mark.parametrize("k", [128, 256, 512]) -@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) -@pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) -@pytest.mark.parametrize("auto_tuning", [False, True]) -@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) -def test_mm_fp4( +def _test_mm_fp4( m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type ): use_nvfp4 = fp4_type == "nvfp4" @@ -40,8 +31,8 @@ def test_mm_fp4( pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.") if not use_128x4_sf_layout and backend != "trtllm": pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False") - if not use_nvfp4 and backend != "cudnn": - pytest.skip("mx_fp4 is only supported for cudnn backend") + if not use_nvfp4 and backend not in ["cudnn", "auto"]: + pytest.skip("mx_fp4 is only supported for cudnn and auto backends") input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) @@ -103,19 +94,37 @@ def test_mm_fp4( pytest.fail(str(e)) +# TODO: Consdier splitting this function up for the various backends +@pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) +@pytest.mark.parametrize("n", [128, 256, 512]) +@pytest.mark.parametrize("k", [128, 256, 512]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) +@pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) +@pytest.mark.parametrize("auto_tuning", [False, True]) +@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) +def test_mm_fp4( + m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type +): + # Non-auto backends + _test_mm_fp4( + m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type + ) + + # Split tests for checking auto functionality @pytest.mark.parametrize("m", [1, 48, 256, 512]) @pytest.mark.parametrize("n", [256, 512]) @pytest.mark.parametrize("k", [256, 512]) @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("backend", ["auto"]) -@pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) +@pytest.mark.parametrize("use_128x4_sf_layout", [True]) @pytest.mark.parametrize("auto_tuning", [False, True]) @pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) def test_mm_fp4_backend_auto( - m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type + m, n, k, res_dtype, use_128x4_sf_layout, auto_tuning, fp4_type ): - test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type) + # Some test cases for auto backend. + _test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type) if __name__ == "__main__": From 7551fde8f53c31c67f6742b161ae24a339fc8989 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 31 Oct 2025 18:09:07 +0000 Subject: [PATCH 07/19] Address comments --- .../routines/flashinfer_benchmark_utils.py | 12 +- benchmarks/routines/gemm.py | 137 ++++++++++-------- flashinfer/gemm/gemm_base.py | 88 +++++------ 3 files changed, 123 insertions(+), 114 deletions(-) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 3f4811ceb1..d5f363839a 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -235,17 +235,7 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cudnn", "cublas", "cutlass"], "12.0": ["cudnn", "cublas"], }, - "mm_fp4": { - "7.5": [], - "8.0": [], - "8.6": [], - "8.9": [], - "9.0": [], - "10.0": ["cudnn", "trtllm", "cutlass", "auto"], - "10.3": ["cudnn", "trtllm", "cutlass", "auto"], - "12.0": ["cudnn", "cutlass", "auto"], - "12.1": ["cudnn", "cutlass", "auto"], - }, + # Note: mm_fp4 uses support checkers to filter backends, so it is not listed here # MOE "trtllm_fp4_block_scale_moe": { "7.5": [], diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 2ab8838a06..833917bf76 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -793,65 +793,11 @@ def testMmFp4(args): autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"] res = [] - backends = filter_backends_by_compute_capability(backends, args.routine, device) - res_dtype = dtype_str_to_torch_dtype(args.out_dtype) if res_dtype not in [torch.bfloat16, torch.float16]: raise ValueError( f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16." ) - ## Done parsing input arguments - - if "trtllm" in backends: - remove_trtllm = False - if res_dtype == torch.float16: - print("[INFO] trtllm backend does not support float16 output") - remove_trtllm = True - if remove_trtllm: - backends.remove("trtllm") - if not use_nvfp4: - print( - "[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("trtllm") - if "cutlass" in backends: - remove_cutlass = False - if not use_128x4_sf_layout: - print("[INFO] cutlass backend does not support use_128x4_sf_layout=False") - remove_cutlass = True - if not use_nvfp4: - print( - "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - remove_cutlass = True - if remove_cutlass: - backends.remove("cutlass") - if "cudnn" in backends: - remove_cudnn = False - if not use_128x4_sf_layout: - print("[INFO] cudnn backend does not support use_128x4_sf_layout=False") - remove_cudnn = True - if remove_cudnn: - backends.remove("cudnn") - if "auto" in backends: - remove_auto = False - if not use_128x4_sf_layout: - print("[INFO] auto backend does not support use_128x4_sf_layout=False") - remove_auto = True - if remove_auto: - backends.remove("auto") - if getattr(args, "autotune", False): - backends_to_remove = [] - for cur_backend in backends: - if cur_backend not in autotune_supported_backends: - print(f"[INFO] {cur_backend} backend does not support autotune") - backends_to_remove.append(cur_backend) - for cur_backend in backends_to_remove: - backends.remove(cur_backend) - - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return input = torch.randn([m, k], device=device, dtype=torch.bfloat16) mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16) @@ -893,7 +839,77 @@ def testMmFp4(args): print(f"[VVERBOSE] {mat2_fp4.dtype = }") alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None - # res = torch.empty([m, n], device="cuda", dtype=res_dtype) + # Completed preparing inputs. Now programmatically filter backends + block_size = 16 if use_nvfp4 else 32 + backends_to_remove = [] + + for backend in backends: + # Skip autotune check for now (handled separately below) + if ( + getattr(args, "autotune", False) + and backend not in autotune_supported_backends + ): + print(f"[INFO] {backend} backend does not support autotune") + backends_to_remove.append(backend) + continue + + try: + from flashinfer.gemm import ( + _mm_fp4_backend_checkers, + _check_mm_fp4_problem_size, + ) + + # Choose correct tensors for this backend + if backend == "trtllm": + b_tensor = mat2_fp4_trtllm.T + b_descale = mat2_inv_s_trtllm.T + else: + b_tensor = mat2_fp4.T + b_descale = mat2_inv_s.T + + # Validate common requirements + _check_mm_fp4_problem_size( + input_fp4, + b_tensor, + input_inv_s, + b_descale, + alpha, + res_dtype, + None, # out + block_size, + not use_128x4_sf_layout, # use_8x4_sf_layout + backend, + use_nvfp4, + ) + + # Validate backend-specific requirements + if backend in _mm_fp4_backend_checkers: + _mm_fp4_backend_checkers[backend]( + input_fp4, + b_tensor, + input_inv_s, + b_descale, + alpha, + res_dtype, + None, # out + block_size, + not use_128x4_sf_layout, + backend, + use_nvfp4, + ) + except Exception as e: + print( + f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" + ) + backends_to_remove.append(backend) + + # Remove unsupported backends + for backend in backends_to_remove: + backends.remove(backend) + + if len(backends) == 0: + print("[ERROR] No backends passed validation. Exiting.") + return def run_backend(backend): if backend in ["cudnn", "trtllm", "cutlass", "auto"]: @@ -924,12 +940,11 @@ def run_backend(backend): args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 ) for cur_backend in backends: - if cur_backend in autotune_supported_backends: - if args.verbose >= 1: - print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") - with autotune(True): - for _ in range(warmup_iters): - run_backend(cur_backend) + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_backend(cur_backend) # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index a069200d6f..74f9d8ee41 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -441,6 +441,10 @@ def forward( _, workspace_buffer, ) = inputs + if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: + a_descale = a_descale.view(torch.uint8) + if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: + b_descale = b_descale.view(torch.uint8) module.fp4_gemm( a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer, tactic ) @@ -1947,7 +1951,7 @@ def _cutlass_gemm_fp4_requirement( return True -@supported_compute_capability([100, 103, 110, 120]) +@supported_compute_capability([100, 103, 110, 120, 121]) def _auto_gemm_fp4_requirement( a: torch.Tensor, b: torch.Tensor, @@ -1985,14 +1989,16 @@ def _auto_gemm_fp4_requirement( return False +_mm_fp4_backend_checkers = { + "cudnn": _cudnn_gemm_fp4_requirement, + "trtllm": _trtllm_gemm_fp4_requirement, + "cutlass": _cutlass_gemm_fp4_requirement, + "auto": _auto_gemm_fp4_requirement, +} + + @backend_requirement( - { - "cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function - "trtllm": _trtllm_gemm_fp4_requirement, - "cutlass": _cutlass_gemm_fp4_requirement, - "auto": _auto_gemm_fp4_requirement, # Auto backend requires at least one backend to be supported on the current device - }, - common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends + backend_checks=_mm_fp4_backend_checkers, common_check=_check_mm_fp4_problem_size ) def mm_fp4( a: torch.Tensor, @@ -2087,7 +2093,7 @@ def mm_fp4( cc_major, cc_minor = get_compute_capability(a.device) # If cuda version is 13 or greater: # cudnn is more performant if cudnn version is 9.14 or greater. - if cuda_major >= 13 and cudnn.backend_version() >= 91400: + if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400: candidate_backends = ("cudnn", "cutlass") # Otherwise, prioritize cutlass else: @@ -2098,11 +2104,11 @@ def mm_fp4( backends = [] for candidate in candidate_backends: # mypy requires explicit type casting for the backend literal - backend_literal = cast( - Literal["cudnn", "trtllm", "cutlass", "auto"], candidate - ) + backend_literal = cast(Literal["cudnn", "trtllm", "cutlass"], candidate) try: - _check_mm_fp4_problem_size( + # Check both common constraints and backend-specific requirements + # to find all compatible backends for this problem instance + if _check_mm_fp4_problem_size( a, b, a_descale, @@ -2114,41 +2120,39 @@ def mm_fp4( use_8x4_sf_layout, backend_literal, use_nvfp4, - ) - backends.append(candidate) + ) and _mm_fp4_backend_checkers[candidate]( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_8x4_sf_layout, + backend_literal, + use_nvfp4, + ): + backends.append(candidate) except Exception: pass else: backends = [backend] # At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'. - runners = [] - for cur_backend in backends: - if cur_backend == "cudnn": - runners.append(_cudnn_gemm_fp4_runner()) - elif cur_backend == "trtllm": - runners.append( - get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner(use_8x4_sf_layout) - ) - elif cur_backend == "cutlass": - if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: - a_descale = a_descale.view(torch.uint8) - if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: - b_descale = b_descale.view(torch.uint8) - - # Dispatch to the correct module based on device architecture - major, _ = get_compute_capability(a.device) - if major == 12: - runners.append( - get_gemm_sm120_module_cutlass_fp4().cutlass_fp4_gemm_runner() - ) - else: - runners.append( - get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner() - ) - else: - # Should not reach this - raise ValueError(f"Unsupported backend: {cur_backend}") + # Lazy initialization of runners to avoid overhead of creating a new runner that will not be used + major, _ = get_compute_capability(a.device) + + backend_to_runner_factory = { + "cudnn": lambda: _cudnn_gemm_fp4_runner(), + "trtllm": lambda: get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner( + use_8x4_sf_layout + ), + "cutlass": lambda: get_gemm_sm120_module_cutlass_fp4().cutlass_fp4_gemm_runner() + if major == 12 + else get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner(), + } + runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends] # Now we have a list of runners for desired & supported backends. tuner = AutoTuner.get() From da049f5a663ba49635d73ea7f083bf3505d29dc6 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 11 Nov 2025 22:14:13 +0000 Subject: [PATCH 08/19] Rebase main --- benchmarks/routines/gemm.py | 55 +++++--------------- flashinfer/gemm/gemm_base.py | 98 ++++++++++++++++-------------------- 2 files changed, 55 insertions(+), 98 deletions(-) diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 833917bf76..cae1d55bf1 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -854,49 +854,20 @@ def testMmFp4(args): continue try: - from flashinfer.gemm import ( - _mm_fp4_backend_checkers, - _check_mm_fp4_problem_size, - ) - - # Choose correct tensors for this backend - if backend == "trtllm": - b_tensor = mat2_fp4_trtllm.T - b_descale = mat2_inv_s_trtllm.T - else: - b_tensor = mat2_fp4.T - b_descale = mat2_inv_s.T - - # Validate common requirements - _check_mm_fp4_problem_size( - input_fp4, - b_tensor, - input_inv_s, - b_descale, - alpha, - res_dtype, - None, # out - block_size, - not use_128x4_sf_layout, # use_8x4_sf_layout - backend, - use_nvfp4, + flashinfer.gemm.mm_fp4( + a=input_fp4, + b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, + a_descale=input_inv_s, + b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T, + alpha=alpha, + out_dtype=res_dtype, + block_size=16 + if use_nvfp4 + else 32, # nvfp4 only supports 16; mxfp4 only supports 32. + use_8x4_sf_layout=not use_128x4_sf_layout, + backend=backend, + use_nvfp4=use_nvfp4, ) - - # Validate backend-specific requirements - if backend in _mm_fp4_backend_checkers: - _mm_fp4_backend_checkers[backend]( - input_fp4, - b_tensor, - input_inv_s, - b_descale, - alpha, - res_dtype, - None, # out - block_size, - not use_128x4_sf_layout, - backend, - use_nvfp4, - ) except Exception as e: print( f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 74f9d8ee41..91f28def36 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -17,7 +17,7 @@ import functools from enum import Enum from types import SimpleNamespace -from typing import List, Literal, Optional, Tuple, cast +from typing import List, Literal, Optional, Tuple from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch @@ -1989,16 +1989,48 @@ def _auto_gemm_fp4_requirement( return False -_mm_fp4_backend_checkers = { - "cudnn": _cudnn_gemm_fp4_requirement, - "trtllm": _trtllm_gemm_fp4_requirement, - "cutlass": _cutlass_gemm_fp4_requirement, - "auto": _auto_gemm_fp4_requirement, -} +def _heuristic_func_mm_fp4( + suitable_backends: List[str], + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn", + use_nvfp4: bool = True, +): + cuda_major, _ = get_cuda_version(a.device) + cc_major, cc_minor = get_compute_capability(a.device) + # If cuda version is 13 or greater: + # cudnn is more performant if cudnn version is 9.14 or greater. + if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400: + candidate_backends = ("cudnn", "cutlass") + # Otherwise, prioritize cutlass + else: + candidate_backends = ("cutlass", "cudnn") + + # Filter to only supported backends for this compute capability + # Note: The requirement function already validated that at least one backend is supported + heuristic_backends = [] + for candidate in candidate_backends: + # mypy requires explicit type casting for the backend literal + if candidate in suitable_backends: + heuristic_backends.append(candidate) + return heuristic_backends @backend_requirement( - backend_checks=_mm_fp4_backend_checkers, common_check=_check_mm_fp4_problem_size + { + "cudnn": _cudnn_gemm_fp4_requirement, + "trtllm": _trtllm_gemm_fp4_requirement, + "cutlass": _cutlass_gemm_fp4_requirement, + }, + common_check=_check_mm_fp4_problem_size, + heuristic_func=_heuristic_func_mm_fp4, ) def mm_fp4( a: torch.Tensor, @@ -2010,7 +2042,7 @@ def mm_fp4( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn", use_nvfp4: bool = True, ) -> torch.Tensor: r"""MM FP4 @@ -2089,53 +2121,7 @@ def mm_fp4( # Auto-select the best backend if backend == "auto": - cuda_major, _ = get_cuda_version(a.device) - cc_major, cc_minor = get_compute_capability(a.device) - # If cuda version is 13 or greater: - # cudnn is more performant if cudnn version is 9.14 or greater. - if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400: - candidate_backends = ("cudnn", "cutlass") - # Otherwise, prioritize cutlass - else: - candidate_backends = ("cutlass", "cudnn") - - # Filter to only supported backends for this compute capability - # Note: The requirement function already validated that at least one backend is supported - backends = [] - for candidate in candidate_backends: - # mypy requires explicit type casting for the backend literal - backend_literal = cast(Literal["cudnn", "trtllm", "cutlass"], candidate) - try: - # Check both common constraints and backend-specific requirements - # to find all compatible backends for this problem instance - if _check_mm_fp4_problem_size( - a, - b, - a_descale, - b_descale, - alpha, - out_dtype, - out, - block_size, - use_8x4_sf_layout, - backend_literal, - use_nvfp4, - ) and _mm_fp4_backend_checkers[candidate]( - a, - b, - a_descale, - b_descale, - alpha, - out_dtype, - out, - block_size, - use_8x4_sf_layout, - backend_literal, - use_nvfp4, - ): - backends.append(candidate) - except Exception: - pass + backends = mm_fp4.suitable_auto_backends else: backends = [backend] From b685e9ec9dce0305fec4589136451357385bb9b8 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 11 Nov 2025 23:01:02 +0000 Subject: [PATCH 09/19] Cleanup --- benchmarks/routines/gemm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index cae1d55bf1..9f95f17fb4 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -891,9 +891,7 @@ def run_backend(backend): b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T, alpha=alpha, out_dtype=res_dtype, - block_size=16 - if use_nvfp4 - else 32, # nvfp4 only supports 16; mxfp4 only supports 32. + block_size=block_size, use_8x4_sf_layout=not use_128x4_sf_layout, backend=backend, use_nvfp4=use_nvfp4, From aa40f5e76ef6e9d47a13e359bd7c277be8ffabd9 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 11 Nov 2025 23:01:22 +0000 Subject: [PATCH 10/19] Cleanup --- flashinfer/gemm/gemm_base.py | 54 ++++++++---------------------------- 1 file changed, 12 insertions(+), 42 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 91f28def36..3e0551fa09 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -93,7 +93,7 @@ def _match_sm_version(device: torch.device, sm_version: list[str]): return device_arch in sm_version -def get_cuda_version(device: torch.device): +def get_cuda_version(): return tuple(map(int, torch.version.cuda.split("."))) # (major, minor) @@ -1951,44 +1951,6 @@ def _cutlass_gemm_fp4_requirement( return True -@supported_compute_capability([100, 103, 110, 120, 121]) -def _auto_gemm_fp4_requirement( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, - block_size: int = 16, - use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", - use_nvfp4: bool = True, -): - # Auto backend requires at least one backend to be supported on the current device - cc_major, cc_minor = get_compute_capability(a.device) - cc_arch = cc_major * 10 + cc_minor - - # Check if at least one backend is supported for this compute capability - candidate_backends = ["cudnn", "cutlass", "trtllm"] - backend_checkers = { - "cudnn": _cudnn_gemm_fp4_requirement, - "cutlass": _cutlass_gemm_fp4_requirement, - # Does not consider trtllm due to different interface. - } - - for candidate in candidate_backends: - checker = backend_checkers[candidate] - if hasattr( - checker, "is_compute_capability_supported" - ) and checker.is_compute_capability_supported(cc_arch): - # At least one backend is supported - return True - - # No backend is supported on this device - return False - - def _heuristic_func_mm_fp4( suitable_backends: List[str], a: torch.Tensor, @@ -2003,8 +1965,16 @@ def _heuristic_func_mm_fp4( backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn", use_nvfp4: bool = True, ): - cuda_major, _ = get_cuda_version(a.device) - cc_major, cc_minor = get_compute_capability(a.device) + r""" + Heuristic function for mm_fp4 backend selection. Routes to either cudnn or cutlass, but not trtllm. + + Logic for which comes first: + - If cuda version is 12 - use cutlass. + - If cuda version is 13 and cudnn version is less than 9.14 - use cutlass. + - If cuda version is 13 and cudnn version is 9.14 or greater - use cudnn. + + """ + cuda_major, _ = get_cuda_version() # If cuda version is 13 or greater: # cudnn is more performant if cudnn version is 9.14 or greater. if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400: @@ -2042,7 +2012,7 @@ def mm_fp4( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ) -> torch.Tensor: r"""MM FP4 From 17a1f28136fcd70d53921be69d099944b344651c Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 11 Nov 2025 23:26:24 +0000 Subject: [PATCH 11/19] Correctly apply common check --- flashinfer/utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 76689bab84..b116e7458d 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -1069,9 +1069,17 @@ def suitable_auto_backends(cc, *args, **kwargs): for backend in backend_checks: req_checker = backend_checks[backend] try: - if req_checker( - *args, **kwargs - ) and req_checker.is_compute_capability_supported(cc): + # Remove 'backend' from kwargs to explicitly pass it to the common check. + kwargs_without_backend = { + k: v for k, v in kwargs.items() if k != "backend" + } + if ( + req_checker(*args, **kwargs) + and common_check( + *args, backend=backend, **kwargs_without_backend + ) + and req_checker.is_compute_capability_supported(cc) + ): suitable_backends.append(backend) except ValueError: continue From b69cadf8bfb477cbe543ec714f4534336e135836 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Wed, 12 Nov 2025 01:45:30 +0000 Subject: [PATCH 12/19] Decorator allow cases when no common_check --- flashinfer/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index b116e7458d..b2ad0f2e76 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -1075,10 +1075,13 @@ def suitable_auto_backends(cc, *args, **kwargs): } if ( req_checker(*args, **kwargs) - and common_check( - *args, backend=backend, **kwargs_without_backend - ) and req_checker.is_compute_capability_supported(cc) + and ( + (common_check is None) + or common_check( + *args, backend=backend, **kwargs_without_backend + ) + ) ): suitable_backends.append(backend) except ValueError: From 8b4bda8f21dd24cb8039cb86e46ab9919533b13e Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 13 Nov 2025 00:41:12 +0000 Subject: [PATCH 13/19] Address comments about code redundancy and cleanliness --- flashinfer/gemm/gemm_base.py | 80 +++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 29 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 3e0551fa09..1fb72af846 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -475,6 +475,17 @@ def get_gemm_sm120_module_cutlass_fp4(): ) +def get_cutlass_fp4_gemm_module( + sm_major: int, +): + if sm_major in [10, 11]: + return get_gemm_sm100_module_cutlass_fp4() + elif sm_major == 12: + return get_gemm_sm120_module_cutlass_fp4() + else: + raise ValueError(f"Unsupported SM major version: {sm_major}") + + @functools.cache def get_tgv_gemm_sm10x_module( dtype: torch.dtype = torch.bfloat16, use_sm_100f: bool = False @@ -1654,7 +1665,7 @@ def mm_fp8( return out -def _cudnn_gemm_fp4( +def _get_cudnn_fp4_gemm_graph( a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, @@ -1664,7 +1675,6 @@ def _cudnn_gemm_fp4( out: Optional[torch.Tensor] = None, block_size: int = 16, use_nvfp4: bool = True, - workspace_buffer: torch.Tensor = None, tactic: int = -1, ): _check_cudnn_availability() @@ -1699,7 +1709,34 @@ def _cudnn_gemm_fp4( use_nvfp4, tactic=tactic, ) + return graph + +def _cudnn_gemm_fp4( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_nvfp4: bool = True, + workspace_buffer: torch.Tensor = None, + tactic: int = -1, +): + graph = _get_cudnn_fp4_gemm_graph( + a=a, + b=b, + a_descale=a_descale, + b_descale=b_descale, + alpha=alpha, + out_dtype=out_dtype, + out=out, + block_size=block_size, + use_nvfp4=use_nvfp4, + tactic=tactic, + ) # execute the fp4 cudnn graph execute_cudnn_gemm_fp4_graph( graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic=tactic @@ -1728,32 +1765,19 @@ def get_valid_tactics( workspace_buffer, ) = inputs - real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) - real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) - batch = real_a_shape[0] - expanded_a_descale_shape, expanded_a_descale_stride = ( - _expand_block_scale_tensor_shape(a_descale, batch) - ) - expanded_b_descale_shape, expanded_b_descale_stride = ( - _expand_block_scale_tensor_shape(b_descale, batch) + graph = _get_cudnn_fp4_gemm_graph( + a=a, + b=b, + a_descale=a_descale, + b_descale=b_descale, + alpha=alpha, + out_dtype=out_dtype, + out=out, + block_size=block_size, + use_nvfp4=use_nvfp4, + tactic=-1, ) - graph = build_plans_cudnn_fp4_gemm_graph( - real_a_shape, - real_a_stride, - real_b_shape, - real_b_stride, - expanded_a_descale_shape, - expanded_a_descale_stride, - expanded_b_descale_shape, - expanded_b_descale_stride, - cudnn.data_type.FP4_E2M1, - _torch_data_type_to_cudnn_data_type(out_dtype), - block_size, - a.device, - alpha is not None, - use_nvfp4, - ) num_plans = graph.get_execution_plan_count() return list(range(num_plans)) @@ -2104,9 +2128,7 @@ def mm_fp4( "trtllm": lambda: get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner( use_8x4_sf_layout ), - "cutlass": lambda: get_gemm_sm120_module_cutlass_fp4().cutlass_fp4_gemm_runner() - if major == 12 - else get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner(), + "cutlass": lambda: get_cutlass_fp4_gemm_module(major).cutlass_fp4_gemm_runner(), } runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends] From 24d17d32d1f09231bbfdac216bbc4e8305343e44 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 13 Nov 2025 01:04:38 +0000 Subject: [PATCH 14/19] Address Jimmy's comment on moving problem size check to backend-specific check --- flashinfer/gemm/gemm_base.py | 13 ++++++++----- flashinfer/utils.py | 17 +++-------------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 1fb72af846..4f37088a04 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1865,11 +1865,6 @@ def _check_mm_fp4_problem_size( f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." ) - if backend != "trtllm" and use_8x4_sf_layout: - raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") - if backend not in ["cudnn", "auto"] and not use_nvfp4: - raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") - if use_nvfp4 and block_size != 16: raise ValueError("nvfp4 only supports block_size = 16.") if not use_nvfp4 and block_size != 32: @@ -1892,6 +1887,8 @@ def _cudnn_gemm_fp4_requirement( backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ): + if use_8x4_sf_layout: + raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") if ( not use_nvfp4 and _match_sm_version(a.device, ["120"]) @@ -1950,6 +1947,8 @@ def _trtllm_gemm_fp4_requirement( backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ): + if not use_nvfp4: + raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") if out_dtype != torch.bfloat16: raise ValueError( f"Unsupported output dtype: {out_dtype}. " @@ -1972,6 +1971,10 @@ def _cutlass_gemm_fp4_requirement( backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ): + if use_8x4_sf_layout: + raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") + if not use_nvfp4: + raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") return True diff --git a/flashinfer/utils.py b/flashinfer/utils.py index b2ad0f2e76..76689bab84 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -1069,20 +1069,9 @@ def suitable_auto_backends(cc, *args, **kwargs): for backend in backend_checks: req_checker = backend_checks[backend] try: - # Remove 'backend' from kwargs to explicitly pass it to the common check. - kwargs_without_backend = { - k: v for k, v in kwargs.items() if k != "backend" - } - if ( - req_checker(*args, **kwargs) - and req_checker.is_compute_capability_supported(cc) - and ( - (common_check is None) - or common_check( - *args, backend=backend, **kwargs_without_backend - ) - ) - ): + if req_checker( + *args, **kwargs + ) and req_checker.is_compute_capability_supported(cc): suitable_backends.append(backend) except ValueError: continue From 837b29031de49a1ed344b110d99b5cd600d42621 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 14 Nov 2025 23:20:46 +0000 Subject: [PATCH 15/19] Final cleanup --- flashinfer/gemm/gemm_base.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 4f37088a04..9f9be68b41 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -2003,21 +2003,15 @@ def _heuristic_func_mm_fp4( """ cuda_major, _ = get_cuda_version() # If cuda version is 13 or greater: - # cudnn is more performant if cudnn version is 9.14 or greater. - if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400: + # cudnn is more performant if cudnn version is 9.15 or greater. + if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500: candidate_backends = ("cudnn", "cutlass") # Otherwise, prioritize cutlass else: candidate_backends = ("cutlass", "cudnn") - # Filter to only supported backends for this compute capability - # Note: The requirement function already validated that at least one backend is supported - heuristic_backends = [] - for candidate in candidate_backends: - # mypy requires explicit type casting for the backend literal - if candidate in suitable_backends: - heuristic_backends.append(candidate) - return heuristic_backends + # Filter and return only supported backends + return [c for c in candidate_backends if c in suitable_backends] @backend_requirement( @@ -2027,7 +2021,7 @@ def _heuristic_func_mm_fp4( "cutlass": _cutlass_gemm_fp4_requirement, }, common_check=_check_mm_fp4_problem_size, - heuristic_func=_heuristic_func_mm_fp4, + heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) def mm_fp4( a: torch.Tensor, From 72eaf90cdbb99aa9292a190bc9ddeccdb0594671 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Mon, 17 Nov 2025 20:00:05 +0000 Subject: [PATCH 16/19] Add comments in response to feedback --- flashinfer/gemm/gemm_base.py | 9 ++++++--- flashinfer/utils.py | 6 ++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 9f9be68b41..fb0484f0fa 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1993,12 +1993,15 @@ def _heuristic_func_mm_fp4( use_nvfp4: bool = True, ): r""" - Heuristic function for mm_fp4 backend selection. Routes to either cudnn or cutlass, but not trtllm. + Heuristic function for mm_fp4 backend selection. Routes to either cudnn or cutlass. + Note: trtllm is not considered in the backend selection because it requires a specific + input quantization (swizzling/shuffling) that differs from the preparation used + for cudnn and cutlass backends. Logic for which comes first: - If cuda version is 12 - use cutlass. - - If cuda version is 13 and cudnn version is less than 9.14 - use cutlass. - - If cuda version is 13 and cudnn version is 9.14 or greater - use cudnn. + - If cuda version is 13 and cudnn version is less than 9.15 - use cutlass. + - If cuda version is 13 and cudnn version is 9.15 or greater - use cudnn. """ cuda_major, _ = get_cuda_version() diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 76689bab84..3fad2f008e 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -921,6 +921,12 @@ def backend_requirement( backends. Should accept the same arguments as the decorated function and return True if requirements are met, False otherwise. In the case where the kernel function does not have any specific backends, this can be decorated with @supported_compute_capability to specify the function's supported compute capabilities. + heuristic_func : callable, optional + An optional function that performs heuristic backend selection when backend is "auto". Does not do anything if backend is not "auto". + Should accept the same arguments as the decorated function. + Should return an ordered list of runnable backends with the most preferred backend first. + When decorated function is not autotuned, the first backend in the heuristic list will be run. + When decorated function is autotuned, the backends in the heuristic list will be autotuned over to find the best backend. Returns ------- From 0be7217f337873442a71af23e0706c27a3520ab3 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Wed, 19 Nov 2025 18:55:24 +0000 Subject: [PATCH 17/19] Don't reinvent get_cuda_version --- flashinfer/gemm/gemm_base.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index fb0484f0fa..a36930e579 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -54,6 +54,7 @@ from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module +from ..jit.cpp_ext import get_cuda_version CUDNN_AVAILABLE = False @@ -93,10 +94,6 @@ def _match_sm_version(device: torch.device, sm_version: list[str]): return device_arch in sm_version -def get_cuda_version(): - return tuple(map(int, torch.version.cuda.split("."))) # (major, minor) - - @functools.cache def get_gemm_module(): module = gen_gemm_module().build_and_load() @@ -2004,7 +2001,7 @@ def _heuristic_func_mm_fp4( - If cuda version is 13 and cudnn version is 9.15 or greater - use cudnn. """ - cuda_major, _ = get_cuda_version() + cuda_major = get_cuda_version().major # If cuda version is 13 or greater: # cudnn is more performant if cudnn version is 9.15 or greater. if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500: From 2f15fc41c6412f1c109314f3c38163e279538320 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 20 Nov 2025 18:22:01 +0000 Subject: [PATCH 18/19] Address comments --- flashinfer/gemm/gemm_base.py | 57 ++++++++++++++++++------------------ flashinfer/utils.py | 7 +++-- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index a36930e579..589c651aca 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1114,7 +1114,6 @@ def _check_cudnn_fp4_availability(): def _is_cublas_fp4_available_in_cudnn(): """Check if cuBLAS backend for FP4 GEMM is available in cuDNN.""" - _check_cudnn_availability() # Check cuDNN backend version for FP4 support (requires cudnn_version == 9.11.1 or cudnn_version >= 9.13) backend_version = cudnn.backend_version() @@ -1166,7 +1165,6 @@ def create_cudnn_execution_plans_fp4_gemm( alpha_is_not_none, use_nvfp4, ): - _check_cudnn_availability() stream = torch.cuda.current_stream(device) with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): scale_type = cudnn.data_type.FP8_E4M3 if use_nvfp4 else cudnn.data_type.FP8_E8M0 @@ -1269,6 +1267,7 @@ def build_plans_cudnn_fp4_gemm_graph( use_nvfp4, tactic: int = -1, ): + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement graph = create_cudnn_execution_plans_fp4_gemm( a_shape, a_stride, @@ -1674,7 +1673,6 @@ def _get_cudnn_fp4_gemm_graph( use_nvfp4: bool = True, tactic: int = -1, ): - _check_cudnn_availability() # the fp4 cudnn graph will be shared for both mm and bmm, so # here we need to get the 3d shape and stride including the # batch dimension for both input and block scale tensors. @@ -1689,6 +1687,7 @@ def _get_cudnn_fp4_gemm_graph( ) # build the fp4 cudnn graph + # Constructed graph is cached, via @functools.cache decorator. graph = build_plans_cudnn_fp4_gemm_graph( real_a_shape, real_a_stride, @@ -1722,6 +1721,7 @@ def _cudnn_gemm_fp4( workspace_buffer: torch.Tensor = None, tactic: int = -1, ): + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement graph = _get_cudnn_fp4_gemm_graph( a=a, b=b, @@ -1748,7 +1748,6 @@ def get_valid_tactics( profile: OptimizationProfile, ) -> List[int]: # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic - _check_cudnn_availability() ( a, b, @@ -1762,6 +1761,7 @@ def get_valid_tactics( workspace_buffer, ) = inputs + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement graph = _get_cudnn_fp4_gemm_graph( a=a, b=b, @@ -1821,10 +1821,10 @@ def _check_mm_fp4_problem_size( b_descale: torch.Tensor, alpha: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, # unused block_size: int = 16, - use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", + use_8x4_sf_layout: bool = False, # unused + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): # Generic checks @@ -1878,10 +1878,10 @@ def _cudnn_gemm_fp4_requirement( b_descale: torch.Tensor, alpha: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, # unused block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): if use_8x4_sf_layout: @@ -1908,7 +1908,8 @@ def _cudnn_gemm_fp4_requirement( _expand_block_scale_tensor_shape(b_descale, batch) ) - # build the fp4 cudnn graph + # build the fp4 cudnn graph. This graph will be cached & reused in mm_fp4() + # because the graph is constructed with @functools.cache decorator graph = create_cudnn_execution_plans_fp4_gemm( real_a_shape, real_a_stride, @@ -1932,16 +1933,16 @@ def _cudnn_gemm_fp4_requirement( @supported_compute_capability([100, 103]) def _trtllm_gemm_fp4_requirement( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: Optional[torch.Tensor] = None, + a: torch.Tensor, # unused + b: torch.Tensor, # unused + a_descale: torch.Tensor, # unused + b_descale: torch.Tensor, # unused + alpha: Optional[torch.Tensor] = None, # unused out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, - block_size: int = 16, - use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, # unused + use_8x4_sf_layout: bool = False, # unused + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): if not use_nvfp4: @@ -1956,16 +1957,16 @@ def _trtllm_gemm_fp4_requirement( @supported_compute_capability([100, 103, 110, 120, 121]) def _cutlass_gemm_fp4_requirement( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, - block_size: int = 16, + a: torch.Tensor, # unused + b: torch.Tensor, # unused + a_descale: torch.Tensor, # unused + b_descale: torch.Tensor, # unused + alpha: Optional[torch.Tensor] = None, # unused + out_dtype: torch.dtype = torch.bfloat16, # unused + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, # unused use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): if use_8x4_sf_layout: diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 3fad2f008e..e323125efa 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -922,7 +922,8 @@ def backend_requirement( True if requirements are met, False otherwise. In the case where the kernel function does not have any specific backends, this can be decorated with @supported_compute_capability to specify the function's supported compute capabilities. heuristic_func : callable, optional - An optional function that performs heuristic backend selection when backend is "auto". Does not do anything if backend is not "auto". + A function that performs heuristic backend selection when backend is "auto". + Must be provided if backend is "auto". Does not do anything if backend is not "auto". Should accept the same arguments as the decorated function. Should return an ordered list of runnable backends with the most preferred backend first. When decorated function is not autotuned, the first backend in the heuristic list will be run. @@ -1082,8 +1083,8 @@ def suitable_auto_backends(cc, *args, **kwargs): except ValueError: continue # If a heuristic function is provided, filter the suitable backends based on the heuristic function - if heuristic_func is not None: - suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) + assert heuristic_func is not None, "Heuristic function must be provided" + suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) if not suitable_backends: return False wrapper.suitable_auto_backends = suitable_backends From fe2070b0c98d42bf80b22d6d40aa24824d4bce50 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 20 Nov 2025 21:49:03 +0000 Subject: [PATCH 19/19] Updated test_decorators to include some heuristic backend --- tests/utils/test_decorators.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index ebbda781fb..f8659c2e44 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -344,7 +344,27 @@ def _cutlass_check(x, backend): def _cudnn_check(x, backend): return x.shape[0] > 5 - @backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check}) + # When using an auto backend, some heuristic function must exist + def _heuristic_func(suitable_backends, x, backend): + candidate_backends = None + if x.shape[0] > 5: + candidate_backends = ["cudnn", "cutlass"] + else: + candidate_backends = ["cutlass", "cudnn"] + + heuristic_backends = [] + for backend in candidate_backends: + if backend in suitable_backends: + heuristic_backends.append(backend) + return heuristic_backends + + @backend_requirement( + backend_checks={ + "cutlass": _cutlass_check, + "cudnn": _cudnn_check, + }, + heuristic_func=_heuristic_func, + ) def my_kernel(x, backend="auto"): backends = my_kernel.suitable_auto_backends if x.shape[0] > 5: