Skip to content

perf: add fp4 GEMM tile configs and streamK scheduler for SM120#2460

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
Yuening-wa:add_sm120_config
Feb 6, 2026
Merged

perf: add fp4 GEMM tile configs and streamK scheduler for SM120#2460
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
Yuening-wa:add_sm120_config

Conversation

@Yuening-wa
Copy link
Copy Markdown
Contributor

@Yuening-wa Yuening-wa commented Feb 2, 2026

📌 Description

Add SM120 FP4 GEMM tile configs and enable streamK scheduler for configs that can be selected by autotuner.
Take the problem size m=32, n=5120, k=25600 as an example. After adding these configs, the latency of mm_fp4 kernel reduced from 0.124ms to 0.069ms on RTX PRO 6000.

[PERF] cutlass_autotun:: median time 0.124 ms; std 0.000 ms; achieved tflops 67.751 TFLOPs/sec; achieved tb_per_sec 0.535 TB/sec
[PERF] cutlass_autotun:: median time 0.069 ms; std 0.000 ms; achieved tflops 121.902 TFLOPs/sec; achieved tb_per_sec 0.963 TB/sec

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features
    • Added StreamK scheduler option for SM120 FP4 GEMM operations, alongside the existing DP scheduler.
    • Expanded SM120 FP4 GEMM tile configurations with two additional options (128×128×256 and 256×128×128).
    • Extended configuration generation to support all tile shapes with both schedulers for improved optimization flexibility.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Yuening-wa, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces performance enhancements for FP4 General Matrix Multiply (GEMM) operations on SM120 GPUs. By expanding the available tile configurations and integrating the StreamK scheduler from CUTLASS, the changes aim to provide more optimized execution paths. This results in improved latency and throughput for certain computational workloads, as evidenced by the provided benchmarks.

Highlights

  • Expanded FP4 GEMM Tile Configurations: New tile sizes (128, 128, 256) and (256, 128, 128) have been added for SM120 architecture, alongside the existing (128, 128, 128).
  • StreamK Scheduler Integration: The CUTLASS StreamK scheduler is now enabled for SM120 FP4 GEMM operations, offering an alternative to the default Data Parallel (DP) scheduler, particularly beneficial for specific problem sizes.
  • Performance Optimization: Benchmarks demonstrate a significant reduction in kernel latency (from 0.124ms to 0.069ms) for a sample problem size (m=32, n=5120, k=25600) on RTX PRO 6000, attributed to these new configurations and scheduler.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 2, 2026

📝 Walkthrough

Walkthrough

These changes extend the SM120 FP4 GEMM kernel infrastructure to support configurable scheduler selection (DP vs StreamK) alongside expanded tile configurations. A new use_stream_k flag is added to the configuration structure, dispatch routing is made scheduler-aware across multiple tile shapes, and shared kernel launch helpers consolidate logic across both scheduler variants.

Changes

Cohort / File(s) Summary
Tile Configuration Expansion
flashinfer/jit/gemm/core.py
Added two new SM120 FP4 Cutlass tile configurations: (128,128,256) and (256,128,128), alongside the existing (128,128,128). Generalized associated comment to cover all tile configurations.
Configuration Structure
include/flashinfer/gemm/cutlass_gemm_configs.h
Introduced use_stream_k boolean field (default false) to select between DP and StreamK schedulers on SM120. Updated SM120 constructor to accept this parameter and modified toString() to emit scheduler information when in Warp Specialized mode.
Dispatch Routing
include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h
Added UseStreamK template parameter to dispatchNVFP4xNVFP4GemmClusterShapeSm120 function. Introduced internal macros for scheduler-aware routing. Extended getConfigs() to generate configurations for all three tile shapes across both DP and StreamK schedulers instead of a single hardcoded path.
Kernel Launch Helpers
include/flashinfer/gemm/fp4_gemm_template_sm120.h
Introduced shared helpers prepareGemmArgsImpl and runFp4GemmImpl to unify argument preparation and execution logic across DP and StreamK schedulers. Refactored both launcher variants to use the new helpers, reducing code duplication and consolidating capability checks and initialization.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant Config as CutlassGemmConfig
    participant Dispatch as dispatchNVFP4xNVFP4Gemm<br/>ClusterShapeSm120
    participant Router as Scheduler<br/>Router
    participant DPLauncher as DP Launcher
    participant StreamKLauncher as StreamK Launcher
    participant Helper as prepareGemmArgsImpl<br/>& runFp4GemmImpl

    User->>Config: Create with tile shape<br/>& use_stream_k flag
    User->>Dispatch: Call with UseStreamK<br/>template param
    Dispatch->>Router: Route based on<br/>tile config & scheduler
    alt use_stream_k == false
        Router->>DPLauncher: Select DP path
        DPLauncher->>Helper: Call runFp4GemmImpl<br/>(DP variant)
    else use_stream_k == true
        Router->>StreamKLauncher: Select StreamK path
        StreamKLauncher->>Helper: Call runFp4GemmImpl<br/>(StreamK variant)
    end
    Helper->>Helper: prepareGemmArgsImpl<br/>(unified args)
    Helper->>Helper: Check workspace,<br/>capability
    Helper->>Helper: Launch kernel
    Helper-->>User: Return execution result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • flashinfer#2095: Modifies the SM120 FP4 GEMM kernel launch path (fp4_gemm_template_sm120.h) with changes to the enablePDL flag, directly overlapping with the launcher refactoring in this PR.

Suggested reviewers

  • bkryu
  • nvmbreughe
  • cyx-6
  • jimmyzho
  • ttyio
  • aleozlx
  • yongwww

Poem

🐰 Hops through dual schedulers with glee,
StreamK and DP now dance in harmony!
Tile shapes expand, refactored with care,
Shared helpers unite what we used to declare.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: adding FP4 GEMM tile configs and streamK scheduler support for SM120, which is the core of the changeset.
Description check ✅ Passed The pull request provides a clear description of changes, related performance improvements with concrete metrics, and confirms completion of pre-commit checks and tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations for FP4 GEMM on SM120 architecture by adding new tile configurations and enabling the StreamK scheduler. The changes are well-structured, particularly the refactoring in fp4_gemm_template_sm120.h to abstract common kernel launching logic, which significantly improves code clarity and maintainability. My main feedback is a minor style suggestion in fp4_gemm_cutlass_template_sm120.h to improve the readability of a switch statement. Overall, this is a great performance enhancement.

Comment on lines 113 to +118
case CutlassTileConfigSM120::CtaShape128x128x128B:
// Always use 1x1x1 cluster shape for SM120
return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>,
cute::Int<128>>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
break;
DISPATCH_WITH_SCHEDULER(128, 128, 128);
case CutlassTileConfigSM120::CtaShape128x128x256B:
DISPATCH_WITH_SCHEDULER(128, 128, 256);
case CutlassTileConfigSM120::CtaShape256x128x128B:
DISPATCH_WITH_SCHEDULER(256, 128, 128);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The break statements are missing between these case labels. While the DISPATCH_WITH_SCHEDULER macro expands to a return statement, which prevents functional issues from fallthrough, this structure is confusing and can be flagged by compilers with -Wimplicit-fallthrough. Adding break statements makes the control flow explicit and improves readability and maintainability, even if the break is currently unreachable.

Suggested change
case CutlassTileConfigSM120::CtaShape128x128x128B:
// Always use 1x1x1 cluster shape for SM120
return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>,
cute::Int<128>>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
break;
DISPATCH_WITH_SCHEDULER(128, 128, 128);
case CutlassTileConfigSM120::CtaShape128x128x256B:
DISPATCH_WITH_SCHEDULER(128, 128, 256);
case CutlassTileConfigSM120::CtaShape256x128x128B:
DISPATCH_WITH_SCHEDULER(256, 128, 128);
case CutlassTileConfigSM120::CtaShape128x128x128B:
DISPATCH_WITH_SCHEDULER(128, 128, 128);
break;
case CutlassTileConfigSM120::CtaShape128x128x256B:
DISPATCH_WITH_SCHEDULER(128, 128, 256);
break;
case CutlassTileConfigSM120::CtaShape256x128x128B:
DISPATCH_WITH_SCHEDULER(256, 128, 128);
break;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each DISPATCH macro will return, so I suppose it's fine.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 2, 2026

/bot run

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 2, 2026

@flashinfer-bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !287 has been created, and the CI pipeline #43072123 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #43072123: 7/20 passed

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the improvement, do we have any performance numbers btw?

Comment on lines 113 to +118
case CutlassTileConfigSM120::CtaShape128x128x128B:
// Always use 1x1x1 cluster shape for SM120
return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>,
cute::Int<128>>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
break;
DISPATCH_WITH_SCHEDULER(128, 128, 128);
case CutlassTileConfigSM120::CtaShape128x128x256B:
DISPATCH_WITH_SCHEDULER(128, 128, 256);
case CutlassTileConfigSM120::CtaShape256x128x128B:
DISPATCH_WITH_SCHEDULER(256, 128, 128);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each DISPATCH macro will return, so I suppose it's fine.

@yzh119 yzh119 merged commit 57ef44b into flashinfer-ai:main Feb 6, 2026
42 checks passed
@eugr
Copy link
Copy Markdown

eugr commented Feb 8, 2026

Is there any reason to restrict this to sm120 - looks like it won't be used on sm121, even though the architecture is identical?
@johnnynunez - FYI.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 8, 2026

@eugr I don't see why not, it should be applicable to sm121.

@eugr
Copy link
Copy Markdown

eugr commented Feb 8, 2026

@yzh119 - my concern is because it has guards like this: if (sm_version == 120) {, but looks like sm_version will be 121 on sm121? I'm not that familiar with flashinfer codebase, but a quick check showed the same pattern across the project is return sm_major * 10 + sm_minor;

@johnnynunez
Copy link
Copy Markdown
Contributor

@yzh119 - my concern is because it has guards like this: if (sm_version == 120) {, but looks like sm_version will be 121 on sm121? I'm not that familiar with flashinfer codebase, but a quick check showed the same pattern across the project is return sm_major * 10 + sm_minor;

Flashinfer builds by “a” suffix, i mean, to each chip specific, in that case, i think that it should be added sm121 condition

@Yuening-wa
Copy link
Copy Markdown
Contributor Author

@yzh119 - my concern is because it has guards like this: if (sm_version == 120) {, but looks like sm_version will be 121 on sm121? I'm not that familiar with flashinfer codebase, but a quick check showed the same pattern across the project is return sm_major * 10 + sm_minor;

Flashinfer builds by “a” suffix, i mean, to each chip specific, in that case, i think that it should be added sm121 condition

Thanks for the comments! Let me add sm_version=121 condition later on a new PR.

@Yuening-wa
Copy link
Copy Markdown
Contributor Author

Thanks for the improvement, do we have any performance numbers btw?

I have shown an example in the description, which is our target low-latency case of qwen3-32B on RTX Pro 6000.

Take the problem size m=32, n=5120, k=25600 as an example. After adding these configs, the latency of mm_fp4 kernel reduced from 0.124ms to 0.069ms on RTX PRO 6000.

[PERF] cutlass_autotun:: median time 0.124 ms; std 0.000 ms; achieved tflops 67.751 TFLOPs/sec; achieved tb_per_sec 0.535 TB/sec
[PERF] cutlass_autotun:: median time 0.069 ms; std 0.000 ms; achieved tflops 121.902 TFLOPs/sec; achieved tb_per_sec 0.963 TB/sec

@eugr
Copy link
Copy Markdown

eugr commented Feb 14, 2026

Thanks for the comments! Let me add sm_version=121 condition later on a new PR.

Just so it stays on your radar :)

blake-snc added a commit to blake-snc/flashinfer that referenced this pull request Feb 17, 2026
Add a major-version-based helper that covers all SM12x GPUs (SM120a,
SM121a, and future variants) so callers don't need to enumerate each
minor version individually. Uses major == 12 check, matching the
pattern of is_sm100a_supported (major == 10).

Update existing call sites in gemm_base.py and the DeepSeek MLA test.

This avoids the recurring pattern where SM121a support gets missed when
only SM120a is checked, as noted in PR flashinfer-ai#2460 and flashinfer-ai#2560 discussion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
blake-snc added a commit to blake-snc/flashinfer that referenced this pull request Feb 20, 2026
Add a major-version-based helper that covers all SM12x GPUs (SM120a,
SM121a, and future variants) so callers don't need to enumerate each
minor version individually. Uses major == 12 check, matching the
pattern of is_sm100a_supported (major == 10).

Update existing call sites in gemm_base.py and the DeepSeek MLA test.

This avoids the recurring pattern where SM121a support gets missed when
only SM120a is checked, as noted in PR flashinfer-ai#2460 and flashinfer-ai#2560 discussion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Yuening-wa
Copy link
Copy Markdown
Contributor Author

Just so it stays on your radar :)

Just added sm121 support on a new PR 2631. Happy CNY :)

@eugr
Copy link
Copy Markdown

eugr commented Feb 24, 2026

Happy CNY!

yongwww pushed a commit that referenced this pull request Feb 25, 2026
## Summary

Adds `is_sm12x_supported()` to `flashinfer/utils.py` as a convenience
helper that covers the entire SM12x GPU family (SM120a, SM121a, and
future variants like SM122a) without requiring callers to enumerate each
minor version.

Uses a `major == 12` check, matching the existing pattern of
`is_sm100a_supported()` (`major == 10`). This means future SM12x
variants are automatically covered without code changes.

**Motivation:** SM121a (DGX Spark) keeps getting missed when only SM120a
is checked. This was noted by @eugr in #2560, and PR #2460 is another
example where SM121a was not included alongside SM120a.

## Changes

| File | Change |
|------|--------|
| `flashinfer/utils.py` | Add `is_sm12x_supported()` with `major == 12`
check |
| `flashinfer/gemm/gemm_base.py` | Replace 3 instances of
`is_sm120a_supported(a.device) or is_sm121a_supported(a.device)` |
| `tests/attention/test_fmha_v2_prefill_deepseek.py` | Update skip guard
to use `is_sm12x_supported()` |

The individual `is_sm120a_supported()` and `is_sm121a_supported()`
functions are preserved for cases that need variant-specific behavior.

Validated on DGX Spark (SM121a, CUDA 13.0).

[Second Nature Computing](https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Consolidated separate SM120/SM121 capability checks into a unified
SM12x check and updated the public import surface accordingly.
* Introduced explicit CUDA-version gating for SM12x variants and
clarified related compatibility/error messages.

* **Tests**
* Updated GPU compatibility tests and skip logic/messages to target
SM12x architecture support.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…shinfer-ai#2574)

## Summary

Adds `is_sm12x_supported()` to `flashinfer/utils.py` as a convenience
helper that covers the entire SM12x GPU family (SM120a, SM121a, and
future variants like SM122a) without requiring callers to enumerate each
minor version.

Uses a `major == 12` check, matching the existing pattern of
`is_sm100a_supported()` (`major == 10`). This means future SM12x
variants are automatically covered without code changes.

**Motivation:** SM121a (DGX Spark) keeps getting missed when only SM120a
is checked. This was noted by @eugr in flashinfer-ai#2560, and PR flashinfer-ai#2460 is another
example where SM121a was not included alongside SM120a.

## Changes

| File | Change |
|------|--------|
| `flashinfer/utils.py` | Add `is_sm12x_supported()` with `major == 12`
check |
| `flashinfer/gemm/gemm_base.py` | Replace 3 instances of
`is_sm120a_supported(a.device) or is_sm121a_supported(a.device)` |
| `tests/attention/test_fmha_v2_prefill_deepseek.py` | Update skip guard
to use `is_sm12x_supported()` |

The individual `is_sm120a_supported()` and `is_sm121a_supported()`
functions are preserved for cases that need variant-specific behavior.

Validated on DGX Spark (SM121a, CUDA 13.0).

[Second Nature Computing](https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Consolidated separate SM120/SM121 capability checks into a unified
SM12x check and updated the public import surface accordingly.
* Introduced explicit CUDA-version gating for SM12x variants and
clarified related compatibility/error messages.

* **Tests**
* Updated GPU compatibility tests and skip logic/messages to target
SM12x architecture support.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants