perf: prefer cuDNN first for mm_fp4 on CUDA>=13 and cuDNN>=9.15 (SM100/SM103)#2664
perf: prefer cuDNN first for mm_fp4 on CUDA>=13 and cuDNN>=9.15 (SM100/SM103)#2664mmangkad wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines the backend selection logic for Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThe change simplifies backend selection heuristics for FP16/FP4 matrix multiplication in CUDA 13+ with cuDNN >= 9.15 by removing SM100/SM103 device branching logic and unconditionally prioritizing cuDNN first for both GPU variants. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the backend selection logic for mm_fp4 to prefer cuDNN over CUTLASS on CUDA 13+ with cuDNN 9.15+ for both SM100 and SM103 architectures, based on updated performance benchmarks. The implementation correctly simplifies the logic in _heuristic_func_mm_fp4 by removing the compute capability check, which is no longer necessary. The code is clearer and aligns with the stated performance goals.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3639-3650: The heuristic in _heuristic_func_mm_fp4 currently
prioritizes cudnn for CUDA>=13 and cuDNN>=9.15 regardless of GPU WA; update the
condition to only prefer cudnn when the SM matches SM100 or SM103 by calling the
existing _match_sm_version() helper (e.g., verify _match_sm_version("100") or
_match_sm_version("103") before selecting ("cudnn","cutlass")). Keep the
existing fallback to ("cutlass","cudnn") for other SMs and maintain the same
CUDA/cuDNN version checks and CUDNN_AVAILABLE guard.
| - Use cudnn first for both SM100 (B200) and SM103 (B300). | ||
|
|
||
| """ | ||
| cuda_major = get_cuda_version().major | ||
| # Get compute capability to distinguish between SM100 (10.0) and SM103 (10.3) | ||
| major, minor = get_compute_capability(a.device) | ||
| is_sm103 = major == 10 and minor == 3 | ||
|
|
||
| # If cuda version is 13 or greater and cudnn version is 9.15 or greater: | ||
| # On SM103 (B300), cutlass is more performant than cudnn. | ||
| # On SM100 (B200), cudnn is more performant than cutlass. | ||
| # If cuda version is 13 or greater and cudnn version is 9.15 or greater, | ||
| # prioritize cudnn for both SM100 (B200) and SM103 (B300). | ||
| if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500: | ||
| if is_sm103: | ||
| candidate_backends = ("cutlass", "cudnn") | ||
| else: | ||
| candidate_backends = ("cudnn", "cutlass") | ||
| candidate_backends = ("cudnn", "cutlass") | ||
| # Otherwise, prioritize cutlass | ||
| else: | ||
| candidate_backends = ("cutlass", "cudnn") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Examine the _heuristic_func_mm_fp4 function and its decorator
rg -n -B 10 'def _heuristic_func_mm_fp4' flashinfer/gemm/gemm_base.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 339
🏁 Script executed:
# Check lines around 3639-3650 to see the complete function and context
sed -n '3625,3660p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1462
🏁 Script executed:
# Search for _match_sm_version usage to understand its applicability
rg -n '_match_sm_version' flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 476
🏁 Script executed:
# Search for other uses of compute capability checks in similar heuristic functions
rg -n 'def _heuristic_func' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 272
🏁 Script executed:
# Find the decorator for mm_fp4 to confirm supported architectures
rg -n -B 5 'def mm_fp4\(' flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 272
🏁 Script executed:
# Also check what the decorator looks like for similar functions
rg -n '@supported_compute_capability' flashinfer/gemm/gemm_base.py | grep -E 'mm_fp4|mm_mxfp8|mm_bf16' -A 1 -B 1Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search more broadly for mm_fp4 with decorator
rg -n -B 15 'def mm_fp4' flashinfer/gemm/gemm_base.py | grep -E '@supported_compute_capability|def mm_fp4'Repository: flashinfer-ai/flashinfer
Length of output: 84
🏁 Script executed:
# Find the complete decorator for mm_fp4 function
sed -n '3700,3760p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1651
🏁 Script executed:
# Search for `@supported_compute_capability` in the entire file to understand the pattern
rg -n '@supported_compute_capability' flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1143
🏁 Script executed:
# Get the exact decorator line(s) immediately before mm_fp4 function
sed -n '3740,3760p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 657
🏁 Script executed:
# Search backward from line 3754 to find `@supported_compute_capability` for mm_fp4
sed -n '3600,3760p' flashinfer/gemm/gemm_base.py | grep -n '@supported_compute_capability' | tail -1Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Let me search for mm_fp4 in the file with more context
rg -n 'mm_fp4' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1973
🏁 Script executed:
# Check if mm_fp4 has any decorator by looking at a wider range
sed -n '3730,3765p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1101
🏁 Script executed:
# Look at the requirement functions to see if they have `@supported_compute_capability`
sed -n '3020,3030p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 422
🏁 Script executed:
# Check all lines around the requirements for mm_fp4
sed -n '3015,3145p' flashinfer/gemm/gemm_base.py | head -60Repository: flashinfer-ai/flashinfer
Length of output: 2184
🏁 Script executed:
# Check all requirement functions for mm_fp4 and their decorators
sed -n '3084,3145p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 2321
🏁 Script executed:
# Also check if there are any other architectural guards in _heuristic_func_mm_fp4 that I might have missed
sed -n '3614,3660p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1814
🏁 Script executed:
# Check the backend_requirement decorator to understand how suitable_backends is populated
rg -n -A 10 'def backend_requirement' flashinfer/gemm/gemm_base.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Verify that _match_sm_version is used for architectural checks in similar heuristics
rg -n -B 5 -A 10 '_match_sm_version' flashinfer/gemm/gemm_base.py | grep -A 10 '_heuristic_func'Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are other heuristics that use SM guards
rg -n -B 2 -A 15 'def _heuristic_func_mm_mxfp8' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 578
Heuristic lacks SM100/SM103 guard despite documented scope.
The _heuristic_func_mm_fp4 implementation at line 3646 prioritizes cudnn whenever CUDA>=13 and cuDNN>=9.15, without checking architecture. The docstring promises this behavior only for SM100/SM103, but the code applies it to all supported architectures (SM110, SM120, SM121 also have suitable backends). This can unintentionally shift backend preferences on architectures where it wasn't intended.
The fix should add an SM version check using the existing _match_sm_version() utility:
Suggested fix
+ is_sm100_or_sm103 = _match_sm_version(a.device, ["100", "103"])
# If cuda version is 13 or greater and cudnn version is 9.15 or greater,
# prioritize cudnn for both SM100 (B200) and SM103 (B300).
if (
CUDNN_AVAILABLE
+ and is_sm100_or_sm103
and cuda_major >= 13
and cudnn.backend_version() >= 91500
):
candidate_backends = ("cudnn", "cutlass")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` around lines 3639 - 3650, The heuristic in
_heuristic_func_mm_fp4 currently prioritizes cudnn for CUDA>=13 and cuDNN>=9.15
regardless of GPU WA; update the condition to only prefer cudnn when the SM
matches SM100 or SM103 by calling the existing _match_sm_version() helper (e.g.,
verify _match_sm_version("100") or _match_sm_version("103") before selecting
("cudnn","cutlass")). Keep the existing fallback to ("cutlass","cudnn") for
other SMs and maintain the same CUDA/cuDNN version checks and CUDNN_AVAILABLE
guard.
There was a problem hiding this comment.
I think this is right, but I basically reverted to the older heuristic, which means back then the cuDNN-first priority also applied to SM110/SM120/SM121. I’m not sure now if this is also the best for those.
There was a problem hiding this comment.
@mmangkad You're right to be cautious about SM110/SM120/SM121. Looking at the PR description, the benchmark data only covers SM100 and SM103—there's no performance comparison for SM110/SM120/SM121 with the cuDNN-first heuristic.
My recommendation: Add the SM100/SM103 guard as suggested. Here's why:
-
Align with tested scope: The PR's rationale ("cuDNN is reported to be better in most cases on SM103") and benchmark plots only validate SM100 and SM103. Applying the change to SM110/SM120/SM121 without data is extending beyond the tested scope.
-
Conservative approach: Without benchmarks, we can't be sure cuDNN-first is optimal for SM110/SM120/SM121. The guard keeps the behavior for those architectures unchanged until validated.
-
Separate decision: If you want cuDNN-first for SM110/SM120/SM121, that should be a separate PR with:
- Benchmark data showing cuDNN vs. CUTLASS on those architectures
- Clear documentation of the performance win
Trade-off: The current code (without guard) applies a potentially unverified optimization to SM110/SM120/SM121. The guard (conservative) keeps them on the existing CUTLASS-first heuristic until you have data showing cuDNN is better.
Would you like to:
- Add the guard now (keep SM110/SM120/SM121 on CUTLASS-first)
- Run benchmarks for SM110/SM120/SM121 first, then decide
- Keep the current code and document that SM110/SM120/SM121 need validation?
🧠 Learnings used
Learnt from: danisereb
Repo: flashinfer-ai/flashinfer PR: 2464
File: include/flashinfer/gemm/mxfp8_gemm_template_sm100.h:148-163
Timestamp: 2026-02-04T10:08:47.455Z
Learning: In flashinfer GEMM template implementations (e.g., include/flashinfer/gemm/fp4_gemm_template_sm100.h, mxfp8_gemm_template_sm100.h), the Sm10x11xOnly architecture check wrapper uses a pattern where only thread0() prints an error message and calls __trap() when running on unsupported architectures. This pattern is intentional and working in production code, so consistency should be maintained across similar implementations.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` → `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
|
Hi @bkryu, could you please take a look at this? Thanks! |
|
/bot run |
|
[SUCCESS] Pipeline #45541796: 10/20 passed |
|
PR itself looks good to me. cc @dhiraj113 and @YangXu1990uiuc because there is some ongoing work on dynamic shapes |

📌 Description
This changes
mm_fp4auto backend selection to prefer cuDNN first (instead of CUTLASS-first on SM103) for CUDA 13 with cuDNN backend >= 9.15 on both SM100 and SM103. This is basically reverting the previous SM103-specific choice based on updated results: withnvidia-cudnn-cu13==9.19.1.2, cuDNN is better in most cases on SM103. See the comparison plots below.Summary by CodeRabbit