perf: add fp4 GEMM tile configs and streamK scheduler for SM120#2460
perf: add fp4 GEMM tile configs and streamK scheduler for SM120#2460yzh119 merged 1 commit intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @Yuening-wa, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces performance enhancements for FP4 General Matrix Multiply (GEMM) operations on SM120 GPUs. By expanding the available tile configurations and integrating the StreamK scheduler from CUTLASS, the changes aim to provide more optimized execution paths. This results in improved latency and throughput for certain computational workloads, as evidenced by the provided benchmarks. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThese changes extend the SM120 FP4 GEMM kernel infrastructure to support configurable scheduler selection (DP vs StreamK) alongside expanded tile configurations. A new Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant Config as CutlassGemmConfig
participant Dispatch as dispatchNVFP4xNVFP4Gemm<br/>ClusterShapeSm120
participant Router as Scheduler<br/>Router
participant DPLauncher as DP Launcher
participant StreamKLauncher as StreamK Launcher
participant Helper as prepareGemmArgsImpl<br/>& runFp4GemmImpl
User->>Config: Create with tile shape<br/>& use_stream_k flag
User->>Dispatch: Call with UseStreamK<br/>template param
Dispatch->>Router: Route based on<br/>tile config & scheduler
alt use_stream_k == false
Router->>DPLauncher: Select DP path
DPLauncher->>Helper: Call runFp4GemmImpl<br/>(DP variant)
else use_stream_k == true
Router->>StreamKLauncher: Select StreamK path
StreamKLauncher->>Helper: Call runFp4GemmImpl<br/>(StreamK variant)
end
Helper->>Helper: prepareGemmArgsImpl<br/>(unified args)
Helper->>Helper: Check workspace,<br/>capability
Helper->>Helper: Launch kernel
Helper-->>User: Return execution result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces performance optimizations for FP4 GEMM on SM120 architecture by adding new tile configurations and enabling the StreamK scheduler. The changes are well-structured, particularly the refactoring in fp4_gemm_template_sm120.h to abstract common kernel launching logic, which significantly improves code clarity and maintainability. My main feedback is a minor style suggestion in fp4_gemm_cutlass_template_sm120.h to improve the readability of a switch statement. Overall, this is a great performance enhancement.
| case CutlassTileConfigSM120::CtaShape128x128x128B: | ||
| // Always use 1x1x1 cluster shape for SM120 | ||
| return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>, | ||
| cute::Int<128>>( | ||
| D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, | ||
| workspaceBytes, stream, occupancy); | ||
| break; | ||
| DISPATCH_WITH_SCHEDULER(128, 128, 128); | ||
| case CutlassTileConfigSM120::CtaShape128x128x256B: | ||
| DISPATCH_WITH_SCHEDULER(128, 128, 256); | ||
| case CutlassTileConfigSM120::CtaShape256x128x128B: | ||
| DISPATCH_WITH_SCHEDULER(256, 128, 128); |
There was a problem hiding this comment.
The break statements are missing between these case labels. While the DISPATCH_WITH_SCHEDULER macro expands to a return statement, which prevents functional issues from fallthrough, this structure is confusing and can be flagged by compilers with -Wimplicit-fallthrough. Adding break statements makes the control flow explicit and improves readability and maintainability, even if the break is currently unreachable.
| case CutlassTileConfigSM120::CtaShape128x128x128B: | |
| // Always use 1x1x1 cluster shape for SM120 | |
| return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>, | |
| cute::Int<128>>( | |
| D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, | |
| workspaceBytes, stream, occupancy); | |
| break; | |
| DISPATCH_WITH_SCHEDULER(128, 128, 128); | |
| case CutlassTileConfigSM120::CtaShape128x128x256B: | |
| DISPATCH_WITH_SCHEDULER(128, 128, 256); | |
| case CutlassTileConfigSM120::CtaShape256x128x128B: | |
| DISPATCH_WITH_SCHEDULER(256, 128, 128); | |
| case CutlassTileConfigSM120::CtaShape128x128x128B: | |
| DISPATCH_WITH_SCHEDULER(128, 128, 128); | |
| break; | |
| case CutlassTileConfigSM120::CtaShape128x128x256B: | |
| DISPATCH_WITH_SCHEDULER(128, 128, 256); | |
| break; | |
| case CutlassTileConfigSM120::CtaShape256x128x128B: | |
| DISPATCH_WITH_SCHEDULER(256, 128, 128); | |
| break; |
There was a problem hiding this comment.
Each DISPATCH macro will return, so I suppose it's fine.
|
/bot run |
|
@flashinfer-bot run |
|
[FAILED] Pipeline #43072123: 7/20 passed |
yzh119
left a comment
There was a problem hiding this comment.
Thanks for the improvement, do we have any performance numbers btw?
| case CutlassTileConfigSM120::CtaShape128x128x128B: | ||
| // Always use 1x1x1 cluster shape for SM120 | ||
| return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>, | ||
| cute::Int<128>>( | ||
| D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, | ||
| workspaceBytes, stream, occupancy); | ||
| break; | ||
| DISPATCH_WITH_SCHEDULER(128, 128, 128); | ||
| case CutlassTileConfigSM120::CtaShape128x128x256B: | ||
| DISPATCH_WITH_SCHEDULER(128, 128, 256); | ||
| case CutlassTileConfigSM120::CtaShape256x128x128B: | ||
| DISPATCH_WITH_SCHEDULER(256, 128, 128); |
There was a problem hiding this comment.
Each DISPATCH macro will return, so I suppose it's fine.
|
Is there any reason to restrict this to sm120 - looks like it won't be used on sm121, even though the architecture is identical? |
|
@eugr I don't see why not, it should be applicable to sm121. |
|
@yzh119 - my concern is because it has guards like this: |
Flashinfer builds by “a” suffix, i mean, to each chip specific, in that case, i think that it should be added sm121 condition |
Thanks for the comments! Let me add sm_version=121 condition later on a new PR. |
I have shown an example in the description, which is our target low-latency case of qwen3-32B on RTX Pro 6000.
|
Just so it stays on your radar :) |
Add a major-version-based helper that covers all SM12x GPUs (SM120a, SM121a, and future variants) so callers don't need to enumerate each minor version individually. Uses major == 12 check, matching the pattern of is_sm100a_supported (major == 10). Update existing call sites in gemm_base.py and the DeepSeek MLA test. This avoids the recurring pattern where SM121a support gets missed when only SM120a is checked, as noted in PR flashinfer-ai#2460 and flashinfer-ai#2560 discussion. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add a major-version-based helper that covers all SM12x GPUs (SM120a, SM121a, and future variants) so callers don't need to enumerate each minor version individually. Uses major == 12 check, matching the pattern of is_sm100a_supported (major == 10). Update existing call sites in gemm_base.py and the DeepSeek MLA test. This avoids the recurring pattern where SM121a support gets missed when only SM120a is checked, as noted in PR flashinfer-ai#2460 and flashinfer-ai#2560 discussion. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Just added sm121 support on a new PR 2631. Happy CNY :) |
|
Happy CNY! |
## Summary Adds `is_sm12x_supported()` to `flashinfer/utils.py` as a convenience helper that covers the entire SM12x GPU family (SM120a, SM121a, and future variants like SM122a) without requiring callers to enumerate each minor version. Uses a `major == 12` check, matching the existing pattern of `is_sm100a_supported()` (`major == 10`). This means future SM12x variants are automatically covered without code changes. **Motivation:** SM121a (DGX Spark) keeps getting missed when only SM120a is checked. This was noted by @eugr in #2560, and PR #2460 is another example where SM121a was not included alongside SM120a. ## Changes | File | Change | |------|--------| | `flashinfer/utils.py` | Add `is_sm12x_supported()` with `major == 12` check | | `flashinfer/gemm/gemm_base.py` | Replace 3 instances of `is_sm120a_supported(a.device) or is_sm121a_supported(a.device)` | | `tests/attention/test_fmha_v2_prefill_deepseek.py` | Update skip guard to use `is_sm12x_supported()` | The individual `is_sm120a_supported()` and `is_sm121a_supported()` functions are preserved for cases that need variant-specific behavior. Validated on DGX Spark (SM121a, CUDA 13.0). [Second Nature Computing](https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Consolidated separate SM120/SM121 capability checks into a unified SM12x check and updated the public import surface accordingly. * Introduced explicit CUDA-version gating for SM12x variants and clarified related compatibility/error messages. * **Tests** * Updated GPU compatibility tests and skip logic/messages to target SM12x architecture support. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…shinfer-ai#2574) ## Summary Adds `is_sm12x_supported()` to `flashinfer/utils.py` as a convenience helper that covers the entire SM12x GPU family (SM120a, SM121a, and future variants like SM122a) without requiring callers to enumerate each minor version. Uses a `major == 12` check, matching the existing pattern of `is_sm100a_supported()` (`major == 10`). This means future SM12x variants are automatically covered without code changes. **Motivation:** SM121a (DGX Spark) keeps getting missed when only SM120a is checked. This was noted by @eugr in flashinfer-ai#2560, and PR flashinfer-ai#2460 is another example where SM121a was not included alongside SM120a. ## Changes | File | Change | |------|--------| | `flashinfer/utils.py` | Add `is_sm12x_supported()` with `major == 12` check | | `flashinfer/gemm/gemm_base.py` | Replace 3 instances of `is_sm120a_supported(a.device) or is_sm121a_supported(a.device)` | | `tests/attention/test_fmha_v2_prefill_deepseek.py` | Update skip guard to use `is_sm12x_supported()` | The individual `is_sm120a_supported()` and `is_sm121a_supported()` functions are preserved for cases that need variant-specific behavior. Validated on DGX Spark (SM121a, CUDA 13.0). [Second Nature Computing](https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Consolidated separate SM120/SM121 capability checks into a unified SM12x check and updated the public import surface accordingly. * Introduced explicit CUDA-version gating for SM12x variants and clarified related compatibility/error messages. * **Tests** * Updated GPU compatibility tests and skip logic/messages to target SM12x architecture support. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
Add SM120 FP4 GEMM tile configs and enable streamK scheduler for configs that can be selected by autotuner.
Take the problem size m=32, n=5120, k=25600 as an example. After adding these configs, the latency of mm_fp4 kernel reduced from 0.124ms to 0.069ms on RTX PRO 6000.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes