diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 95fe833b3c..3a8a28d18a 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -378,8 +378,8 @@ def dtype_str_to_torch_dtype(dtype_str): "8.6": [], "8.9": [], "9.0": [], - "10.0": ["cutlass", "cute-dsl"], - "10.3": ["cutlass", "cute-dsl"], + "10.0": ["cudnn", "cutlass", "cute-dsl"], + "10.3": ["cudnn", "cutlass", "cute-dsl"], "11.0": ["cutlass"], "12.0": [], }, diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 89c031fc76..511c72a2a5 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -1296,7 +1296,7 @@ def testMmMxfp8(args): res_dtype = args.out_dtype is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - autotune_supported_backends = ["cutlass", "cute-dsl", "auto"] + autotune_supported_backends = ["cudnn", "cutlass", "cute-dsl", "auto"] res = [] backends = filter_backends_by_compute_capability(backends, args.routine, device) @@ -1349,7 +1349,7 @@ def testMmMxfp8(args): print(f"[VVERBOSE] {mat2_scale.dtype = }") def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): - if backend in ["cutlass", "cute-dsl", "auto"]: + if backend in ["cudnn", "cutlass", "cute-dsl", "auto"]: return flashinfer.gemm.mm_mxfp8( a=input_mxfp8, b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t() diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 60bc5eb76f..a535bd948c 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -2518,7 +2518,7 @@ def _check_mm_mxfp8_problem_size( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", # unused + backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused ) -> bool: # Generic checks ## pre-check the input tensors and block scale tensors @@ -2632,11 +2632,30 @@ def _cutlass_gemm_mxfp8_requirement( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", + backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", ): return True +@supported_compute_capability([100, 103]) +def _cudnn_mm_mxfp8_requirement( + a: torch.Tensor, # unused + b: torch.Tensor, # unused + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out: Optional[torch.Tensor] = None, # unused + out_dtype: torch.dtype = torch.bfloat16, # unused + backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused +): + # cuDNN MXFP8 path currently expects swizzled 1D scale tensors. + if a_descale.ndim != 1 or b_descale.ndim != 1: + raise ValueError( + "cudnn mm_mxfp8 requires swizzled 1D scale tensors for a_descale and b_descale." + ) + _check_cudnn_availability() + return True + + @supported_compute_capability([100, 103]) def _cute_dsl_gemm_mxfp8_requirement( a: torch.Tensor, # unused @@ -2645,7 +2664,7 @@ def _cute_dsl_gemm_mxfp8_requirement( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, # unused out_dtype: torch.dtype = torch.bfloat16, # unused - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", # unused + backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused ): # CuTe DSL MXFP8 path currently expects swizzled 1D block scales # in F8_128x4 layout for both A and B. @@ -3050,8 +3069,10 @@ def _heuristic_func_mm_mxfp8( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", + backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", ) -> List[str]: + if CUDNN_AVAILABLE and "cudnn" in suitable_backends: + return ["cudnn"] if "cutlass" in suitable_backends: return ["cutlass"] return [] @@ -3059,6 +3080,7 @@ def _heuristic_func_mm_mxfp8( @backend_requirement( { + "cudnn": _cudnn_mm_mxfp8_requirement, "cutlass": _cutlass_gemm_mxfp8_requirement, "cute-dsl": _cute_dsl_gemm_mxfp8_requirement, }, @@ -3073,7 +3095,7 @@ def mm_mxfp8( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", + backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", ) -> torch.Tensor: r"""MM MXFP8 (block size 32) @@ -3100,14 +3122,16 @@ def mm_mxfp8( For 1D swizzled format, it's flattened from (N_padded, K_padded) layout. out: Optional[torch.Tensor] - Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to ``None``. + Out tensor, shape (m, n), bf16 or fp16. Defaults to ``None``. out_dtype: torch.dtype Output dtype, bf16 or fp16. Defaults to ``torch.bfloat16``. - backend: Literal["cutlass", "cute-dsl", "auto"] + backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] The backend to use for the operation. Defaults to ``"auto"``. - ``"auto"`` selects the CUTLASS backend. + ``"auto"`` selects a supported backend (currently cuDNN or CUTLASS). + ``"cudnn"`` requires swizzled 1D scales produced by + ``mxfp8_quantize(..., is_sf_swizzled_layout=True)``. The ``"cute-dsl"`` backend currently requires swizzled 1D scales (``mxfp8_quantize(..., is_sf_swizzled_layout=True)``). @@ -3183,6 +3207,7 @@ def mm_mxfp8( major, minor = get_compute_capability(a.device) backend_to_runner_factory = { + "cudnn": lambda: _cudnn_mm_mxfp8_runner(), "cutlass": lambda: get_cutlass_mxfp8_gemm_module( major ).cutlass_mxfp8_gemm_runner(), @@ -5922,6 +5947,42 @@ def forward( return CudnnMxfp8GemmRunner() +def _cudnn_mm_mxfp8_runner(): + class CudnnMmMxfp8GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + # cuDNN provides internal heuristics; use the default tactic entry. + return [0] + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + a, b, scale_a, scale_b, _, out, workspace_buffer = inputs + a_3d = a.unsqueeze(0) + b_3d = b.unsqueeze(0) + out_3d = out.unsqueeze(0) + _cudnn_gemm_mxfp8( + a=a_3d, + b=b_3d, + a_descale=scale_a, + b_descale=scale_b, + out=out_3d, + out_dtype=out.dtype, + workspace_buffer=workspace_buffer, + tactic=tactic, + ) + return out + + return CudnnMmMxfp8GemmRunner() + + def mxfp8_gemm_sm100( a: torch.Tensor, b: torch.Tensor, diff --git a/tests/gemm/test_mm_mxfp8.py b/tests/gemm/test_mm_mxfp8.py index ea87645e1d..0c34509052 100644 --- a/tests/gemm/test_mm_mxfp8.py +++ b/tests/gemm/test_mm_mxfp8.py @@ -207,6 +207,34 @@ def test_mm_mxfp8_small_m(m, n, k): ) +def test_mm_mxfp8_cudnn_swizzled_single_gemm(): + _run_mm_mxfp8( + 320, + 384, + 224, + torch.bfloat16, + True, # cuDNN path currently requires swizzled 1D scales + torch.bfloat16, + "cudnn", + auto_tuning=False, + provide_out=True, + ) + + +def test_mm_mxfp8_auto_swizzled_single_gemm(): + _run_mm_mxfp8( + 384, + 512, + 256, + torch.bfloat16, + True, # auto path should select a supported swizzled-scale backend + torch.bfloat16, + "auto", + auto_tuning=False, + provide_out=True, + ) + + def test_mm_mxfp8_invalid_input_dtype(): _skip_if_unsupported() m, n, k = 128, 128, 128