fix: add SM121 support to SM120 version guards#2631
fix: add SM121 support to SM120 version guards#2631yzh119 merged 2 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @Yuening-wa, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances FlashInfer's compatibility with newer NVIDIA GPU architectures by integrating SM121 into existing SM120-specific logic. This ensures that advanced features like FP4 GEMM and optimized Cutlass configurations are correctly applied to SM121 devices, maintaining performance and functionality across a broader range of hardware. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThis PR expands SM120-specific checks, messages, and config paths to also include SM121 across GEMM, Cutlass config, XQA, and tests; it adds a new Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for the SM121 architecture by including it in version guards that previously only checked for SM120. The changes are straightforward and correctly extend support across the Python and C++ codebases. I've identified a couple of minor issues where comments and error messages should be updated to reflect the inclusion of SM121, ensuring code clarity and accurate user feedback. Overall, the changes are good and align with the PR's goal.
| if ( | ||
| not use_nvfp4 | ||
| and _match_sm_version(a.device, ["120"]) | ||
| and _match_sm_version(a.device, ["120", "121"]) |
There was a problem hiding this comment.
While this change correctly adds support for SM121, the error message raised on line 3044, which is controlled by the CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR constant, is now potentially misleading as it only mentions SM120. To avoid confusion for users on SM121 devices, please consider updating the error message to include SM121.
| << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); | ||
| // SM120 specific: StreamK scheduler option | ||
| if (sm_version == 120) { | ||
| if (sm_version == 120 || sm_version == 121) { |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3039-3043: The conditional checks use_nvfp4,
_match_sm_version(a.device, ["120", "121"]), and cudnn.backend_version() < 91400
but the raised error string only mentions “SM120”; update the error message
produced in this branch to mention both SM120 and SM121 (or use a combined
phrase like “SM120/SM121”) so users on SM121 see the correct guidance; locate
the raise or processLogger.error call inside the block guarded by
use_nvfp4/_match_sm_version and change the literal to include SM121.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/gemm/gemm_base.pyinclude/flashinfer/gemm/cutlass_gemm_configs.hinclude/flashinfer/trtllm/common.h
|
/bot run |
|
@flashinfer-bot run |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/gemm/fp4_gemm_template_sm120.h (1)
254-259:⚠️ Potential issue | 🟡 MinorStale inline comment at line 259 not updated to match the SM120/SM121 change.
Lines 254–255 were updated to reference "SM120/SM121", but the adjacent inline comment on line 259 still only mentions "SM120":
cutlass::gemm::collective::StageCount<2>, /* Fixed 2 stages for SM120 */✏️ Proposed fix
- cutlass::gemm::collective::StageCount<2>, /* Fixed 2 stages for SM120 */ + cutlass::gemm::collective::StageCount<2>, /* Fixed 2 stages for SM120/SM121 */🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/gemm/fp4_gemm_template_sm120.h` around lines 254 - 259, Update the stale inline comment that mentions only "SM120" next to cutlass::gemm::collective::StageCount<2> to reference "SM120/SM121" (or make it generic) so it matches the surrounding lines that were changed; locate the CollectiveMainloop typedef in fp4_gemm_template_sm120.h and modify the comment after StageCount<2> to reflect SM120/SM121 consistency.
🧹 Nitpick comments (1)
include/flashinfer/gemm/fp4_gemm_template_sm120.h (1)
52-73:SMTypeAdapterdefinitions are unused in SM120 — remove or document to avoid confusion.The
SMTypeAdapter<_1SM>andSMTypeAdapter<_2SM>structs define SM100-specific schedules (KernelTmaWarpSpecialized{1,2}SmNvf4Sm100) that are never referenced in the SM120 kernel macro. SM120 hardcodesEpilogueScheduleAutoandKernelScheduleAutodirectly (lines 244–245), bypassingSMTypeAdapterentirely.In contrast, SM100 (
fp4_gemm_template_sm100.h) actively uses these fields at lines 129–130 within its macro definition. The SM120 definitions appear to be an artifact of copying from SM100 and represent dead code that risks future misinterpretation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/gemm/fp4_gemm_template_sm120.h` around lines 52 - 73, SMTypeAdapter<_1SM> and SMTypeAdapter<_2SM> declare SM100-specific EpilogueSchedule and MainloopSchedule values that are never used by the SM120 kernel (which uses EpilogueScheduleAuto and KernelScheduleAuto); remove these dead structs or clearly document them as SM100-only artifacts to avoid confusion. Locate the template specializations SMTypeAdapter<_1SM> and SMTypeAdapter<_2SM> and either delete them or add a comment indicating they are intentionally present for SM100 compatibility only, ensuring references to EpilogueSchedule and MainloopSchedule and the fact SM120 uses EpilogueScheduleAuto/KernelScheduleAuto are mentioned.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@include/flashinfer/gemm/fp4_gemm_template_sm120.h`:
- Around line 254-259: Update the stale inline comment that mentions only
"SM120" next to cutlass::gemm::collective::StageCount<2> to reference
"SM120/SM121" (or make it generic) so it matches the surrounding lines that were
changed; locate the CollectiveMainloop typedef in fp4_gemm_template_sm120.h and
modify the comment after StageCount<2> to reflect SM120/SM121 consistency.
---
Nitpick comments:
In `@include/flashinfer/gemm/fp4_gemm_template_sm120.h`:
- Around line 52-73: SMTypeAdapter<_1SM> and SMTypeAdapter<_2SM> declare
SM100-specific EpilogueSchedule and MainloopSchedule values that are never used
by the SM120 kernel (which uses EpilogueScheduleAuto and KernelScheduleAuto);
remove these dead structs or clearly document them as SM100-only artifacts to
avoid confusion. Locate the template specializations SMTypeAdapter<_1SM> and
SMTypeAdapter<_2SM> and either delete them or add a comment indicating they are
intentionally present for SM100 compatibility only, ensuring references to
EpilogueSchedule and MainloopSchedule and the fact SM120 uses
EpilogueScheduleAuto/KernelScheduleAuto are mentioned.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (19)
csrc/gemm_groupwise_sm120.cucsrc/group_gemm_fp8_groupwise_sm120.cucsrc/trtllm_fmha_v2_binding.cuflashinfer/decode.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.pyflashinfer/jit/gemm/cutlass/generate_kernels.pyflashinfer/mla.pyflashinfer/xqa.pyinclude/flashinfer/gemm/cutlass_gemm_configs.hinclude/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.hinclude/flashinfer/gemm/fp4_gemm_template_sm120.hinclude/flashinfer/gemm/gemm_groupwise_sm120.cuhinclude/flashinfer/gemm/group_gemm_fp8_groupwise_sm120.cuhtests/attention/test_trtllm_gen_mla.pytests/attention/test_xqa.pytests/attention/test_xqa_batch_decode.pytests/attention/test_xqa_mla_batch_decode.pytests/moe/test_trtllm_cutlass_fused_moe.py
✅ Files skipped from review due to trivial changes (6)
- include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h
- csrc/trtllm_fmha_v2_binding.cu
- tests/attention/test_xqa_batch_decode.py
- include/flashinfer/gemm/gemm_groupwise_sm120.cuh
- tests/attention/test_trtllm_gen_mla.py
- csrc/gemm_groupwise_sm120.cu
|
Thanks, mates! Let´s keep improving DGX Spark |
|
[FAILED] Pipeline #44685986: 9/20 passed |
|
keeping fingers crossed :) |
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Broadened hardware compatibility to include additional NVIDIA Blackwell GPUs (SM121) and added an explicit configuration option for the newer scheduler on those devices. * **Documentation** * Updated user-facing messages, docs and comments to reflect SM121 support. * **Tests** * Adjusted test skip messages to include SM121 where applicable. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit