fix: guard CUTLASS FMHA against SM12x and fix fmha_v2 SM121a check#2560
Conversation
Summary of ChangesHello @blake-snc, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines GPU architecture support for FlashInfer's Fused Multi-Head Attention (FMHA) kernels. It addresses compatibility issues by preventing SM12x GPUs from using CUTLASS FMHA kernels that require specific MMA instructions they lack, and simultaneously expands support for the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughReduced CUTLASS FMHA NVCC compile targets (dropped CUDA 12+) and restricted FMHA module selection to SM100a/SM110a; prefill deepseek logic removed the explicit SM12x guard so SM12x no longer follows the prior CUDA-version error path. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can enforce grammar and style rules using `languagetool`.Configure the |
There was a problem hiding this comment.
Code Review
This pull request correctly addresses two issues. First, it removes SM12x support for the CUTLASS FMHA kernel, which was causing compilation failures due to missing hardware instructions, and improves the error message to guide users. Second, it fixes a bug in fmha_v2_prefill_deepseek by adding support for SM121a, which was previously incorrectly rejected. The changes are clear, well-justified, and improve both correctness and user experience.
flashinfer/prefill.py
Outdated
| """ | ||
| if not is_sm120a_supported(query.device): | ||
| raise ValueError("fmha_v2_prefill_deepseek is only supported on SM120 GPUs.") | ||
| if not (is_sm120a_supported(query.device) or is_sm121a_supported(query.device)): |
There was a problem hiding this comment.
To make this check more robust for new SM12x architectures, consider checking the major architecture version directly instead of listing each supported minor version. This would automatically include future SM12x GPUs (e.g., SM122a) without requiring code changes, which seems to be the intent given the error message and the build flags for this kernel.
| if not (is_sm120a_supported(query.device) or is_sm121a_supported(query.device)): | |
| if torch.cuda.get_device_capability(query.device)[0] != 12: |
|
Regarding the suggestion to use That said, if a future SM122a variant appears, adding |
@blake-snc - would it be better to introduce is_sm12x_family_supported() to cover all such cases? Because even though it's a one line change, there are still some places in the code where sm120 is included and sm121 is ignored, even though they are pretty much identical. Even in new-ish PRs. Point in case: #2460 And that's with sm121 being out in the wild since October. |
Add a major-version-based helper that covers all SM12x GPUs (SM120a, SM121a, and future variants) so callers don't need to enumerate each minor version individually. Uses major == 12 check, matching the pattern of is_sm100a_supported (major == 10). Update existing call sites in gemm_base.py and the DeepSeek MLA test. This avoids the recurring pattern where SM121a support gets missed when only SM120a is checked, as noted in PR flashinfer-ai#2460 and flashinfer-ai#2560 discussion. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@eugr Good call. I just opened #2574 which adds The PR updates the existing call sites in |
yzh119
left a comment
There was a problem hiding this comment.
Make sense to me, thanks for the fix.
|
/bot run |
|
@flashinfer-bot run |
|
[FAILED] Pipeline #44336836: 9/20 passed |
Add a major-version-based helper that covers all SM12x GPUs (SM120a, SM121a, and future variants) so callers don't need to enumerate each minor version individually. Uses major == 12 check, matching the pattern of is_sm100a_supported (major == 10). Update existing call sites in gemm_base.py and the DeepSeek MLA test. This avoids the recurring pattern where SM121a support gets missed when only SM120a is checked, as noted in PR flashinfer-ai#2460 and flashinfer-ai#2560 discussion. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
The internal CI pipeline shows 9/20 passed — are the failures related to this PR or pre-existing? Our changes only narrow CUTLASS FMHA support (removing SM12x) and fix the SM121a check in Happy to investigate if there's something specific we need to fix. |
## Summary Adds `is_sm12x_supported()` to `flashinfer/utils.py` as a convenience helper that covers the entire SM12x GPU family (SM120a, SM121a, and future variants like SM122a) without requiring callers to enumerate each minor version. Uses a `major == 12` check, matching the existing pattern of `is_sm100a_supported()` (`major == 10`). This means future SM12x variants are automatically covered without code changes. **Motivation:** SM121a (DGX Spark) keeps getting missed when only SM120a is checked. This was noted by @eugr in #2560, and PR #2460 is another example where SM121a was not included alongside SM120a. ## Changes | File | Change | |------|--------| | `flashinfer/utils.py` | Add `is_sm12x_supported()` with `major == 12` check | | `flashinfer/gemm/gemm_base.py` | Replace 3 instances of `is_sm120a_supported(a.device) or is_sm121a_supported(a.device)` | | `tests/attention/test_fmha_v2_prefill_deepseek.py` | Update skip guard to use `is_sm12x_supported()` | The individual `is_sm120a_supported()` and `is_sm121a_supported()` functions are preserved for cases that need variant-specific behavior. Validated on DGX Spark (SM121a, CUDA 13.0). [Second Nature Computing](https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Consolidated separate SM120/SM121 capability checks into a unified SM12x check and updated the public import surface accordingly. * Introduced explicit CUDA-version gating for SM12x variants and clarified related compatibility/error messages. * **Tests** * Updated GPU compatibility tests and skip logic/messages to target SM12x architecture support. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
|
Hey @yzh119 — this PR now has merge conflicts with main. Here's what changed upstream since your approval:
Happy to rebase if you can clarify the intended direction for (1). If CUTLASS FMHA genuinely works on SM12x with newer CUDA, we can drop that part and just rebase cleanly. |
SM12x GPUs (RTX 5090, DGX Spark) lack tcgen05 MMA instructions required by the CUTLASS FMHA SM100 kernel. Previously, get_fmha_module() and gen_fmha_cutlass_sm100a_module() incorrectly included SM12x in their support checks, causing compile failures when using backend="cutlass" or fmha_varlen() on SM12x. Also fix fmha_v2_prefill_deepseek() to accept SM121a (DGX Spark) in addition to SM120a (RTX 5090), as both are SM12x-class GPUs that support the fmha_v2 DeepSeek kernels. Changes: - Remove SM12x from get_fmha_module() support check with clear error msg - Change supported_major_versions from [10, 11, 12] to [10, 11] - Add is_sm121a_supported() check to fmha_v2_prefill_deepseek() Validated on NVIDIA GB10 (DGX Spark, SM 12.1): - CUTLASS FMHA correctly rejects SM12x with helpful error message - FA2 prefill continues to work (max_diff=0.0078 vs SDPA reference) - XQA decode continues to work (no NaN) - determine_attention_backend() correctly returns "fa2" for SM12x AI-assisted by Claude (Anthropic) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
7a7f724 to
e1479f5
Compare
|
Update: resolved the merge conflicts and rebased onto main. PR is mergeable now. The resolution keeps our original intent:
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)
56-56:⚠️ Potential issue | 🟡 MinorRemove unused import
get_compute_capability.The pipeline is failing because
get_compute_capabilityis imported but not used in this file. This needs to be removed to fix the Ruff F401 linting error.🔧 Proposed fix
from .utils import ( log2e, FP4Tensor, MaskMode, PosEncodingMode, TensorLayout, _check_cached_qkv_data_type, _check_kv_layout, _check_pos_encoding_mode, check_shape_dtype_device, _get_cache_alibi_slopes_buf, _get_cache_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, determine_attention_backend, device_support_pdl, - get_compute_capability, get_device_sm_count, is_float8, is_sm100a_supported, is_sm110a_supported, is_sm12x_supported, register_custom_op, register_fake_op, ceil_div, round_up, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` at line 56, Remove the unused import get_compute_capability from the import list in prefill.py to resolve the Ruff F401 lint error; locate the import statement that includes "get_compute_capability" and delete that identifier (or the whole import line if it only contained that symbol) so the module no longer imports an unused name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/prefill.py`:
- Line 56: Remove the unused import get_compute_capability from the import list
in prefill.py to resolve the Ruff F401 lint error; locate the import statement
that includes "get_compute_capability" and delete that identifier (or the whole
import line if it only contained that symbol) so the module no longer imports an
unused name.
Fixes pre-commit ruff F401 lint failure. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Head branch was pushed to by a user without write access
…lashinfer-ai#2560) ## Summary - **Remove SM12x from CUTLASS FMHA support**: `get_fmha_module()` and `gen_fmha_cutlass_sm100a_module()` incorrectly included SM12x GPUs (RTX 5090, DGX Spark) in their support checks. SM12x lacks the `tcgen05` MMA instructions required by the CUTLASS FMHA SM100 kernel (`SM100_MMA_F16BF16_SS/TS`, `SM100_MMA_F8F6F4_SS/TS`), causing compile failures when using `backend="cutlass"` or `fmha_varlen()`. Changed `supported_major_versions` from `[10, 11, 12]` to `[10, 11]` and added a clear error message for SM12x users pointing them to `backend='fa2'`. - **Fix `fmha_v2_prefill_deepseek` SM121a check**: The SM12x guard only checked `is_sm120a_supported()` (SM120 = RTX 5090, minor=0) but not `is_sm121a_supported()` (SM121 = DGX Spark, minor=1). DGX Spark users were incorrectly rejected from using the fmha_v2 DeepSeek prefill kernel. ## Validated on NVIDIA GB10 (DGX Spark, SM 12.1) | Test | Result | |------|--------| | CUTLASS FMHA correctly rejects SM12x with clear error | PASS | | FA2 prefill works (max_diff=0.0078 vs SDPA reference) | PASS | | XQA decode works (no NaN) | PASS | | `determine_attention_backend()` returns "fa2" for SM12x | PASS | | `fmha_v2_prefill_deepseek` accepts SM121a | PASS | ## Test plan - [ ] Verify CUTLASS FMHA still works on SM100a (B200/GB200) - [ ] Verify `fmha_varlen()` raises clear error on SM12x instead of compile failure - [ ] Verify `fmha_v2_prefill_deepseek()` works on both SM120 (RTX 5090) and SM121 (DGX Spark) - [ ] Run existing CI tests Contributed by [Second Nature Computing](https://joinsecondnature.com) 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * FMHA optimized kernel now targets only SM100a/SM110a devices; other devices will receive an updated compatibility message with a suggested alternative backend. * Removed CUDA 12+ compilation support for the optimized path. * Prefill behavior updated: the alternate prefill path will not proceed on non-SM12x hardware and now raises a clear, explicit message. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…shinfer-ai#2574) ## Summary Adds `is_sm12x_supported()` to `flashinfer/utils.py` as a convenience helper that covers the entire SM12x GPU family (SM120a, SM121a, and future variants like SM122a) without requiring callers to enumerate each minor version. Uses a `major == 12` check, matching the existing pattern of `is_sm100a_supported()` (`major == 10`). This means future SM12x variants are automatically covered without code changes. **Motivation:** SM121a (DGX Spark) keeps getting missed when only SM120a is checked. This was noted by @eugr in flashinfer-ai#2560, and PR flashinfer-ai#2460 is another example where SM121a was not included alongside SM120a. ## Changes | File | Change | |------|--------| | `flashinfer/utils.py` | Add `is_sm12x_supported()` with `major == 12` check | | `flashinfer/gemm/gemm_base.py` | Replace 3 instances of `is_sm120a_supported(a.device) or is_sm121a_supported(a.device)` | | `tests/attention/test_fmha_v2_prefill_deepseek.py` | Update skip guard to use `is_sm12x_supported()` | The individual `is_sm120a_supported()` and `is_sm121a_supported()` functions are preserved for cases that need variant-specific behavior. Validated on DGX Spark (SM121a, CUDA 13.0). [Second Nature Computing](https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Consolidated separate SM120/SM121 capability checks into a unified SM12x check and updated the public import surface accordingly. * Introduced explicit CUDA-version gating for SM12x variants and clarified related compatibility/error messages. * **Tests** * Updated GPU compatibility tests and skip logic/messages to target SM12x architecture support. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…lashinfer-ai#2560) ## Summary - **Remove SM12x from CUTLASS FMHA support**: `get_fmha_module()` and `gen_fmha_cutlass_sm100a_module()` incorrectly included SM12x GPUs (RTX 5090, DGX Spark) in their support checks. SM12x lacks the `tcgen05` MMA instructions required by the CUTLASS FMHA SM100 kernel (`SM100_MMA_F16BF16_SS/TS`, `SM100_MMA_F8F6F4_SS/TS`), causing compile failures when using `backend="cutlass"` or `fmha_varlen()`. Changed `supported_major_versions` from `[10, 11, 12]` to `[10, 11]` and added a clear error message for SM12x users pointing them to `backend='fa2'`. - **Fix `fmha_v2_prefill_deepseek` SM121a check**: The SM12x guard only checked `is_sm120a_supported()` (SM120 = RTX 5090, minor=0) but not `is_sm121a_supported()` (SM121 = DGX Spark, minor=1). DGX Spark users were incorrectly rejected from using the fmha_v2 DeepSeek prefill kernel. ## Validated on NVIDIA GB10 (DGX Spark, SM 12.1) | Test | Result | |------|--------| | CUTLASS FMHA correctly rejects SM12x with clear error | PASS | | FA2 prefill works (max_diff=0.0078 vs SDPA reference) | PASS | | XQA decode works (no NaN) | PASS | | `determine_attention_backend()` returns "fa2" for SM12x | PASS | | `fmha_v2_prefill_deepseek` accepts SM121a | PASS | ## Test plan - [ ] Verify CUTLASS FMHA still works on SM100a (B200/GB200) - [ ] Verify `fmha_varlen()` raises clear error on SM12x instead of compile failure - [ ] Verify `fmha_v2_prefill_deepseek()` works on both SM120 (RTX 5090) and SM121 (DGX Spark) - [ ] Run existing CI tests Contributed by [Second Nature Computing](https://joinsecondnature.com) 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * FMHA optimized kernel now targets only SM100a/SM110a devices; other devices will receive an updated compatibility message with a suggested alternative backend. * Removed CUDA 12+ compilation support for the optimized path. * Prefill behavior updated: the alternate prefill path will not proceed on non-SM12x hardware and now raises a clear, explicit message. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Summary
Remove SM12x from CUTLASS FMHA support:
get_fmha_module()andgen_fmha_cutlass_sm100a_module()incorrectly included SM12x GPUs (RTX 5090, DGX Spark) in their support checks. SM12x lacks thetcgen05MMA instructions required by the CUTLASS FMHA SM100 kernel (SM100_MMA_F16BF16_SS/TS,SM100_MMA_F8F6F4_SS/TS), causing compile failures when usingbackend="cutlass"orfmha_varlen(). Changedsupported_major_versionsfrom[10, 11, 12]to[10, 11]and added a clear error message for SM12x users pointing them tobackend='fa2'.Fix
fmha_v2_prefill_deepseekSM121a check: The SM12x guard only checkedis_sm120a_supported()(SM120 = RTX 5090, minor=0) but notis_sm121a_supported()(SM121 = DGX Spark, minor=1). DGX Spark users were incorrectly rejected from using the fmha_v2 DeepSeek prefill kernel.Validated on NVIDIA GB10 (DGX Spark, SM 12.1)
determine_attention_backend()returns "fa2" for SM12xfmha_v2_prefill_deepseekaccepts SM121aTest plan
fmha_varlen()raises clear error on SM12x instead of compile failurefmha_v2_prefill_deepseek()works on both SM120 (RTX 5090) and SM121 (DGX Spark)Contributed by Second Nature Computing
🤖 Generated with Claude Code
Summary by CodeRabbit