feat: add cudnn support for single-GEMM MXFP8#2782
feat: add cudnn support for single-GEMM MXFP8#2782scottyokim wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Scott Yokim <syokim@nvidia.com>
📝 WalkthroughWalkthroughThis pull request adds cuDNN as a supported backend for MXFP8 GEMM operations. Changes span benchmark utilities, core GEMM logic, and test cases to integrate cuDNN backend selection, requirement validation, and execution alongside existing CuTLaSS/Cute-DSL paths. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use your project's `ruff` configuration to improve the quality of Python code reviews.Add a Ruff configuration file to your project to customize how CodeRabbit runs |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for the cuDNN backend for mixed-precision FP8 (MXFP8) GEMM operations. The changes involve updating backend lists and type hints to include "cudnn", defining specific requirements for the cuDNN MXFP8 path (such as swizzled 1D scale tensors), modifying the backend selection heuristic to prioritize cuDNN, implementing a dedicated cuDNN MXFP8 runner, and updating documentation to reflect these changes. New test cases have also been added to validate the cuDNN backend and the auto-selection mechanism with swizzled scales.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2640-2656: The _cudnn_mm_mxfp8_requirement helper currently only
verifies scale tensor layout and cuDNN presence; update it to also validate
cuDNN graph support exactly like the FP4 helper by constructing execution plans
and calling graph.check_support() so shapes that cuDNN cannot handle are
rejected early: invoke create_cudnn_execution_plans_mxfp8_gemm(...) with the
given tensor descriptors/scales (matching the shape/stride assumptions used at
runtime), obtain the resulting graph(s), and call graph.check_support()
(handling any exceptions or returning False when unsupported) before returning
True from _cudnn_mm_mxfp8_requirement.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 35654514-f0f3-459f-85cc-a037b21803dd
📒 Files selected for processing (4)
benchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/gemm.pyflashinfer/gemm/gemm_base.pytests/gemm/test_mm_mxfp8.py
| @supported_compute_capability([100, 103]) | ||
| def _cudnn_mm_mxfp8_requirement( | ||
| a: torch.Tensor, # unused | ||
| b: torch.Tensor, # unused | ||
| a_descale: torch.Tensor, | ||
| b_descale: torch.Tensor, | ||
| out: Optional[torch.Tensor] = None, # unused | ||
| out_dtype: torch.dtype = torch.bfloat16, # unused | ||
| backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused | ||
| ): | ||
| # cuDNN MXFP8 path currently expects swizzled 1D scale tensors. | ||
| if a_descale.ndim != 1 or b_descale.ndim != 1: | ||
| raise ValueError( | ||
| "cudnn mm_mxfp8 requires swizzled 1D scale tensors for a_descale and b_descale." | ||
| ) | ||
| _check_cudnn_availability() | ||
| return True |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the similar FP4 requirement function pattern
rg -A 25 "_cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 3909
🏁 Script executed:
# Check how _cudnn_mm_mxfp8_requirement is currently implemented
rg -B 2 -A 20 "def _cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 861
🏁 Script executed:
# Find the cuDNN graph creation function mentioned in the proposed fix
rg -A 10 "create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 759
🏁 Script executed:
# Check how backend_requirement decorator wires the requirement helper
rg -B 3 -A 15 "backend_requirement.*_cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for _get_cudnn_mxfp8_gemm_graph to see how graph checking happens during execution
rg -A 30 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1766
🏁 Script executed:
# Check the complete FP4 requirement function to see full validation logic
rg -A 50 "def _cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.py | head -80Repository: flashinfer-ai/flashinfer
Length of output: 1980
🏁 Script executed:
# Look at how mm_fp4 backend_requirement is set up
rg -B 5 -A 10 '@backend_requirement.*mm_fp4' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check mm_mxfp8 backend_requirement setup
rg -B 5 -A 10 '@backend_requirement.*mm_mxfp8' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are any graph caching mechanisms between requirement and execution
rg "graph.*cache" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 285
🏁 Script executed:
# Verify the actual function signature and parameters for MXFP8 graph creation
rg -B 2 -A 15 "def create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 545
🏁 Script executed:
# Check what block_size is used in mm_mxfp8 public API
rg -A 5 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | grep -A 5 "block_size"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Verify FP4 uses graph caching with `@functools.cache`
rg -B 5 "def create_cudnn_execution_plans_fp4_gemm" flashinfer/gemm/gemm_base.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 232
🏁 Script executed:
# Find where MXFP8 tensors are converted to 3D in the actual execution path
rg -B 10 -A 5 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1041
🏁 Script executed:
# Look for unsqueeze or reshape before calling _get_cudnn_mxfp8_gemm_graph
rg -B 15 "_get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.py | grep -E "(unsqueeze|reshape|view|3d|batch)"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check the mm_mxfp8 public API to see how tensors flow into the requirement and runners
rg -A 100 "@backend_requirement\(" flashinfer/gemm/gemm_base.py | grep -A 80 "mm_mxfp8"Repository: flashinfer-ai/flashinfer
Length of output: 12485
🏁 Script executed:
# Check if there's shape conversion happening elsewhere in the MXFP8 execution path
rg -B 5 -A 20 "_cudnn_gemm_mxfp8" flashinfer/gemm/gemm_base.py | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1467
🏁 Script executed:
# Look for where MXFP8 handles 2D to 3D conversion before graph creation
rg -B 10 "a_shape=a.shape" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 383
🏁 Script executed:
# Check if FP4 requirement also provides parameters like block_size
rg -A 70 "def _cudnn_gemm_fp4_requirement" flashinfer/gemm/gemm_base.py | grep -E "(block_size|graph.check_support)"Repository: flashinfer-ai/flashinfer
Length of output: 139
🏁 Script executed:
# Get the full _get_cudnn_mxfp8_gemm_graph implementation to see shape handling
rg -A 30 "def _get_cudnn_mxfp8_gemm_graph" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 924
🏁 Script executed:
# Check if there's any shape unsqueezing in the MXFP8 path
rg -B 5 -A 15 "create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1287
🏁 Script executed:
# Get full context of _cudnn_gemm_mxfp8 to see if shapes are modified
rg -B 5 -A 35 "def _cudnn_gemm_mxfp8" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 2083
🏁 Script executed:
# Check if there's shape unsqueezing happening in the bmm_mxfp8_requirement or elsewhere
rg -A 30 "_cudnn_bmm_mxfp8_requirement\|_cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check the API documentation and flow between mm_mxfp8 and the actual runner
rg -B 10 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | tail -20Repository: flashinfer-ai/flashinfer
Length of output: 430
🏁 Script executed:
# Check if create_cudnn_execution_plans_mxfp8_gemm actually expects 3D or handles 2D
rg -A 50 "def create_cudnn_execution_plans_mxfp8_gemm" flashinfer/gemm/gemm_base.py | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2109
🏁 Script executed:
# Verify if there's explicit shape validation and see how bmm (which should be 3D) uses it
rg -B 5 "_cudnn_bmm_mxfp8_requirement" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 461
🏁 Script executed:
# Check the _cudnn_mm_mxfp8_runner to see how it handles shapes
rg -A 40 "def _cudnn_mm_mxfp8_runner" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1271
🏁 Script executed:
# Check the complete mm_mxfp8 function body to see where shapes might be converted
rg -A 200 "def mm_mxfp8" flashinfer/gemm/gemm_base.py | head -250Repository: flashinfer-ai/flashinfer
Length of output: 7237
🏁 Script executed:
# Check if scale tensors need special handling in MXFP8 requirement
rg -B 5 -A 20 "_expand_block_scale_tensor_shape" flashinfer/gemm/gemm_base.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2052
🏁 Script executed:
# Look for where scale tensor validation happens in MXFP8
rg "a_descale\|b_descale" flashinfer/gemm/gemm_base.py | grep -E "(shape|stride|ndim)" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if mm_mxfp8 requirement gets the a_descale, b_descale parameters or just checks availability
rg -A 15 "def _cudnn_mm_mxfp8_requirement" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 705
Add cuDNN graph support validation to the requirement helper, matching FP4 pattern.
The FP4 requirement function validates the cuDNN graph during requirement checking via graph.check_support(), ensuring unsupported shapes fail fast. The MXFP8 requirement currently only checks scale tensor layout and installation, allowing shapes that fail graph.check_support() at execution time to still mark cuDNN as suitable for backend="auto". This causes unnecessary fallback to CUTLASS or CuTe DSL.
Call create_cudnn_execution_plans_mxfp8_gemm() and graph.check_support() in the requirement function to validate that the specific shapes are compatible with cuDNN before marking it as a viable backend.
Proposed fix
def _cudnn_mm_mxfp8_requirement(
- a: torch.Tensor, # unused
- b: torch.Tensor, # unused
+ a: torch.Tensor,
+ b: torch.Tensor,
a_descale: torch.Tensor,
b_descale: torch.Tensor,
- out: Optional[torch.Tensor] = None, # unused
+ out: Optional[torch.Tensor] = None,
- out_dtype: torch.dtype = torch.bfloat16, # unused
+ out_dtype: torch.dtype = torch.bfloat16,
backend: Literal["cudnn", "cutlass", "cute-dsl", "auto"] = "auto", # unused
):
# cuDNN MXFP8 path currently expects swizzled 1D scale tensors.
if a_descale.ndim != 1 or b_descale.ndim != 1:
raise ValueError(
"cudnn mm_mxfp8 requires swizzled 1D scale tensors for a_descale and b_descale."
)
_check_cudnn_availability()
+ # Validate the graph is supported for these specific shapes (batch=1 for mm_mxfp8)
+ a_3d = a.unsqueeze(0)
+ b_3d = b.unsqueeze(0)
+ graph = create_cudnn_execution_plans_mxfp8_gemm(
+ a_shape=a_3d.shape,
+ a_stride=a_3d.stride(),
+ b_shape=b_3d.shape,
+ b_stride=b_3d.stride(),
+ a_type=_torch_data_type_to_cudnn_data_type(a.dtype),
+ b_type=_torch_data_type_to_cudnn_data_type(b.dtype),
+ block_size=32,
+ o_type=_torch_data_type_to_cudnn_data_type(out_dtype),
+ device=a.device,
+ )
+ graph.check_support()
return True🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` around lines 2640 - 2656, The
_cudnn_mm_mxfp8_requirement helper currently only verifies scale tensor layout
and cuDNN presence; update it to also validate cuDNN graph support exactly like
the FP4 helper by constructing execution plans and calling graph.check_support()
so shapes that cuDNN cannot handle are rejected early: invoke
create_cudnn_execution_plans_mxfp8_gemm(...) with the given tensor
descriptors/scales (matching the shape/stride assumptions used at runtime),
obtain the resulting graph(s), and call graph.check_support() (handling any
exceptions or returning False when unsupported) before returning True from
_cudnn_mm_mxfp8_requirement.
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
Tests