Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ def dtype_str_to_torch_dtype(dtype_str):
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cutlass", "cute-dsl"],
"10.3": ["cutlass", "cute-dsl"],
"10.0": ["cudnn", "cutlass", "cute-dsl"],
"10.3": ["cudnn", "cutlass", "cute-dsl"],
"11.0": ["cutlass"],
"12.0": [],
},
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ def testMmMxfp8(args):
res_dtype = args.out_dtype
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = ["cutlass", "cute-dsl", "auto"]
autotune_supported_backends = ["cudnn", "cutlass", "cute-dsl", "auto"]
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)
Expand Down Expand Up @@ -1349,7 +1349,7 @@ def testMmMxfp8(args):
print(f"[VVERBOSE] {mat2_scale.dtype = }")

def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
if backend in ["cutlass", "cute-dsl", "auto"]:
if backend in ["cudnn", "cutlass", "cute-dsl", "auto"]:
return flashinfer.gemm.mm_mxfp8(
a=input_mxfp8,
b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t()
Expand Down
77 changes: 69 additions & 8 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,7 +2518,7 @@ def _check_mm_mxfp8_problem_size(
b_descale: torch.Tensor,
out: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.bfloat16,
backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", # unused
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused
) -> bool:
# Generic checks
## pre-check the input tensors and block scale tensors
Expand Down Expand Up @@ -2632,11 +2632,30 @@ def _cutlass_gemm_mxfp8_requirement(
b_descale: torch.Tensor,
out: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.bfloat16,
backend: Literal["cutlass", "cute-dsl", "auto"] = "auto",
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto",
):
return True


@supported_compute_capability([100, 103])
def _cudnn_mm_mxfp8_requirement(
a: torch.Tensor, # unused
b: torch.Tensor, # unused
a_descale: torch.Tensor,
b_descale: torch.Tensor,
out: Optional[torch.Tensor] = None, # unused
out_dtype: torch.dtype = torch.bfloat16, # unused
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused
):
# cuDNN MXFP8 path currently expects swizzled 1D scale tensors.
if a_descale.ndim != 1 or b_descale.ndim != 1:
raise ValueError(
"cudnn mm_mxfp8 requires swizzled 1D scale tensors for a_descale and b_descale."
)
_check_cudnn_availability()
return True
Comment on lines +2640 to +2656
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the similar FP4 requirement function pattern
rg -A 25 "_cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 3909


🏁 Script executed:

# Check how _cudnn_mm_mxfp8_requirement is currently implemented
rg -B 2 -A 20 "def _cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 861


🏁 Script executed:

# Find the cuDNN graph creation function mentioned in the proposed fix
rg -A 10 "create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 759


🏁 Script executed:

# Check how backend_requirement decorator wires the requirement helper
rg -B 3 -A 15 "backend_requirement.*_cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for _get_cudnn_mxfp8_gemm_graph to see how graph checking happens during execution
rg -A 30 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1766


🏁 Script executed:

# Check the complete FP4 requirement function to see full validation logic
rg -A 50 "def _cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.py | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 1980


🏁 Script executed:

# Look at how mm_fp4 backend_requirement is set up
rg -B 5 -A 10 '@backend_requirement.*mm_fp4' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check mm_mxfp8 backend_requirement setup
rg -B 5 -A 10 '@backend_requirement.*mm_mxfp8' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are any graph caching mechanisms between requirement and execution
rg "graph.*cache" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 285


🏁 Script executed:

# Verify the actual function signature and parameters for MXFP8 graph creation
rg -B 2 -A 15 "def create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 545


🏁 Script executed:

# Check what block_size is used in mm_mxfp8 public API
rg -A 5 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | grep -A 5 "block_size"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Verify FP4 uses graph caching with `@functools.cache`
rg -B 5 "def create_cudnn_execution_plans_fp4_gemm" flashinfer/gemm/gemm_base.py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 232


🏁 Script executed:

# Find where MXFP8 tensors are converted to 3D in the actual execution path
rg -B 10 -A 5 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1041


🏁 Script executed:

# Look for unsqueeze or reshape before calling _get_cudnn_mxfp8_gemm_graph
rg -B 15 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py | grep -E "(unsqueeze|reshape|view|3d|batch)"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the mm_mxfp8 public API to see how tensors flow into the requirement and runners
rg -A 100 "@backend_requirement\(" flashinfer/gemm/gemm_base.py | grep -A 80 "mm_mxfp8"

Repository: flashinfer-ai/flashinfer

Length of output: 12485


🏁 Script executed:

# Check if there's shape conversion happening elsewhere in the MXFP8 execution path
rg -B 5 -A 20 "_cudnn_gemm_mxfp8" flashinfer/gemm/gemm_base.py | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1467


🏁 Script executed:

# Look for where MXFP8 handles 2D to 3D conversion before graph creation
rg -B 10 "a_shape=a.shape" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 383


🏁 Script executed:

# Check if FP4 requirement also provides parameters like block_size
rg -A 70 "def _cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.py | grep -E "(block_size|graph.check_support)"

Repository: flashinfer-ai/flashinfer

Length of output: 139


🏁 Script executed:

# Get the full _get_cudnn_mxfp8_gemm_graph implementation to see shape handling
rg -A 30 "def _get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 924


🏁 Script executed:

# Check if there's any shape unsqueezing in the MXFP8 path
rg -B 5 -A 15 "create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1287


🏁 Script executed:

# Get full context of _cudnn_gemm_mxfp8 to see if shapes are modified
rg -B 5 -A 35 "def _cudnn_gemm_mxfp8" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 2083


🏁 Script executed:

# Check if there's shape unsqueezing happening in the bmm_mxfp8_requirement or elsewhere
rg -A 30 "_cudnn_bmm_mxfp8_requirement\|_cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the API documentation and flow between mm_mxfp8 and the actual runner
rg -B 10 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | tail -20

Repository: flashinfer-ai/flashinfer

Length of output: 430


🏁 Script executed:

# Check if create_cudnn_execution_plans_mxfp8_gemm actually expects 3D or handles 2D
rg -A 50 "def create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 2109


🏁 Script executed:

# Verify if there's explicit shape validation and see how bmm (which should be 3D) uses it
rg -B 5 "_cudnn_bmm_mxfp8_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 461


🏁 Script executed:

# Check the _cudnn_mm_mxfp8_runner to see how it handles shapes
rg -A 40 "def _cudnn_mm_mxfp8_runner" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1271


🏁 Script executed:

# Check the complete mm_mxfp8 function body to see where shapes might be converted
rg -A 200 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | head -250

Repository: flashinfer-ai/flashinfer

Length of output: 7237


🏁 Script executed:

# Check if scale tensors need special handling in MXFP8 requirement
rg -B 5 -A 20 "_expand_block_scale_tensor_shape" flashinfer/gemm/gemm_base.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2052


🏁 Script executed:

# Look for where scale tensor validation happens in MXFP8
rg "a_descale\|b_descale" flashinfer/gemm/gemm_base.py | grep -E "(shape|stride|ndim)" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if mm_mxfp8 requirement gets the a_descale, b_descale parameters or just checks availability
rg -A 15 "def _cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 705


Add cuDNN graph support validation to the requirement helper, matching FP4 pattern.

The FP4 requirement function validates the cuDNN graph during requirement checking via graph.check_support(), ensuring unsupported shapes fail fast. The MXFP8 requirement currently only checks scale tensor layout and installation, allowing shapes that fail graph.check_support() at execution time to still mark cuDNN as suitable for backend="auto". This causes unnecessary fallback to CUTLASS or CuTe DSL.

Call create_cudnn_execution_plans_mxfp8_gemm() and graph.check_support() in the requirement function to validate that the specific shapes are compatible with cuDNN before marking it as a viable backend.

Proposed fix
 def _cudnn_mm_mxfp8_requirement(
-    a: torch.Tensor,  # unused
-    b: torch.Tensor,  # unused
+    a: torch.Tensor,
+    b: torch.Tensor,
     a_descale: torch.Tensor,
     b_descale: torch.Tensor,
-    out: Optional[torch.Tensor] = None,  # unused
+    out: Optional[torch.Tensor] = None,
-    out_dtype: torch.dtype = torch.bfloat16,  # unused
+    out_dtype: torch.dtype = torch.bfloat16,
     backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto",  # unused
 ):
     # cuDNN MXFP8 path currently expects swizzled 1D scale tensors.
     if a_descale.ndim != 1 or b_descale.ndim != 1:
         raise ValueError(
             "cudnn mm_mxfp8 requires swizzled 1D scale tensors for a_descale and b_descale."
         )
     _check_cudnn_availability()
+    # Validate the graph is supported for these specific shapes (batch=1 for mm_mxfp8)
+    a_3d = a.unsqueeze(0)
+    b_3d = b.unsqueeze(0)
+    graph = create_cudnn_execution_plans_mxfp8_gemm(
+        a_shape=a_3d.shape,
+        a_stride=a_3d.stride(),
+        b_shape=b_3d.shape,
+        b_stride=b_3d.stride(),
+        a_type=_torch_data_type_to_cudnn_data_type(a.dtype),
+        b_type=_torch_data_type_to_cudnn_data_type(b.dtype),
+        block_size=32,
+        o_type=_torch_data_type_to_cudnn_data_type(out_dtype),
+        device=a.device,
+    )
+    graph.check_support()
     return True
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2640 - 2656, The
_cudnn_mm_mxfp8_requirement helper currently only verifies scale tensor layout
and cuDNN presence; update it to also validate cuDNN graph support exactly like
the FP4 helper by constructing execution plans and calling graph.check_support()
so shapes that cuDNN cannot handle are rejected early: invoke
create_cudnn_execution_plans_mxfp8_gemm(...) with the given tensor
descriptors/scales (matching the shape/stride assumptions used at runtime),
obtain the resulting graph(s), and call graph.check_support() (handling any
exceptions or returning False when unsupported) before returning True from
_cudnn_mm_mxfp8_requirement.



@supported_compute_capability([100, 103])
def _cute_dsl_gemm_mxfp8_requirement(
a: torch.Tensor, # unused
Expand All @@ -2645,7 +2664,7 @@ def _cute_dsl_gemm_mxfp8_requirement(
b_descale: torch.Tensor,
out: Optional[torch.Tensor] = None, # unused
out_dtype: torch.dtype = torch.bfloat16, # unused
backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", # unused
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused
):
# CuTe DSL MXFP8 path currently expects swizzled 1D block scales
# in F8_128x4 layout for both A and B.
Expand Down Expand Up @@ -3050,15 +3069,18 @@ def _heuristic_func_mm_mxfp8(
b_descale: torch.Tensor,
out: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.bfloat16,
backend: Literal["cutlass", "cute-dsl", "auto"] = "auto",
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto",
) -> List[str]:
if CUDNN_AVAILABLE and "cudnn" in suitable_backends:
return ["cudnn"]
if "cutlass" in suitable_backends:
return ["cutlass"]
return []


@backend_requirement(
{
"cudnn": _cudnn_mm_mxfp8_requirement,
"cutlass": _cutlass_gemm_mxfp8_requirement,
"cute-dsl": _cute_dsl_gemm_mxfp8_requirement,
},
Expand All @@ -3073,7 +3095,7 @@ def mm_mxfp8(
b_descale: torch.Tensor,
out: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.bfloat16,
backend: Literal["cutlass", "cute-dsl", "auto"] = "auto",
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto",
) -> torch.Tensor:
r"""MM MXFP8 (block size 32)

Expand All @@ -3100,14 +3122,16 @@ def mm_mxfp8(
For 1D swizzled format, it's flattened from (N_padded, K_padded) layout.

out: Optional[torch.Tensor]
Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to ``None``.
Out tensor, shape (m, n), bf16 or fp16. Defaults to ``None``.

out_dtype: torch.dtype
Output dtype, bf16 or fp16. Defaults to ``torch.bfloat16``.

backend: Literal["cutlass", "cute-dsl", "auto"]
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"]
The backend to use for the operation. Defaults to ``"auto"``.
``"auto"`` selects the CUTLASS backend.
``"auto"`` selects a supported backend (currently cuDNN or CUTLASS).
``"cudnn"`` requires swizzled 1D scales produced by
``mxfp8_quantize(..., is_sf_swizzled_layout=True)``.
The ``"cute-dsl"`` backend currently requires swizzled 1D scales
(``mxfp8_quantize(..., is_sf_swizzled_layout=True)``).

Expand Down Expand Up @@ -3183,6 +3207,7 @@ def mm_mxfp8(
major, minor = get_compute_capability(a.device)

backend_to_runner_factory = {
"cudnn": lambda: _cudnn_mm_mxfp8_runner(),
"cutlass": lambda: get_cutlass_mxfp8_gemm_module(
major
).cutlass_mxfp8_gemm_runner(),
Expand Down Expand Up @@ -5922,6 +5947,42 @@ def forward(
return CudnnMxfp8GemmRunner()


def _cudnn_mm_mxfp8_runner():
class CudnnMmMxfp8GemmRunner(TunableRunner):
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
# cuDNN provides internal heuristics; use the default tactic entry.
return [0]

def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
do_preparation: bool = False,
**kwargs,
) -> torch.Tensor:
a, b, scale_a, scale_b, _, out, workspace_buffer = inputs
a_3d = a.unsqueeze(0)
b_3d = b.unsqueeze(0)
out_3d = out.unsqueeze(0)
_cudnn_gemm_mxfp8(
a=a_3d,
b=b_3d,
a_descale=scale_a,
b_descale=scale_b,
out=out_3d,
out_dtype=out.dtype,
workspace_buffer=workspace_buffer,
tactic=tactic,
)
return out

return CudnnMmMxfp8GemmRunner()


def mxfp8_gemm_sm100(
a: torch.Tensor,
b: torch.Tensor,
Expand Down
28 changes: 28 additions & 0 deletions tests/gemm/test_mm_mxfp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,34 @@ def test_mm_mxfp8_small_m(m, n, k):
)


def test_mm_mxfp8_cudnn_swizzled_single_gemm():
_run_mm_mxfp8(
320,
384,
224,
torch.bfloat16,
True, # cuDNN path currently requires swizzled 1D scales
torch.bfloat16,
"cudnn",
auto_tuning=False,
provide_out=True,
)


def test_mm_mxfp8_auto_swizzled_single_gemm():
_run_mm_mxfp8(
384,
512,
256,
torch.bfloat16,
True, # auto path should select a supported swizzled-scale backend
torch.bfloat16,
"auto",
auto_tuning=False,
provide_out=True,
)


def test_mm_mxfp8_invalid_input_dtype():
_skip_if_unsupported()
m, n, k = 128, 128, 128
Expand Down