Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3636,23 +3636,15 @@ def _heuristic_func_mm_fp4(
- If cuda version is 12 - use cutlass.
- If cuda version is 13 and cudnn version is less than 9.15 - use cutlass.
- If cuda version is 13 and cudnn version is 9.15 or greater:
- On SM103 (B300) - use cutlass (faster based on benchmarks).
- On SM100 (B200) - use cudnn (faster based on benchmarks).
- Use cudnn first for both SM100 (B200) and SM103 (B300).

"""
cuda_major = get_cuda_version().major
# Get compute capability to distinguish between SM100 (10.0) and SM103 (10.3)
major, minor = get_compute_capability(a.device)
is_sm103 = major == 10 and minor == 3

# If cuda version is 13 or greater and cudnn version is 9.15 or greater:
# On SM103 (B300), cutlass is more performant than cudnn.
# On SM100 (B200), cudnn is more performant than cutlass.
# If cuda version is 13 or greater and cudnn version is 9.15 or greater,
# prioritize cudnn for both SM100 (B200) and SM103 (B300).
if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500:
if is_sm103:
candidate_backends = ("cutlass", "cudnn")
else:
candidate_backends = ("cudnn", "cutlass")
candidate_backends = ("cudnn", "cutlass")
# Otherwise, prioritize cutlass
else:
candidate_backends = ("cutlass", "cudnn")
Comment on lines +3639 to 3650
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Examine the _heuristic_func_mm_fp4 function and its decorator
rg -n -B 10 'def _heuristic_func_mm_fp4' flashinfer/gemm/gemm_base.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 339


🏁 Script executed:

# Check lines around 3639-3650 to see the complete function and context
sed -n '3625,3660p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1462


🏁 Script executed:

# Search for _match_sm_version usage to understand its applicability
rg -n '_match_sm_version' flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 476


🏁 Script executed:

# Search for other uses of compute capability checks in similar heuristic functions
rg -n 'def _heuristic_func' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 272


🏁 Script executed:

# Find the decorator for mm_fp4 to confirm supported architectures
rg -n -B 5 'def mm_fp4\(' flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 272


🏁 Script executed:

# Also check what the decorator looks like for similar functions
rg -n '@supported_compute_capability' flashinfer/gemm/gemm_base.py | grep -E 'mm_fp4|mm_mxfp8|mm_bf16' -A 1 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search more broadly for mm_fp4 with decorator
rg -n -B 15 'def mm_fp4' flashinfer/gemm/gemm_base.py | grep -E '@supported_compute_capability|def mm_fp4'

Repository: flashinfer-ai/flashinfer

Length of output: 84


🏁 Script executed:

# Find the complete decorator for mm_fp4 function
sed -n '3700,3760p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1651


🏁 Script executed:

# Search for `@supported_compute_capability` in the entire file to understand the pattern
rg -n '@supported_compute_capability' flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1143


🏁 Script executed:

# Get the exact decorator line(s) immediately before mm_fp4 function
sed -n '3740,3760p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 657


🏁 Script executed:

# Search backward from line 3754 to find `@supported_compute_capability` for mm_fp4
sed -n '3600,3760p' flashinfer/gemm/gemm_base.py | grep -n '@supported_compute_capability' | tail -1

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Let me search for mm_fp4 in the file with more context
rg -n 'mm_fp4' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1973


🏁 Script executed:

# Check if mm_fp4 has any decorator by looking at a wider range
sed -n '3730,3765p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1101


🏁 Script executed:

# Look at the requirement functions to see if they have `@supported_compute_capability`
sed -n '3020,3030p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 422


🏁 Script executed:

# Check all lines around the requirements for mm_fp4
sed -n '3015,3145p' flashinfer/gemm/gemm_base.py | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 2184


🏁 Script executed:

# Check all requirement functions for mm_fp4 and their decorators
sed -n '3084,3145p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 2321


🏁 Script executed:

# Also check if there are any other architectural guards in _heuristic_func_mm_fp4 that I might have missed
sed -n '3614,3660p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1814


🏁 Script executed:

# Check the backend_requirement decorator to understand how suitable_backends is populated
rg -n -A 10 'def backend_requirement' flashinfer/gemm/gemm_base.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Verify that _match_sm_version is used for architectural checks in similar heuristics
rg -n -B 5 -A 10 '_match_sm_version' flashinfer/gemm/gemm_base.py | grep -A 10 '_heuristic_func'

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are other heuristics that use SM guards
rg -n -B 2 -A 15 'def _heuristic_func_mm_mxfp8' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 578


Heuristic lacks SM100/SM103 guard despite documented scope.

The _heuristic_func_mm_fp4 implementation at line 3646 prioritizes cudnn whenever CUDA>=13 and cuDNN>=9.15, without checking architecture. The docstring promises this behavior only for SM100/SM103, but the code applies it to all supported architectures (SM110, SM120, SM121 also have suitable backends). This can unintentionally shift backend preferences on architectures where it wasn't intended.

The fix should add an SM version check using the existing _match_sm_version() utility:

Suggested fix
+    is_sm100_or_sm103 = _match_sm_version(a.device, ["100", "103"])
     # If cuda version is 13 or greater and cudnn version is 9.15 or greater,
     # prioritize cudnn for both SM100 (B200) and SM103 (B300).
     if (
         CUDNN_AVAILABLE
+        and is_sm100_or_sm103
         and cuda_major >= 13
         and cudnn.backend_version() >= 91500
     ):
         candidate_backends = ("cudnn", "cutlass")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 3639 - 3650, The heuristic in
_heuristic_func_mm_fp4 currently prioritizes cudnn for CUDA>=13 and cuDNN>=9.15
regardless of GPU WA; update the condition to only prefer cudnn when the SM
matches SM100 or SM103 by calling the existing _match_sm_version() helper (e.g.,
verify _match_sm_version("100") or _match_sm_version("103") before selecting
("cudnn","cutlass")). Keep the existing fallback to ("cutlass","cudnn") for
other SMs and maintain the same CUDA/cuDNN version checks and CUDNN_AVAILABLE
guard.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is right, but I basically reverted to the older heuristic, which means back then the cuDNN-first priority also applied to SM110/SM120/SM121. I’m not sure now if this is also the best for those.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mmangkad You're right to be cautious about SM110/SM120/SM121. Looking at the PR description, the benchmark data only covers SM100 and SM103—there's no performance comparison for SM110/SM120/SM121 with the cuDNN-first heuristic.

My recommendation: Add the SM100/SM103 guard as suggested. Here's why:

  1. Align with tested scope: The PR's rationale ("cuDNN is reported to be better in most cases on SM103") and benchmark plots only validate SM100 and SM103. Applying the change to SM110/SM120/SM121 without data is extending beyond the tested scope.

  2. Conservative approach: Without benchmarks, we can't be sure cuDNN-first is optimal for SM110/SM120/SM121. The guard keeps the behavior for those architectures unchanged until validated.

  3. Separate decision: If you want cuDNN-first for SM110/SM120/SM121, that should be a separate PR with:

    • Benchmark data showing cuDNN vs. CUTLASS on those architectures
    • Clear documentation of the performance win

Trade-off: The current code (without guard) applies a potentially unverified optimization to SM110/SM120/SM121. The guard (conservative) keeps them on the existing CUTLASS-first heuristic until you have data showing cuDNN is better.

Would you like to:

  • Add the guard now (keep SM110/SM120/SM121 on CUTLASS-first)
  • Run benchmarks for SM110/SM120/SM121 first, then decide
  • Keep the current code and document that SM110/SM120/SM121 need validation?

🧠 Learnings used
Learnt from: danisereb
Repo: flashinfer-ai/flashinfer PR: 2464
File: include/flashinfer/gemm/mxfp8_gemm_template_sm100.h:148-163
Timestamp: 2026-02-04T10:08:47.455Z
Learning: In flashinfer GEMM template implementations (e.g., include/flashinfer/gemm/fp4_gemm_template_sm100.h, mxfp8_gemm_template_sm100.h), the Sm10x11xOnly architecture check wrapper uses a pattern where only thread0() prints an error message and calls __trap() when running on unsupported architectures. This pattern is intentional and working in production code, so consistency should be maintained across similar implementations.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` → `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.

Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Expand Down
Loading