Skip to content

feat: add cudnn support for single-GEMM MXFP8#2782

Open
scottyokim wants to merge 1 commit intoflashinfer-ai:mainfrom
scottyokim:scottyokim/cudnn_single_GEMM_MXFP8
Open

feat: add cudnn support for single-GEMM MXFP8#2782
scottyokim wants to merge 1 commit intoflashinfer-ai:mainfrom
scottyokim:scottyokim/cudnn_single_GEMM_MXFP8

Conversation

@scottyokim
Copy link
Copy Markdown

@scottyokim scottyokim commented Mar 13, 2026

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added cuDNN backend support for matrix multiplication operations, enabling broader hardware compatibility.
    • System now automatically detects and utilizes cuDNN when available.
  • Tests

    • Added test coverage for cuDNN backend operations and automatic backend selection scenarios.

Signed-off-by: Scott Yokim <syokim@nvidia.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 13, 2026

📝 Walkthrough

Walkthrough

This pull request adds cuDNN as a supported backend for MXFP8 GEMM operations. Changes span benchmark utilities, core GEMM logic, and test cases to integrate cuDNN backend selection, requirement validation, and execution alongside existing CuTLaSS/Cute-DSL paths.

Changes

Cohort / File(s) Summary
Benchmark Configuration
benchmarks/routines/flashinfer_benchmark_utils.py, benchmarks/routines/gemm.py
Added "cudnn" to backend lists for mm_mxfp8 routine entries in compute versions 10.0 and 10.3, expanding supported backends from ["cutlass", "cute-dsl"] to ["cudnn", "cutlass", "cute-dsl"].
Core GEMM Implementation
flashinfer/gemm/gemm_base.py
Introduced cuDNN backend support for MXFP8 GEMM path with: updated type annotations to include "cudnn" in Literal types across _check_mm_mxfp8_problem_size, _cutlass_gemm_mxfp8_requirement, _heuristic_func_mm_mxfp8, and mm_mxfp8; new _cudnn_mm_mxfp8_requirement function for availability and constraint validation; new _cudnn_mm_mxfp8_runner function providing cuDNN-based MXFP8 execution; and updated backend factory mapping to dispatch to cuDNN runner when selected.
Test Coverage
tests/gemm/test_mm_mxfp8.py
Added two new test functions: test_mm_mxfp8_cudnn_swizzled_single_gemm and test_mm_mxfp8_auto_swizzled_single_gemm to validate swizzled-scale MXFP8 operations under cuDNN and auto backend selection modes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

op: gemm, run-ci

Suggested reviewers

  • aleozlx
  • yongwww
  • nvmbreughe
  • jimmyzho
  • jiahanc
  • bkryu

Poem

🐰 A new backend hops into sight,
cuDNN blazing fast and right!
MXFP8 GEMM paths expand with glee,
From cutlass roots to cudnn spree! ✨
Code reviews hop—let's verify with care! 🎯

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The description is entirely composed of the template with no actual content filled in; the Description and Related Issues sections are empty, and all checklist items are unchecked. Fill in the Description section explaining the changes and motivation, link any related issues, and check off completed items from the Pre-commit Checks and Tests sections.
Docstring Coverage ⚠️ Warning Docstring coverage is 11.76% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding cuDNN support for single-GEMM MXFP8 operations, which aligns with the code changes across multiple files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use your project's `ruff` configuration to improve the quality of Python code reviews.

Add a Ruff configuration file to your project to customize how CodeRabbit runs ruff.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the mm_mxfp8 functionality by integrating cuDNN support for mixed-precision matrix multiplication. This expansion allows the system to leverage cuDNN's optimized routines for GEMM operations on compatible NVIDIA architectures, providing an additional high-performance backend. The changes involve updating backend registration, implementing a specific cuDNN runner, and adjusting the backend selection logic to intelligently utilize cuDNN when appropriate, particularly with swizzled 1D scales.

Highlights

  • cuDNN Backend Integration: Added cuDNN as a supported backend for MXFP8 GEMM operations, specifically for compute capabilities 10.0 and 10.3.
  • cuDNN Runner Implementation: Implemented a new _cudnn_mm_mxfp8_runner to facilitate cuDNN-specific MXFP8 GEMM execution, handling tensor preparation and calling the underlying cuDNN GEMM function.
  • Backend Selection Logic and Documentation: Updated the mm_mxfp8 function's backend selection heuristic to prioritize cuDNN when available and revised its documentation to detail cuDNN support and its requirement for swizzled 1D scales.
  • New Test Cases: Introduced dedicated test cases to verify the correct functionality of the cuDNN backend and the 'auto' backend selection for MXFP8 GEMM with swizzled scales.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/routines/flashinfer_benchmark_utils.py
    • Added "cudnn" to the list of supported backends for compute capabilities 10.0 and 10.3.
  • benchmarks/routines/gemm.py
    • Included "cudnn" in the autotune_supported_backends list.
    • Modified the run_backend function to recognize and utilize the "cudnn" backend.
  • flashinfer/gemm/gemm_base.py
    • Extended the Literal type for the backend parameter in several functions to include "cudnn".
    • Introduced _cudnn_mm_mxfp8_requirement to define conditions for cuDNN MXFP8 usage, specifically requiring swizzled 1D scale tensors.
    • Updated _heuristic_func_mm_mxfp8 to prioritize the "cudnn" backend if available and suitable.
    • Added "cudnn" to the backend_to_runner_factory mapping.
    • Implemented _cudnn_mm_mxfp8_runner as a TunableRunner for cuDNN MXFP8 GEMM operations, handling input tensor reshaping and calling _cudnn_gemm_mxfp8.
    • Revised the docstring for mm_mxfp8 to clarify the backend options, including cuDNN's requirements for swizzled 1D scales.
  • tests/gemm/test_mm_mxfp8.py
    • Added test_mm_mxfp8_cudnn_swizzled_single_gemm to specifically test the cuDNN backend with swizzled scales.
    • Added test_mm_mxfp8_auto_swizzled_single_gemm to verify the "auto" backend selection with swizzled scales, expecting it to pick cuDNN.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the cuDNN backend for mixed-precision FP8 (MXFP8) GEMM operations. The changes involve updating backend lists and type hints to include "cudnn", defining specific requirements for the cuDNN MXFP8 path (such as swizzled 1D scale tensors), modifying the backend selection heuristic to prioritize cuDNN, implementing a dedicated cuDNN MXFP8 runner, and updating documentation to reflect these changes. New test cases have also been added to validate the cuDNN backend and the auto-selection mechanism with swizzled scales.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2640-2656: The _cudnn_mm_mxfp8_requirement helper currently only
verifies scale tensor layout and cuDNN presence; update it to also validate
cuDNN graph support exactly like the FP4 helper by constructing execution plans
and calling graph.check_support() so shapes that cuDNN cannot handle are
rejected early: invoke create_cudnn_execution_plans_mxfp8_gemm(...) with the
given tensor descriptors/scales (matching the shape/stride assumptions used at
runtime), obtain the resulting graph(s), and call graph.check_support()
(handling any exceptions or returning False when unsupported) before returning
True from _cudnn_mm_mxfp8_requirement.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 35654514-f0f3-459f-85cc-a037b21803dd

📥 Commits

Reviewing files that changed from the base of the PR and between 74e99e8 and 7cc5ff6.

📒 Files selected for processing (4)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/gemm.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_mm_mxfp8.py

Comment on lines +2640 to +2656
@supported_compute_capability([100, 103])
def _cudnn_mm_mxfp8_requirement(
a: torch.Tensor, # unused
b: torch.Tensor, # unused
a_descale: torch.Tensor,
b_descale: torch.Tensor,
out: Optional[torch.Tensor] = None, # unused
out_dtype: torch.dtype = torch.bfloat16, # unused
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused
):
# cuDNN MXFP8 path currently expects swizzled 1D scale tensors.
if a_descale.ndim != 1 or b_descale.ndim != 1:
raise ValueError(
"cudnn mm_mxfp8 requires swizzled 1D scale tensors for a_descale and b_descale."
)
_check_cudnn_availability()
return True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the similar FP4 requirement function pattern
rg -A 25 "_cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 3909


🏁 Script executed:

# Check how _cudnn_mm_mxfp8_requirement is currently implemented
rg -B 2 -A 20 "def _cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 861


🏁 Script executed:

# Find the cuDNN graph creation function mentioned in the proposed fix
rg -A 10 "create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 759


🏁 Script executed:

# Check how backend_requirement decorator wires the requirement helper
rg -B 3 -A 15 "backend_requirement.*_cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for _get_cudnn_mxfp8_gemm_graph to see how graph checking happens during execution
rg -A 30 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1766


🏁 Script executed:

# Check the complete FP4 requirement function to see full validation logic
rg -A 50 "def _cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.py | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 1980


🏁 Script executed:

# Look at how mm_fp4 backend_requirement is set up
rg -B 5 -A 10 '@backend_requirement.*mm_fp4' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check mm_mxfp8 backend_requirement setup
rg -B 5 -A 10 '@backend_requirement.*mm_mxfp8' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are any graph caching mechanisms between requirement and execution
rg "graph.*cache" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 285


🏁 Script executed:

# Verify the actual function signature and parameters for MXFP8 graph creation
rg -B 2 -A 15 "def create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 545


🏁 Script executed:

# Check what block_size is used in mm_mxfp8 public API
rg -A 5 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | grep -A 5 "block_size"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Verify FP4 uses graph caching with `@functools.cache`
rg -B 5 "def create_cudnn_execution_plans_fp4_gemm" flashinfer/gemm/gemm_base.py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 232


🏁 Script executed:

# Find where MXFP8 tensors are converted to 3D in the actual execution path
rg -B 10 -A 5 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1041


🏁 Script executed:

# Look for unsqueeze or reshape before calling _get_cudnn_mxfp8_gemm_graph
rg -B 15 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py | grep -E "(unsqueeze|reshape|view|3d|batch)"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the mm_mxfp8 public API to see how tensors flow into the requirement and runners
rg -A 100 "@backend_requirement\(" flashinfer/gemm/gemm_base.py | grep -A 80 "mm_mxfp8"

Repository: flashinfer-ai/flashinfer

Length of output: 12485


🏁 Script executed:

# Check if there's shape conversion happening elsewhere in the MXFP8 execution path
rg -B 5 -A 20 "_cudnn_gemm_mxfp8" flashinfer/gemm/gemm_base.py | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1467


🏁 Script executed:

# Look for where MXFP8 handles 2D to 3D conversion before graph creation
rg -B 10 "a_shape=a.shape" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 383


🏁 Script executed:

# Check if FP4 requirement also provides parameters like block_size
rg -A 70 "def _cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.py | grep -E "(block_size|graph.check_support)"

Repository: flashinfer-ai/flashinfer

Length of output: 139


🏁 Script executed:

# Get the full _get_cudnn_mxfp8_gemm_graph implementation to see shape handling
rg -A 30 "def _get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 924


🏁 Script executed:

# Check if there's any shape unsqueezing in the MXFP8 path
rg -B 5 -A 15 "create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1287


🏁 Script executed:

# Get full context of _cudnn_gemm_mxfp8 to see if shapes are modified
rg -B 5 -A 35 "def _cudnn_gemm_mxfp8" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 2083


🏁 Script executed:

# Check if there's shape unsqueezing happening in the bmm_mxfp8_requirement or elsewhere
rg -A 30 "_cudnn_bmm_mxfp8_requirement\|_cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the API documentation and flow between mm_mxfp8 and the actual runner
rg -B 10 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | tail -20

Repository: flashinfer-ai/flashinfer

Length of output: 430


🏁 Script executed:

# Check if create_cudnn_execution_plans_mxfp8_gemm actually expects 3D or handles 2D
rg -A 50 "def create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 2109


🏁 Script executed:

# Verify if there's explicit shape validation and see how bmm (which should be 3D) uses it
rg -B 5 "_cudnn_bmm_mxfp8_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 461


🏁 Script executed:

# Check the _cudnn_mm_mxfp8_runner to see how it handles shapes
rg -A 40 "def _cudnn_mm_mxfp8_runner" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1271


🏁 Script executed:

# Check the complete mm_mxfp8 function body to see where shapes might be converted
rg -A 200 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | head -250

Repository: flashinfer-ai/flashinfer

Length of output: 7237


🏁 Script executed:

# Check if scale tensors need special handling in MXFP8 requirement
rg -B 5 -A 20 "_expand_block_scale_tensor_shape" flashinfer/gemm/gemm_base.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2052


🏁 Script executed:

# Look for where scale tensor validation happens in MXFP8
rg "a_descale\|b_descale" flashinfer/gemm/gemm_base.py | grep -E "(shape|stride|ndim)" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if mm_mxfp8 requirement gets the a_descale, b_descale parameters or just checks availability
rg -A 15 "def _cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 705


Add cuDNN graph support validation to the requirement helper, matching FP4 pattern.

The FP4 requirement function validates the cuDNN graph during requirement checking via graph.check_support(), ensuring unsupported shapes fail fast. The MXFP8 requirement currently only checks scale tensor layout and installation, allowing shapes that fail graph.check_support() at execution time to still mark cuDNN as suitable for backend="auto". This causes unnecessary fallback to CUTLASS or CuTe DSL.

Call create_cudnn_execution_plans_mxfp8_gemm() and graph.check_support() in the requirement function to validate that the specific shapes are compatible with cuDNN before marking it as a viable backend.

Proposed fix
 def _cudnn_mm_mxfp8_requirement(
-    a: torch.Tensor,  # unused
-    b: torch.Tensor,  # unused
+    a: torch.Tensor,
+    b: torch.Tensor,
     a_descale: torch.Tensor,
     b_descale: torch.Tensor,
-    out: Optional[torch.Tensor] = None,  # unused
+    out: Optional[torch.Tensor] = None,
-    out_dtype: torch.dtype = torch.bfloat16,  # unused
+    out_dtype: torch.dtype = torch.bfloat16,
     backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto",  # unused
 ):
     # cuDNN MXFP8 path currently expects swizzled 1D scale tensors.
     if a_descale.ndim != 1 or b_descale.ndim != 1:
         raise ValueError(
             "cudnn mm_mxfp8 requires swizzled 1D scale tensors for a_descale and b_descale."
         )
     _check_cudnn_availability()
+    # Validate the graph is supported for these specific shapes (batch=1 for mm_mxfp8)
+    a_3d = a.unsqueeze(0)
+    b_3d = b.unsqueeze(0)
+    graph = create_cudnn_execution_plans_mxfp8_gemm(
+        a_shape=a_3d.shape,
+        a_stride=a_3d.stride(),
+        b_shape=b_3d.shape,
+        b_stride=b_3d.stride(),
+        a_type=_torch_data_type_to_cudnn_data_type(a.dtype),
+        b_type=_torch_data_type_to_cudnn_data_type(b.dtype),
+        block_size=32,
+        o_type=_torch_data_type_to_cudnn_data_type(out_dtype),
+        device=a.device,
+    )
+    graph.check_support()
     return True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2640 - 2656, The
_cudnn_mm_mxfp8_requirement helper currently only verifies scale tensor layout
and cuDNN presence; update it to also validate cuDNN graph support exactly like
the FP4 helper by constructing execution plans and calling graph.check_support()
so shapes that cuDNN cannot handle are rejected early: invoke
create_cudnn_execution_plans_mxfp8_gemm(...) with the given tensor
descriptors/scales (matching the shape/stride assumptions used at runtime),
obtain the resulting graph(s), and call graph.check_support() (handling any
exceptions or returning False when unsupported) before returning True from
_cudnn_mm_mxfp8_requirement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant