Skip to content

fix: guard CUTLASS FMHA against SM12x and fix fmha_v2 SM121a check#2560

Merged
jimmyzho merged 2 commits intoflashinfer-ai:mainfrom
blake-snc:fix/cutlass-fmha-sm120-guard
Mar 13, 2026
Merged

fix: guard CUTLASS FMHA against SM12x and fix fmha_v2 SM121a check#2560
jimmyzho merged 2 commits intoflashinfer-ai:mainfrom
blake-snc:fix/cutlass-fmha-sm120-guard

Conversation

@blake-snc
Copy link
Contributor

@blake-snc blake-snc commented Feb 13, 2026

Summary

  • Remove SM12x from CUTLASS FMHA support: get_fmha_module() and gen_fmha_cutlass_sm100a_module() incorrectly included SM12x GPUs (RTX 5090, DGX Spark) in their support checks. SM12x lacks the tcgen05 MMA instructions required by the CUTLASS FMHA SM100 kernel (SM100_MMA_F16BF16_SS/TS, SM100_MMA_F8F6F4_SS/TS), causing compile failures when using backend="cutlass" or fmha_varlen(). Changed supported_major_versions from [10, 11, 12] to [10, 11] and added a clear error message for SM12x users pointing them to backend='fa2'.

  • Fix fmha_v2_prefill_deepseek SM121a check: The SM12x guard only checked is_sm120a_supported() (SM120 = RTX 5090, minor=0) but not is_sm121a_supported() (SM121 = DGX Spark, minor=1). DGX Spark users were incorrectly rejected from using the fmha_v2 DeepSeek prefill kernel.

Validated on NVIDIA GB10 (DGX Spark, SM 12.1)

Test Result
CUTLASS FMHA correctly rejects SM12x with clear error PASS
FA2 prefill works (max_diff=0.0078 vs SDPA reference) PASS
XQA decode works (no NaN) PASS
determine_attention_backend() returns "fa2" for SM12x PASS
fmha_v2_prefill_deepseek accepts SM121a PASS

Test plan

  • Verify CUTLASS FMHA still works on SM100a (B200/GB200)
  • Verify fmha_varlen() raises clear error on SM12x instead of compile failure
  • Verify fmha_v2_prefill_deepseek() works on both SM120 (RTX 5090) and SM121 (DGX Spark)
  • Run existing CI tests

Contributed by Second Nature Computing

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes
    • FMHA optimized kernel now targets only SM100a/SM110a devices; other devices will receive an updated compatibility message with a suggested alternative backend.
    • Removed CUDA 12+ compilation support for the optimized path.
    • Prefill behavior updated: the alternate prefill path will not proceed on non-SM12x hardware and now raises a clear, explicit message.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @blake-snc, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines GPU architecture support for FlashInfer's Fused Multi-Head Attention (FMHA) kernels. It addresses compatibility issues by preventing SM12x GPUs from using CUTLASS FMHA kernels that require specific MMA instructions they lack, and simultaneously expands support for the fmha_v2_prefill_deepseek kernel to correctly include SM121a devices. These changes ensure proper functionality and provide clearer guidance to users regarding backend selection based on their hardware.

Highlights

  • CUTLASS FMHA SM12x Support Removal: Removed SM12x GPUs (RTX 5090, DGX Spark) from CUTLASS FMHA support due to their lack of tcgen05 MMA instructions, which are required by the SM100 kernel. The supported_major_versions for gen_fmha_cutlass_sm100a_module were updated from [10, 11, 12] to [10, 11], and a clear error message now guides SM12x users to backend='fa2'.
  • FMHA v2 DeepSeek SM121a Fix: Corrected the fmha_v2_prefill_deepseek function's SM12x guard to properly include is_sm121a_supported() (DGX Spark), ensuring that DGX Spark users are no longer incorrectly rejected from using this prefill kernel.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/jit/attention/modules.py
    • Removed SM12x (major version 12) from the list of supported GPU architectures for gen_fmha_cutlass_sm100a_module.
  • flashinfer/prefill.py
    • Modified get_fmha_module to exclude SM12x GPUs from the CUTLASS FMHA path.
    • Updated the error message in get_fmha_module to provide specific guidance for SM12x users, directing them to backend='fa2'.
    • Adjusted the fmha_v2_prefill_deepseek function to correctly support SM121a (DGX Spark) in addition to SM120a.
Activity
  • Validated CUTLASS FMHA correctly rejects SM12x with a clear error.
  • Validated FA2 prefill works with acceptable max_diff vs SDPA reference.
  • Validated XQA decode works without NaNs.
  • Validated determine_attention_backend() returns 'fa2' for SM12x.
  • Validated fmha_v2_prefill_deepseek accepts SM121a.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 13, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 47338576-910f-49a5-891a-8160e9521bc4

📥 Commits

Reviewing files that changed from the base of the PR and between e1479f5 and dada79d.

📒 Files selected for processing (1)
  • flashinfer/prefill.py

📝 Walkthrough

Walkthrough

Reduced CUTLASS FMHA NVCC compile targets (dropped CUDA 12+) and restricted FMHA module selection to SM100a/SM110a; prefill deepseek logic removed the explicit SM12x guard so SM12x no longer follows the prior CUDA-version error path.

Changes

Cohort / File(s) Summary
FMHA build config
flashinfer/jit/attention/modules.py
Removed CUDA 12+ from CUTLASS FMHA nvcc supported_major_versions (changed [10,11,12] → [10,11]).
Prefill device checks & error paths
flashinfer/prefill.py
get_fmha_module now accepts only SM100a/SM110a (SM12x removed); fmha_v2_prefill_deepseek removed the explicit SM12x CUDA-capability guard and the previous CUDA-version error path for SM12x devices.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • nvmbreughe
  • yzh119
  • cyx-6
  • bkryu
  • jimmyzho

Poem

🐰 I hopped through flags and trimmed the lines,
Dropped twelve from CUTLASS, kept the builds fine,
SM100a/110a now lead the way,
Deepseek’s old guard quietly hops away,
A little rabbit cheers the compile time! 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the two main changes: removing SM12x from CUTLASS FMHA support and fixing the SM121a check in fmha_v2, both reflected in the code modifications.
Description check ✅ Passed The PR description comprehensively explains both changes with clear rationale, validation results, and test plan items, though the checklist items are not marked as completed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can enforce grammar and style rules using `languagetool`.

Configure the reviews.tools.languagetool setting to enable/disable rules and categories. Refer to the LanguageTool Community to learn more.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses two issues. First, it removes SM12x support for the CUTLASS FMHA kernel, which was causing compilation failures due to missing hardware instructions, and improves the error message to guide users. Second, it fixes a bug in fmha_v2_prefill_deepseek by adding support for SM121a, which was previously incorrectly rejected. The changes are clear, well-justified, and improve both correctness and user experience.

"""
if not is_sm120a_supported(query.device):
raise ValueError("fmha_v2_prefill_deepseek is only supported on SM120 GPUs.")
if not (is_sm120a_supported(query.device) or is_sm121a_supported(query.device)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make this check more robust for new SM12x architectures, consider checking the major architecture version directly instead of listing each supported minor version. This would automatically include future SM12x GPUs (e.g., SM122a) without requiring code changes, which seems to be the intent given the error message and the build flags for this kernel.

Suggested change
if not (is_sm120a_supported(query.device) or is_sm121a_supported(query.device)):
if torch.cuda.get_device_capability(query.device)[0] != 12:

@blake-snc
Copy link
Contributor Author

Regarding the suggestion to use get_device_capability()[0] != 12 instead of is_sm120a_supported() or is_sm121a_supported(): the is_sm12xa_supported() utility functions are the standard pattern used throughout flashinfer (they also check CUDA toolkit version requirements). Using a raw capability check here would be inconsistent with the rest of the codebase.

That said, if a future SM122a variant appears, adding is_sm122a_supported() to the check is a one-line change. Happy to refactor if the maintainers prefer a different approach.

@eugr
Copy link

eugr commented Feb 14, 2026

That said, if a future SM122a variant appears, adding is_sm122a_supported() to the check is a one-line change. Happy to refactor if the maintainers prefer a different approach.

@blake-snc - would it be better to introduce is_sm12x_family_supported() to cover all such cases? Because even though it's a one line change, there are still some places in the code where sm120 is included and sm121 is ignored, even though they are pretty much identical. Even in new-ish PRs. Point in case: #2460

And that's with sm121 being out in the wild since October.

blake-snc added a commit to blake-snc/flashinfer that referenced this pull request Feb 17, 2026
Add a major-version-based helper that covers all SM12x GPUs (SM120a,
SM121a, and future variants) so callers don't need to enumerate each
minor version individually. Uses major == 12 check, matching the
pattern of is_sm100a_supported (major == 10).

Update existing call sites in gemm_base.py and the DeepSeek MLA test.

This avoids the recurring pattern where SM121a support gets missed when
only SM120a is checked, as noted in PR flashinfer-ai#2460 and flashinfer-ai#2560 discussion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc
Copy link
Contributor Author

@eugr Good call. I just opened #2574 which adds is_sm12x_supported() to utils.py using a major == 12 check (matching the pattern of is_sm100a_supported). This way future SM12x variants are automatically covered without needing to add and wire up individual is_sm122a_supported() functions.

The PR updates the existing call sites in gemm_base.py and the DeepSeek MLA test. Once it lands, our other SM120 PRs (#2559, #2560, #2561) can be rebased to use it too.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense to me, thanks for the fix.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 19, 2026

/bot run

@yzh119
Copy link
Collaborator

yzh119 commented Feb 19, 2026

@flashinfer-bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !327 has been created, and the CI pipeline #44336836 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44336836: 9/20 passed

blake-snc added a commit to blake-snc/flashinfer that referenced this pull request Feb 20, 2026
Add a major-version-based helper that covers all SM12x GPUs (SM120a,
SM121a, and future variants) so callers don't need to enumerate each
minor version individually. Uses major == 12 check, matching the
pattern of is_sm100a_supported (major == 10).

Update existing call sites in gemm_base.py and the DeepSeek MLA test.

This avoids the recurring pattern where SM121a support gets missed when
only SM120a is checked, as noted in PR flashinfer-ai#2460 and flashinfer-ai#2560 discussion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc
Copy link
Contributor Author

The internal CI pipeline shows 9/20 passed — are the failures related to this PR or pre-existing? Our changes only narrow CUTLASS FMHA support (removing SM12x) and fix the SM121a check in fmha_v2_prefill_deepseek, so they shouldn't affect existing SM100/SM110 tests.

Happy to investigate if there's something specific we need to fix.

yongwww pushed a commit that referenced this pull request Feb 25, 2026
## Summary

Adds `is_sm12x_supported()` to `flashinfer/utils.py` as a convenience
helper that covers the entire SM12x GPU family (SM120a, SM121a, and
future variants like SM122a) without requiring callers to enumerate each
minor version.

Uses a `major == 12` check, matching the existing pattern of
`is_sm100a_supported()` (`major == 10`). This means future SM12x
variants are automatically covered without code changes.

**Motivation:** SM121a (DGX Spark) keeps getting missed when only SM120a
is checked. This was noted by @eugr in #2560, and PR #2460 is another
example where SM121a was not included alongside SM120a.

## Changes

| File | Change |
|------|--------|
| `flashinfer/utils.py` | Add `is_sm12x_supported()` with `major == 12`
check |
| `flashinfer/gemm/gemm_base.py` | Replace 3 instances of
`is_sm120a_supported(a.device) or is_sm121a_supported(a.device)` |
| `tests/attention/test_fmha_v2_prefill_deepseek.py` | Update skip guard
to use `is_sm12x_supported()` |

The individual `is_sm120a_supported()` and `is_sm121a_supported()`
functions are preserved for cases that need variant-specific behavior.

Validated on DGX Spark (SM121a, CUDA 13.0).

[Second Nature Computing](https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Consolidated separate SM120/SM121 capability checks into a unified
SM12x check and updated the public import surface accordingly.
* Introduced explicit CUDA-version gating for SM12x variants and
clarified related compatibility/error messages.

* **Tests**
* Updated GPU compatibility tests and skip logic/messages to target
SM12x architecture support.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc blake-snc closed this Feb 28, 2026
@blake-snc blake-snc reopened this Feb 28, 2026
@blake-snc
Copy link
Contributor Author

Hey @yzh119 — this PR now has merge conflicts with main. Here's what changed upstream since your approval:

  1. CUTLASS FMHA: Main now includes SM12x in get_fmha_module() with is_sm12x_supported() and supported_major_versions=[10, 11, 12]. Our PR removes SM12x from that path (since SM12x lacks tcgen05 MMA). These are conflicting approaches — should SM12x be in the CUTLASS FMHA path or not?

  2. fmha_v2 DeepSeek SM121a: Main now uses is_sm12x_supported() which covers all SM12x variants, so our SM121a-specific fix is already addressed.

Happy to rebase if you can clarify the intended direction for (1). If CUTLASS FMHA genuinely works on SM12x with newer CUDA, we can drop that part and just rebase cleanly.

SM12x GPUs (RTX 5090, DGX Spark) lack tcgen05 MMA instructions required
by the CUTLASS FMHA SM100 kernel. Previously, get_fmha_module() and
gen_fmha_cutlass_sm100a_module() incorrectly included SM12x in their
support checks, causing compile failures when using backend="cutlass" or
fmha_varlen() on SM12x.

Also fix fmha_v2_prefill_deepseek() to accept SM121a (DGX Spark) in
addition to SM120a (RTX 5090), as both are SM12x-class GPUs that support
the fmha_v2 DeepSeek kernels.

Changes:
- Remove SM12x from get_fmha_module() support check with clear error msg
- Change supported_major_versions from [10, 11, 12] to [10, 11]
- Add is_sm121a_supported() check to fmha_v2_prefill_deepseek()

Validated on NVIDIA GB10 (DGX Spark, SM 12.1):
- CUTLASS FMHA correctly rejects SM12x with helpful error message
- FA2 prefill continues to work (max_diff=0.0078 vs SDPA reference)
- XQA decode continues to work (no NaN)
- determine_attention_backend() correctly returns "fa2" for SM12x

AI-assisted by Claude (Anthropic)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc blake-snc force-pushed the fix/cutlass-fmha-sm120-guard branch from 7a7f724 to e1479f5 Compare February 28, 2026 01:31
@blake-snc
Copy link
Contributor Author

Update: resolved the merge conflicts and rebased onto main. PR is mergeable now.

The resolution keeps our original intent:

  • CUTLASS FMHA guard: SM12x removed from get_fmha_module() and supported_major_versions (SM12x lacks tcgen05)
  • fmha_v2 DeepSeek: adopted upstream's is_sm12x_supported() which is more general than our original per-variant check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

56-56: ⚠️ Potential issue | 🟡 Minor

Remove unused import get_compute_capability.

The pipeline is failing because get_compute_capability is imported but not used in this file. This needs to be removed to fix the Ruff F401 linting error.

🔧 Proposed fix
 from .utils import (
     log2e,
     FP4Tensor,
     MaskMode,
     PosEncodingMode,
     TensorLayout,
     _check_cached_qkv_data_type,
     _check_kv_layout,
     _check_pos_encoding_mode,
     check_shape_dtype_device,
     _get_cache_alibi_slopes_buf,
     _get_cache_buf,
     _unpack_paged_kv_cache,
     canonicalize_torch_dtype,
     determine_attention_backend,
     device_support_pdl,
-    get_compute_capability,
     get_device_sm_count,
     is_float8,
     is_sm100a_supported,
     is_sm110a_supported,
     is_sm12x_supported,
     register_custom_op,
     register_fake_op,
     ceil_div,
     round_up,
 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` at line 56, Remove the unused import
get_compute_capability from the import list in prefill.py to resolve the Ruff
F401 lint error; locate the import statement that includes
"get_compute_capability" and delete that identifier (or the whole import line if
it only contained that symbol) so the module no longer imports an unused name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/prefill.py`:
- Line 56: Remove the unused import get_compute_capability from the import list
in prefill.py to resolve the Ruff F401 lint error; locate the import statement
that includes "get_compute_capability" and delete that identifier (or the whole
import line if it only contained that symbol) so the module no longer imports an
unused name.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7a7f724 and e1479f5.

📒 Files selected for processing (2)
  • flashinfer/jit/attention/modules.py
  • flashinfer/prefill.py

Fixes pre-commit ruff F401 lint failure.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
auto-merge was automatically disabled March 12, 2026 22:07

Head branch was pushed to by a user without write access

@jimmyzho jimmyzho merged commit 74e99e8 into flashinfer-ai:main Mar 13, 2026
30 of 36 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…lashinfer-ai#2560)

## Summary

- **Remove SM12x from CUTLASS FMHA support**: `get_fmha_module()` and
`gen_fmha_cutlass_sm100a_module()` incorrectly included SM12x GPUs (RTX
5090, DGX Spark) in their support checks. SM12x lacks the `tcgen05` MMA
instructions required by the CUTLASS FMHA SM100 kernel
(`SM100_MMA_F16BF16_SS/TS`, `SM100_MMA_F8F6F4_SS/TS`), causing compile
failures when using `backend="cutlass"` or `fmha_varlen()`. Changed
`supported_major_versions` from `[10, 11, 12]` to `[10, 11]` and added a
clear error message for SM12x users pointing them to `backend='fa2'`.

- **Fix `fmha_v2_prefill_deepseek` SM121a check**: The SM12x guard only
checked `is_sm120a_supported()` (SM120 = RTX 5090, minor=0) but not
`is_sm121a_supported()` (SM121 = DGX Spark, minor=1). DGX Spark users
were incorrectly rejected from using the fmha_v2 DeepSeek prefill
kernel.

## Validated on NVIDIA GB10 (DGX Spark, SM 12.1)

| Test | Result |
|------|--------|
| CUTLASS FMHA correctly rejects SM12x with clear error | PASS |
| FA2 prefill works (max_diff=0.0078 vs SDPA reference) | PASS |
| XQA decode works (no NaN) | PASS |
| `determine_attention_backend()` returns "fa2" for SM12x | PASS |
| `fmha_v2_prefill_deepseek` accepts SM121a | PASS |

## Test plan

- [ ] Verify CUTLASS FMHA still works on SM100a (B200/GB200)
- [ ] Verify `fmha_varlen()` raises clear error on SM12x instead of
compile failure
- [ ] Verify `fmha_v2_prefill_deepseek()` works on both SM120 (RTX 5090)
and SM121 (DGX Spark)
- [ ] Run existing CI tests

Contributed by [Second Nature Computing](https://joinsecondnature.com)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* FMHA optimized kernel now targets only SM100a/SM110a devices; other
devices will receive an updated compatibility message with a suggested
alternative backend.
  * Removed CUDA 12+ compilation support for the optimized path.
* Prefill behavior updated: the alternate prefill path will not proceed
on non-SM12x hardware and now raises a clear, explicit message.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…shinfer-ai#2574)

## Summary

Adds `is_sm12x_supported()` to `flashinfer/utils.py` as a convenience
helper that covers the entire SM12x GPU family (SM120a, SM121a, and
future variants like SM122a) without requiring callers to enumerate each
minor version.

Uses a `major == 12` check, matching the existing pattern of
`is_sm100a_supported()` (`major == 10`). This means future SM12x
variants are automatically covered without code changes.

**Motivation:** SM121a (DGX Spark) keeps getting missed when only SM120a
is checked. This was noted by @eugr in flashinfer-ai#2560, and PR flashinfer-ai#2460 is another
example where SM121a was not included alongside SM120a.

## Changes

| File | Change |
|------|--------|
| `flashinfer/utils.py` | Add `is_sm12x_supported()` with `major == 12`
check |
| `flashinfer/gemm/gemm_base.py` | Replace 3 instances of
`is_sm120a_supported(a.device) or is_sm121a_supported(a.device)` |
| `tests/attention/test_fmha_v2_prefill_deepseek.py` | Update skip guard
to use `is_sm12x_supported()` |

The individual `is_sm120a_supported()` and `is_sm121a_supported()`
functions are preserved for cases that need variant-specific behavior.

Validated on DGX Spark (SM121a, CUDA 13.0).

[Second Nature Computing](https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Consolidated separate SM120/SM121 capability checks into a unified
SM12x check and updated the public import surface accordingly.
* Introduced explicit CUDA-version gating for SM12x variants and
clarified related compatibility/error messages.

* **Tests**
* Updated GPU compatibility tests and skip logic/messages to target
SM12x architecture support.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…lashinfer-ai#2560)

## Summary

- **Remove SM12x from CUTLASS FMHA support**: `get_fmha_module()` and
`gen_fmha_cutlass_sm100a_module()` incorrectly included SM12x GPUs (RTX
5090, DGX Spark) in their support checks. SM12x lacks the `tcgen05` MMA
instructions required by the CUTLASS FMHA SM100 kernel
(`SM100_MMA_F16BF16_SS/TS`, `SM100_MMA_F8F6F4_SS/TS`), causing compile
failures when using `backend="cutlass"` or `fmha_varlen()`. Changed
`supported_major_versions` from `[10, 11, 12]` to `[10, 11]` and added a
clear error message for SM12x users pointing them to `backend='fa2'`.

- **Fix `fmha_v2_prefill_deepseek` SM121a check**: The SM12x guard only
checked `is_sm120a_supported()` (SM120 = RTX 5090, minor=0) but not
`is_sm121a_supported()` (SM121 = DGX Spark, minor=1). DGX Spark users
were incorrectly rejected from using the fmha_v2 DeepSeek prefill
kernel.

## Validated on NVIDIA GB10 (DGX Spark, SM 12.1)

| Test | Result |
|------|--------|
| CUTLASS FMHA correctly rejects SM12x with clear error | PASS |
| FA2 prefill works (max_diff=0.0078 vs SDPA reference) | PASS |
| XQA decode works (no NaN) | PASS |
| `determine_attention_backend()` returns "fa2" for SM12x | PASS |
| `fmha_v2_prefill_deepseek` accepts SM121a | PASS |

## Test plan

- [ ] Verify CUTLASS FMHA still works on SM100a (B200/GB200)
- [ ] Verify `fmha_varlen()` raises clear error on SM12x instead of
compile failure
- [ ] Verify `fmha_v2_prefill_deepseek()` works on both SM120 (RTX 5090)
and SM121 (DGX Spark)
- [ ] Run existing CI tests

Contributed by [Second Nature Computing](https://joinsecondnature.com)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* FMHA optimized kernel now targets only SM100a/SM110a devices; other
devices will receive an updated compatibility message with a suggested
alternative backend.
  * Removed CUDA 12+ compilation support for the optimized path.
* Prefill behavior updated: the alternate prefill path will not proceed
on non-SM12x hardware and now raises a clear, explicit message.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants