Skip to content

[WIP][Do not review] feat: enable sm103 fp4 gemm#2888

Draft
nv-yunzheq wants to merge 3 commits intoflashinfer-ai:mainfrom
nv-yunzheq:cute_dsl_sm103_gemm_enable
Draft

[WIP][Do not review] feat: enable sm103 fp4 gemm#2888
nv-yunzheq wants to merge 3 commits intoflashinfer-ai:mainfrom
nv-yunzheq:cute_dsl_sm103_gemm_enable

Conversation

@nv-yunzheq
Copy link
Collaborator

@nv-yunzheq nv-yunzheq commented Mar 25, 2026

📌 Description

Issue #2621

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added benchmark script to evaluate and compare SM103 vs. SM100 performance across configurable problem sizes with CSV export support.
    • Enabled SM103 kernel support with runtime optimizations for memory management and tile shape computation.
  • Chores

    • Removed deprecated compatibility code for older framework versions.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 25, 2026

📝 Walkthrough

Walkthrough

The PR enables SM103 kernel support by activating runtime conditional imports in gemm_base, enhances the SM103 dense blockscaled GEMM kernel with dynamic TMEM allocation and improved epilogue handling, removes Cutlass compatibility monkey-patching from Blackwell DSL modules, and adds a benchmark script comparing SM100 versus SM103 GEMM tactics.

Changes

Cohort / File(s) Summary
SM103 Kernel Activation
flashinfer/gemm/gemm_base.py
Switched SM103 kernel selection from commented-out code to active runtime conditional imports, enabling functional SM103 support while maintaining fallback behavior.
SM103 Kernel Enhancement
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py
Enhanced TMEM allocation with runtime queries, refined epilogue tile shape computation, added overlapping accumulator mode for single-stage configurations, and improved alpha scaling with FP32 intermediate conversion.
Cutlass Compatibility Cleanup
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py, flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Removed local monkey-patching helpers and Int32 imports that provided fallback support for older Cutlass versions lacking swizzle_size and FastDivmod support.
Benchmark Tooling
benchmarks/bench_sm103_vs_sm100.py
Added standalone benchmark script for enumerating and measuring FP4 GEMM tactics across SM100 and SM103, supporting configurable problem sizes, output dtype, and CSV reporting.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

cute-dsl, op: gemm, run-ci

Suggested reviewers

  • aleozlx
  • yongwww
  • djmmoss
  • cyx-6
  • jimmyzho
  • yzh119

Poem

🐰 Hops with glee, a kernel springs to life,
SM103 bounds through the CuTe-DSL strife,
No more monkeypatch in shadowy code,
TMEM queries light the optimization road!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The pull request description is entirely a template with all sections empty, unchecked checkboxes, and no actual implementation details, rationale, or context provided. Fill in the Description section explaining what changes are made and why; provide context on the SM103 FP4 GEMM enablement; link related issues; check off completed pre-commit and test items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main feature: enabling SM103 FP4 GEMM support. However, the [WIP][Do not review] prefixes indicate the PR is incomplete and not ready for review.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on integrating and optimizing FP4 General Matrix Multiply (GEMM) operations for NVIDIA's Blackwell (SM103) GPU architecture. It enables the dedicated SM103 FP4 GEMM kernel, updates the underlying cutlass-dsl dependency to remove compatibility workarounds, and introduces a comprehensive benchmark to evaluate the performance gains of SM103-specific tactics over SM100 tactics. The changes aim to leverage Blackwell's capabilities for more efficient FP4 computations, particularly relevant for large language models.

Highlights

  • SM103 FP4 GEMM Enablement: The SM103 FP4 GEMM kernel has been officially enabled by uncommenting its import and usage in gemm_base.py, allowing Blackwell GPUs to utilize optimized FP4 matrix multiplication.
  • Cutlass-DSL Dependency Update: Outdated monkey-patching code for PersistentTileSchedulerParams and _get_cluster_work_idx_with_fastdivmod has been removed from fused MoE kernels, indicating an upgrade to a cutlass-dsl version (4.4.0+) that natively supports these features.
  • New SM103 vs SM100 FP4 GEMM Benchmark: A new benchmark script bench_sm103_vs_sm100.py was added to compare the performance of SM100 tactics against SM103-specific 3xFP4 tactics for GEMM across various LLM problem sizes on Blackwell hardware.
  • SM103 Kernel Optimizations: Several optimizations were introduced to the SM103 dense blockscaled GEMM kernel, including dynamic epilogue tile shape computation, improved TMEM column allocation, and accumulator overlapping/double buffering logic for enhanced performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

4607-4616: ⚠️ Potential issue | 🟠 Major

Fix mypy-blocking callable type mismatch for make_kernel.

Lines 4607–4622 conditionally assign make_kernel with different lambda return types: Sm103BlockScaledPersistentDenseGemmKernel (via Sm103Kernel) and Sm100BlockScaledPersistentDenseGemmKernel. This causes mypy to fail during pre-commit validation.

Add type annotation to resolve the conflict:

🔧 Proposed typing-safe fix
-from typing import List, Literal, Optional, Tuple
+from typing import Any, Callable, List, Literal, Optional, Tuple
...
+            make_kernel: Callable[[], Any]
             if kernel_type == "sm103" and Sm103Kernel is not None:
                 make_kernel = lambda: Sm103Kernel(
                     sf_vec_size,
                     mma_tiler_mn,
                     cluster_shape_mn,
                     use_tma_store,
                     enable_pdl,
                 )
             else:
                 make_kernel = lambda: Sm100BlockScaledPersistentDenseGemmKernel(
                     sf_vec_size,
                     mma_tiler_mn,
                     cluster_shape_mn,
                     use_prefetch,
                     enable_pdl,
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 4607 - 4616, The lambda
assignments to make_kernel return two different concrete kernel classes
(Sm103Kernel/Sm103BlockScaledPersistentDenseGemmKernel vs
Sm100BlockScaledPersistentDenseGemmKernel), causing a mypy callable return-type
mismatch; annotate make_kernel with a common kernel return type (e.g., add "from
typing import Callable" and declare make_kernel: Callable[[],
BlockScaledPersistentDenseGemmKernel] = ..." or the actual shared base/protocol
name used by Sm100BlockScaledPersistentDenseGemmKernel and Sm103* kernels) so
both lambdas satisfy the same Callable return type, or alternatively define a
small Protocol/base class and use Callable[[], ThatProtocol] if no shared base
exists; update the two lambda assignments (make_kernel) accordingly and import
typing as needed.
🧹 Nitpick comments (3)
benchmarks/bench_sm103_vs_sm100.py (2)

165-173: Add an explicit CUDA availability guard for clearer failure mode.

If CUDA is unavailable, this currently fails later in device capability calls; a direct check here gives a cleaner message.

🛡️ Suggested guard
-    device = torch.device("cuda")
+    if not torch.cuda.is_available():
+        raise RuntimeError("CUDA is required for benchmarks/bench_sm103_vs_sm100.py")
+    device = torch.device("cuda")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_sm103_vs_sm100.py` around lines 165 - 173, Add an explicit
CUDA availability check before creating device and querying capabilities: verify
torch.cuda.is_available() and if false log/raise a clear error or exit rather
than proceeding to torch.device("cuda") and torch.cuda.get_device_capability;
update the block that currently sets device, calls
torch.cuda.get_device_capability(device), computes sm_version, and prints GPU
info (variables/functions: device, torch.cuda.is_available,
torch.cuda.get_device_capability, torch.cuda.get_device_name) so it early-exits
with a readable message when CUDA is not available.

89-99: Clean up unused locals to keep lint output clean.

Line 90 (prefetch) and Line 202 (err) are unused; rename to _prefetch / _err (or remove) to silence Ruff noise.

🧹 Minimal cleanup
-    mma, cluster, swap, prefetch, ktype, tma_store = tactic
+    mma, cluster, swap, _prefetch, ktype, tma_store = tactic
...
-                ms, err = benchmark_one(runner, inputs, tactic, args.iters)
+                ms, _err = benchmark_one(runner, inputs, tactic, args.iters)

Also applies to: 200-203

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_sm103_vs_sm100.py` around lines 89 - 99, The linter warning
is caused by unused local variables; in format_tactic rename the unused
parameter prefetch to _prefetch (e.g., def format_tactic(tactic): mma, cluster,
swap, _prefetch, ktype, tma_store = tactic) so Ruff ignores it, and similarly
rename the unused error variable err to _err in the other location (where err is
assigned around lines ~200–203) or remove it if not needed; update any matching
unpacking or assignments that reference these names (format_tactic, and the
function/block that defines err) to silence the Ruff unused-variable warnings.
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py (1)

1634-1639: Convert lambda to nested function per style guidelines.

The alpha scaling logic is correct (FP32 multiplication for precision, then cast to c_dtype), but static analysis flags E731 for assigning a lambda to a variable.

♻️ Refactor lambda to def
-            # Wrap epilogue_op with alpha scaling.
-            # The library epilogue converts acc to c_dtype before calling epilogue_op,
-            # so alpha*x promotes to Float32; we must convert back to c_dtype for the store.
-            alpha_epilogue_op = lambda x: epilogue_op(
-                (alpha_value * x).to(self.c_dtype)
-            )
+            # Wrap epilogue_op with alpha scaling.
+            # The library epilogue converts acc to c_dtype before calling epilogue_op,
+            # so alpha*x promotes to Float32; we must convert back to c_dtype for the store.
+            def alpha_epilogue_op(x):
+                return epilogue_op((alpha_value * x).to(self.c_dtype))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py` around lines 1634 -
1639, Replace the lambda assigned to alpha_epilogue_op with a nested def
function to satisfy style/E731: define a function named (e.g.) alpha_epilogue_op
that accepts x and returns epilogue_op((alpha_value * x).to(self.c_dtype)),
keeping the same semantics (FP32 multiply then cast to self.c_dtype) and using
the existing symbols epilogue_op, alpha_value, and self.c_dtype so callers of
alpha_epilogue_op are unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4607-4616: The lambda assignments to make_kernel return two
different concrete kernel classes
(Sm103Kernel/Sm103BlockScaledPersistentDenseGemmKernel vs
Sm100BlockScaledPersistentDenseGemmKernel), causing a mypy callable return-type
mismatch; annotate make_kernel with a common kernel return type (e.g., add "from
typing import Callable" and declare make_kernel: Callable[[],
BlockScaledPersistentDenseGemmKernel] = ..." or the actual shared base/protocol
name used by Sm100BlockScaledPersistentDenseGemmKernel and Sm103* kernels) so
both lambdas satisfy the same Callable return type, or alternatively define a
small Protocol/base class and use Callable[[], ThatProtocol] if no shared base
exists; update the two lambda assignments (make_kernel) accordingly and import
typing as needed.

---

Nitpick comments:
In `@benchmarks/bench_sm103_vs_sm100.py`:
- Around line 165-173: Add an explicit CUDA availability check before creating
device and querying capabilities: verify torch.cuda.is_available() and if false
log/raise a clear error or exit rather than proceeding to torch.device("cuda")
and torch.cuda.get_device_capability; update the block that currently sets
device, calls torch.cuda.get_device_capability(device), computes sm_version, and
prints GPU info (variables/functions: device, torch.cuda.is_available,
torch.cuda.get_device_capability, torch.cuda.get_device_name) so it early-exits
with a readable message when CUDA is not available.
- Around line 89-99: The linter warning is caused by unused local variables; in
format_tactic rename the unused parameter prefetch to _prefetch (e.g., def
format_tactic(tactic): mma, cluster, swap, _prefetch, ktype, tma_store = tactic)
so Ruff ignores it, and similarly rename the unused error variable err to _err
in the other location (where err is assigned around lines ~200–203) or remove it
if not needed; update any matching unpacking or assignments that reference these
names (format_tactic, and the function/block that defines err) to silence the
Ruff unused-variable warnings.

In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py`:
- Around line 1634-1639: Replace the lambda assigned to alpha_epilogue_op with a
nested def function to satisfy style/E731: define a function named (e.g.)
alpha_epilogue_op that accepts x and returns epilogue_op((alpha_value *
x).to(self.c_dtype)), keeping the same semantics (FP32 multiply then cast to
self.c_dtype) and using the existing symbols epilogue_op, alpha_value, and
self.c_dtype so callers of alpha_epilogue_op are unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: aefa4e42-4326-47bb-9f37-ea8da7670ce4

📥 Commits

Reviewing files that changed from the base of the PR and between ede7a27 and 830bea9.

📒 Files selected for processing (5)
  • benchmarks/bench_sm103_vs_sm100.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py
💤 Files with no reviewable changes (2)
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new benchmark script to compare SM103 and SM100 FP4 GEMM tactics. It re-enables the SM103 kernel and updates its implementation to leverage cutlass-dsl's API for hardware-specific values and dynamic epilogue tile computation. The changes also include refactoring accumulator handling for overlapping operations and ensuring correct type casting in the epilogue. A review comment suggests improving error logging in the benchmark_one function within the new benchmark script, as the current nested try-except blocks could mask original errors. Another comment emphasizes the importance of explicit type conversion in the epilogue to prevent potential type mismatches or precision issues.

Comment on lines +126 to +139
except Exception:
try:
times = bench_gpu_time(
run_fn,
dry_run_iters=max(3, iters // 4),
repeat_iters=iters,
enable_cupti=False,
use_cuda_graph=False,
cold_l2_cache=True,
sleep_after_run=True,
)
return float(np.median(times)), None
except Exception as e2:
return None, str(e2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested try-except blocks in benchmark_one are too broad. The inner except Exception catches any exception and then retries with different parameters, potentially masking the original error. It would be beneficial to log the initial exception before retrying to aid in debugging.

# may not be available in older cutlass-dsl versions.
SM103_TMEM_CAPACITY_COLUMNS = 512
self.num_tmem_alloc_cols = SM103_TMEM_CAPACITY_COLUMNS
self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_103")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replacing the hardcoded 512 with cute.arch.get_max_tmem_alloc_cols("sm_103") is a good improvement. It makes the code more robust by relying on the library's API for hardware-specific values, which can adapt to future changes or more accurate definitions within cutlass-dsl.

Comment on lines +240 to +245
self.epi_tile = sm103_utils.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk,
self.use_2cta_instrs,
self.c_layout,
self.c_dtype,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using sm103_utils.compute_epilogue_tile_shape instead of a hardcoded value for self.epi_tile is a positive change. This makes the epilogue tile computation more dynamic and adaptable to various kernel configurations, improving flexibility and correctness.

Comment on lines +1637 to +1638
alpha_epilogue_op = lambda x: epilogue_op(
(alpha_value * x).to(self.c_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly converting (alpha_value * x) to self.c_dtype before calling epilogue_op is important for correctness. This ensures that type promotion to Float32 during multiplication is handled, and the result is cast back to the expected output data type, preventing potential type mismatches or precision issues during storage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant