Skip to content

perf: prefer cuDNN first for mm_fp4 on CUDA>=13 and cuDNN>=9.15 (SM100/SM103)#2664

Open
mmangkad wants to merge 1 commit intoflashinfer-ai:mainfrom
mmangkad:perf/mm-fp4-cudnn-priority-cuda13
Open

perf: prefer cuDNN first for mm_fp4 on CUDA>=13 and cuDNN>=9.15 (SM100/SM103)#2664
mmangkad wants to merge 1 commit intoflashinfer-ai:mainfrom
mmangkad:perf/mm-fp4-cudnn-priority-cuda13

Conversation

@mmangkad
Copy link
Copy Markdown

@mmangkad mmangkad commented Mar 2, 2026

📌 Description

This changes mm_fp4 auto backend selection to prefer cuDNN first (instead of CUTLASS-first on SM103) for CUDA 13 with cuDNN backend >= 9.15 on both SM100 and SM103. This is basically reverting the previous SM103-specific choice based on updated results: with nvidia-cudnn-cu13==9.19.1.2, cuDNN is better in most cases on SM103. See the comparison plots below.

image image

Summary by CodeRabbit

  • Performance
    • Optimized backend selection for matrix multiplication operations on NVIDIA GPUs with CUDA 13+ and cuDNN 9.15+, streamlining library prioritization for improved efficiency.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines the backend selection logic for mm_fp4 operations, shifting the preference from CUTLASS to cuDNN under specific conditions. This adjustment, driven by recent performance benchmarks, aims to enhance the efficiency of mm_fp4 on newer NVIDIA architectures (SM100 and SM103) when using CUDA 13 and cuDNN 9.15 or later.

Highlights

  • Backend Preference Update: Changed the mm_fp4 auto backend selection to prioritize cuDNN over CUTLASS for CUDA 13 and cuDNN versions 9.15 or greater.
  • Expanded Scope: This cuDNN preference now applies to both SM100 (B200) and SM103 (B300) architectures, removing the previous SM103-specific choice.
  • Performance Optimization: The change is based on updated benchmarks showing cuDNN is generally more performant on SM103 with nvidia-cudnn-cu13==9.19.1.2.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gemm/gemm_base.py
    • Modified the _heuristic_func_mm_fp4 function to prioritize cuDNN for SM100 and SM103 when CUDA version is 13+ and cuDNN version is 9.15+.
    • Removed the logic for distinguishing between SM100 and SM103 for backend selection within this specific condition.
Activity
  • No specific activity (comments, reviews, progress) has been provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 2, 2026

📝 Walkthrough

Walkthrough

The change simplifies backend selection heuristics for FP16/FP4 matrix multiplication in CUDA 13+ with cuDNN >= 9.15 by removing SM100/SM103 device branching logic and unconditionally prioritizing cuDNN first for both GPU variants.

Changes

Cohort / File(s) Summary
Backend Heuristic Simplification
flashinfer/gemm/gemm_base.py
Removed conditional branching for SM100 vs SM103 in FP16/FP4 MM path; now uniformly sets cuDNN-first prioritization (("cudnn", "cutlass")) for both under CUDA 13+ and cuDNN >= 9.15 conditions. Updated related comments and docstring.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

op: gemm

Suggested reviewers

  • bkryu
  • nvmbreughe
  • jimmyzho
  • nv-yunzheq

Poem

🐰 The branches once split like a garden so wide,
SM100, SM103, standing side by side,
But now cudnn leads with unified cheer,
One path forward, crystal clear! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: preferring cuDNN first for mm_fp4 on CUDA>=13 and cuDNN>=9.15 for both SM100/SM103.
Description check ✅ Passed The description covers the key change (preferring cuDNN first instead of CUTLASS-first on SM103) and includes supporting evidence with performance comparison plots showing cuDNN's superiority.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the backend selection logic for mm_fp4 to prefer cuDNN over CUTLASS on CUDA 13+ with cuDNN 9.15+ for both SM100 and SM103 architectures, based on updated performance benchmarks. The implementation correctly simplifies the logic in _heuristic_func_mm_fp4 by removing the compute capability check, which is no longer necessary. The code is clearer and aligns with the stated performance goals.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3639-3650: The heuristic in _heuristic_func_mm_fp4 currently
prioritizes cudnn for CUDA>=13 and cuDNN>=9.15 regardless of GPU WA; update the
condition to only prefer cudnn when the SM matches SM100 or SM103 by calling the
existing _match_sm_version() helper (e.g., verify _match_sm_version("100") or
_match_sm_version("103") before selecting ("cudnn","cutlass")). Keep the
existing fallback to ("cutlass","cudnn") for other SMs and maintain the same
CUDA/cuDNN version checks and CUDNN_AVAILABLE guard.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c25bd14 and 9e7a4bb.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py

Comment on lines +3639 to 3650
- Use cudnn first for both SM100 (B200) and SM103 (B300).

"""
cuda_major = get_cuda_version().major
# Get compute capability to distinguish between SM100 (10.0) and SM103 (10.3)
major, minor = get_compute_capability(a.device)
is_sm103 = major == 10 and minor == 3

# If cuda version is 13 or greater and cudnn version is 9.15 or greater:
# On SM103 (B300), cutlass is more performant than cudnn.
# On SM100 (B200), cudnn is more performant than cutlass.
# If cuda version is 13 or greater and cudnn version is 9.15 or greater,
# prioritize cudnn for both SM100 (B200) and SM103 (B300).
if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500:
if is_sm103:
candidate_backends = ("cutlass", "cudnn")
else:
candidate_backends = ("cudnn", "cutlass")
candidate_backends = ("cudnn", "cutlass")
# Otherwise, prioritize cutlass
else:
candidate_backends = ("cutlass", "cudnn")
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Examine the _heuristic_func_mm_fp4 function and its decorator
rg -n -B 10 'def _heuristic_func_mm_fp4' flashinfer/gemm/gemm_base.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 339


🏁 Script executed:

# Check lines around 3639-3650 to see the complete function and context
sed -n '3625,3660p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1462


🏁 Script executed:

# Search for _match_sm_version usage to understand its applicability
rg -n '_match_sm_version' flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 476


🏁 Script executed:

# Search for other uses of compute capability checks in similar heuristic functions
rg -n 'def _heuristic_func' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 272


🏁 Script executed:

# Find the decorator for mm_fp4 to confirm supported architectures
rg -n -B 5 'def mm_fp4\(' flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 272


🏁 Script executed:

# Also check what the decorator looks like for similar functions
rg -n '@supported_compute_capability' flashinfer/gemm/gemm_base.py | grep -E 'mm_fp4|mm_mxfp8|mm_bf16' -A 1 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search more broadly for mm_fp4 with decorator
rg -n -B 15 'def mm_fp4' flashinfer/gemm/gemm_base.py | grep -E '@supported_compute_capability|def mm_fp4'

Repository: flashinfer-ai/flashinfer

Length of output: 84


🏁 Script executed:

# Find the complete decorator for mm_fp4 function
sed -n '3700,3760p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1651


🏁 Script executed:

# Search for `@supported_compute_capability` in the entire file to understand the pattern
rg -n '@supported_compute_capability' flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1143


🏁 Script executed:

# Get the exact decorator line(s) immediately before mm_fp4 function
sed -n '3740,3760p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 657


🏁 Script executed:

# Search backward from line 3754 to find `@supported_compute_capability` for mm_fp4
sed -n '3600,3760p' flashinfer/gemm/gemm_base.py | grep -n '@supported_compute_capability' | tail -1

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Let me search for mm_fp4 in the file with more context
rg -n 'mm_fp4' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1973


🏁 Script executed:

# Check if mm_fp4 has any decorator by looking at a wider range
sed -n '3730,3765p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1101


🏁 Script executed:

# Look at the requirement functions to see if they have `@supported_compute_capability`
sed -n '3020,3030p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 422


🏁 Script executed:

# Check all lines around the requirements for mm_fp4
sed -n '3015,3145p' flashinfer/gemm/gemm_base.py | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 2184


🏁 Script executed:

# Check all requirement functions for mm_fp4 and their decorators
sed -n '3084,3145p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 2321


🏁 Script executed:

# Also check if there are any other architectural guards in _heuristic_func_mm_fp4 that I might have missed
sed -n '3614,3660p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1814


🏁 Script executed:

# Check the backend_requirement decorator to understand how suitable_backends is populated
rg -n -A 10 'def backend_requirement' flashinfer/gemm/gemm_base.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Verify that _match_sm_version is used for architectural checks in similar heuristics
rg -n -B 5 -A 10 '_match_sm_version' flashinfer/gemm/gemm_base.py | grep -A 10 '_heuristic_func'

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are other heuristics that use SM guards
rg -n -B 2 -A 15 'def _heuristic_func_mm_mxfp8' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 578


Heuristic lacks SM100/SM103 guard despite documented scope.

The _heuristic_func_mm_fp4 implementation at line 3646 prioritizes cudnn whenever CUDA>=13 and cuDNN>=9.15, without checking architecture. The docstring promises this behavior only for SM100/SM103, but the code applies it to all supported architectures (SM110, SM120, SM121 also have suitable backends). This can unintentionally shift backend preferences on architectures where it wasn't intended.

The fix should add an SM version check using the existing _match_sm_version() utility:

Suggested fix
+    is_sm100_or_sm103 = _match_sm_version(a.device, ["100", "103"])
     # If cuda version is 13 or greater and cudnn version is 9.15 or greater,
     # prioritize cudnn for both SM100 (B200) and SM103 (B300).
     if (
         CUDNN_AVAILABLE
+        and is_sm100_or_sm103
         and cuda_major >= 13
         and cudnn.backend_version() >= 91500
     ):
         candidate_backends = ("cudnn", "cutlass")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 3639 - 3650, The heuristic in
_heuristic_func_mm_fp4 currently prioritizes cudnn for CUDA>=13 and cuDNN>=9.15
regardless of GPU WA; update the condition to only prefer cudnn when the SM
matches SM100 or SM103 by calling the existing _match_sm_version() helper (e.g.,
verify _match_sm_version("100") or _match_sm_version("103") before selecting
("cudnn","cutlass")). Keep the existing fallback to ("cutlass","cudnn") for
other SMs and maintain the same CUDA/cuDNN version checks and CUDNN_AVAILABLE
guard.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is right, but I basically reverted to the older heuristic, which means back then the cuDNN-first priority also applied to SM110/SM120/SM121. I’m not sure now if this is also the best for those.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mmangkad You're right to be cautious about SM110/SM120/SM121. Looking at the PR description, the benchmark data only covers SM100 and SM103—there's no performance comparison for SM110/SM120/SM121 with the cuDNN-first heuristic.

My recommendation: Add the SM100/SM103 guard as suggested. Here's why:

  1. Align with tested scope: The PR's rationale ("cuDNN is reported to be better in most cases on SM103") and benchmark plots only validate SM100 and SM103. Applying the change to SM110/SM120/SM121 without data is extending beyond the tested scope.

  2. Conservative approach: Without benchmarks, we can't be sure cuDNN-first is optimal for SM110/SM120/SM121. The guard keeps the behavior for those architectures unchanged until validated.

  3. Separate decision: If you want cuDNN-first for SM110/SM120/SM121, that should be a separate PR with:

    • Benchmark data showing cuDNN vs. CUTLASS on those architectures
    • Clear documentation of the performance win

Trade-off: The current code (without guard) applies a potentially unverified optimization to SM110/SM120/SM121. The guard (conservative) keeps them on the existing CUTLASS-first heuristic until you have data showing cuDNN is better.

Would you like to:

  • Add the guard now (keep SM110/SM120/SM121 on CUTLASS-first)
  • Run benchmarks for SM110/SM120/SM121 first, then decide
  • Keep the current code and document that SM110/SM120/SM121 need validation?

🧠 Learnings used
Learnt from: danisereb
Repo: flashinfer-ai/flashinfer PR: 2464
File: include/flashinfer/gemm/mxfp8_gemm_template_sm100.h:148-163
Timestamp: 2026-02-04T10:08:47.455Z
Learning: In flashinfer GEMM template implementations (e.g., include/flashinfer/gemm/fp4_gemm_template_sm100.h, mxfp8_gemm_template_sm100.h), the Sm10x11xOnly architecture check wrapper uses a pattern where only thread0() prints an error message and calls __trap() when running on unsupported architectures. This pattern is intentional and working in production code, so consistency should be maintained across similar implementations.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` → `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.

Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

@mmangkad
Copy link
Copy Markdown
Author

mmangkad commented Mar 2, 2026

SM120:
image

@mmangkad
Copy link
Copy Markdown
Author

mmangkad commented Mar 4, 2026

Hi @bkryu, could you please take a look at this? Thanks!

Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mmangkad, do you mind clarifying which cuBLAS version you might be using in these tests?

You might be able to get it from pip3 freeze | grep nvidia-cublas. If not, maybe try querying pytorch?

cc @b8zhong, since the results here seem to be different with what you observed in #2375

@mmangkad
Copy link
Copy Markdown
Author

mmangkad commented Mar 5, 2026

@bkryu the tests used nvidia-cublas==13.2.1.1. I've also told @b8zhong about this.

@bkryu bkryu added the run-ci label Mar 5, 2026
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 6, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !387 has been created, and the CI pipeline #45541796 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #45541796: 10/20 passed

@mmangkad mmangkad requested a review from bkryu March 9, 2026 04:51
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 9, 2026

PR itself looks good to me. cc @dhiraj113 and @YangXu1990uiuc because there is some ongoing work on dynamic shapes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants