[WIP][Do not review] feat: enable sm103 fp4 gemm#2888
[WIP][Do not review] feat: enable sm103 fp4 gemm#2888nv-yunzheq wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThe PR enables SM103 kernel support by activating runtime conditional imports in gemm_base, enhances the SM103 dense blockscaled GEMM kernel with dynamic TMEM allocation and improved epilogue handling, removes Cutlass compatibility monkey-patching from Blackwell DSL modules, and adds a benchmark script comparing SM100 versus SM103 GEMM tactics. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on integrating and optimizing FP4 General Matrix Multiply (GEMM) operations for NVIDIA's Blackwell (SM103) GPU architecture. It enables the dedicated SM103 FP4 GEMM kernel, updates the underlying Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
4607-4616:⚠️ Potential issue | 🟠 MajorFix mypy-blocking callable type mismatch for
make_kernel.Lines 4607–4622 conditionally assign
make_kernelwith different lambda return types:Sm103BlockScaledPersistentDenseGemmKernel(viaSm103Kernel) andSm100BlockScaledPersistentDenseGemmKernel. This causes mypy to fail during pre-commit validation.Add type annotation to resolve the conflict:
🔧 Proposed typing-safe fix
-from typing import List, Literal, Optional, Tuple +from typing import Any, Callable, List, Literal, Optional, Tuple ... + make_kernel: Callable[[], Any] if kernel_type == "sm103" and Sm103Kernel is not None: make_kernel = lambda: Sm103Kernel( sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store, enable_pdl, ) else: make_kernel = lambda: Sm100BlockScaledPersistentDenseGemmKernel( sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_prefetch, enable_pdl, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 4607 - 4616, The lambda assignments to make_kernel return two different concrete kernel classes (Sm103Kernel/Sm103BlockScaledPersistentDenseGemmKernel vs Sm100BlockScaledPersistentDenseGemmKernel), causing a mypy callable return-type mismatch; annotate make_kernel with a common kernel return type (e.g., add "from typing import Callable" and declare make_kernel: Callable[[], BlockScaledPersistentDenseGemmKernel] = ..." or the actual shared base/protocol name used by Sm100BlockScaledPersistentDenseGemmKernel and Sm103* kernels) so both lambdas satisfy the same Callable return type, or alternatively define a small Protocol/base class and use Callable[[], ThatProtocol] if no shared base exists; update the two lambda assignments (make_kernel) accordingly and import typing as needed.
🧹 Nitpick comments (3)
benchmarks/bench_sm103_vs_sm100.py (2)
165-173: Add an explicit CUDA availability guard for clearer failure mode.If CUDA is unavailable, this currently fails later in device capability calls; a direct check here gives a cleaner message.
🛡️ Suggested guard
- device = torch.device("cuda") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for benchmarks/bench_sm103_vs_sm100.py") + device = torch.device("cuda")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_sm103_vs_sm100.py` around lines 165 - 173, Add an explicit CUDA availability check before creating device and querying capabilities: verify torch.cuda.is_available() and if false log/raise a clear error or exit rather than proceeding to torch.device("cuda") and torch.cuda.get_device_capability; update the block that currently sets device, calls torch.cuda.get_device_capability(device), computes sm_version, and prints GPU info (variables/functions: device, torch.cuda.is_available, torch.cuda.get_device_capability, torch.cuda.get_device_name) so it early-exits with a readable message when CUDA is not available.
89-99: Clean up unused locals to keep lint output clean.Line 90 (
prefetch) and Line 202 (err) are unused; rename to_prefetch/_err(or remove) to silence Ruff noise.🧹 Minimal cleanup
- mma, cluster, swap, prefetch, ktype, tma_store = tactic + mma, cluster, swap, _prefetch, ktype, tma_store = tactic ... - ms, err = benchmark_one(runner, inputs, tactic, args.iters) + ms, _err = benchmark_one(runner, inputs, tactic, args.iters)Also applies to: 200-203
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_sm103_vs_sm100.py` around lines 89 - 99, The linter warning is caused by unused local variables; in format_tactic rename the unused parameter prefetch to _prefetch (e.g., def format_tactic(tactic): mma, cluster, swap, _prefetch, ktype, tma_store = tactic) so Ruff ignores it, and similarly rename the unused error variable err to _err in the other location (where err is assigned around lines ~200–203) or remove it if not needed; update any matching unpacking or assignments that reference these names (format_tactic, and the function/block that defines err) to silence the Ruff unused-variable warnings.flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py (1)
1634-1639: Convert lambda to nested function per style guidelines.The alpha scaling logic is correct (FP32 multiplication for precision, then cast to
c_dtype), but static analysis flags E731 for assigning a lambda to a variable.♻️ Refactor lambda to def
- # Wrap epilogue_op with alpha scaling. - # The library epilogue converts acc to c_dtype before calling epilogue_op, - # so alpha*x promotes to Float32; we must convert back to c_dtype for the store. - alpha_epilogue_op = lambda x: epilogue_op( - (alpha_value * x).to(self.c_dtype) - ) + # Wrap epilogue_op with alpha scaling. + # The library epilogue converts acc to c_dtype before calling epilogue_op, + # so alpha*x promotes to Float32; we must convert back to c_dtype for the store. + def alpha_epilogue_op(x): + return epilogue_op((alpha_value * x).to(self.c_dtype))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py` around lines 1634 - 1639, Replace the lambda assigned to alpha_epilogue_op with a nested def function to satisfy style/E731: define a function named (e.g.) alpha_epilogue_op that accepts x and returns epilogue_op((alpha_value * x).to(self.c_dtype)), keeping the same semantics (FP32 multiply then cast to self.c_dtype) and using the existing symbols epilogue_op, alpha_value, and self.c_dtype so callers of alpha_epilogue_op are unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4607-4616: The lambda assignments to make_kernel return two
different concrete kernel classes
(Sm103Kernel/Sm103BlockScaledPersistentDenseGemmKernel vs
Sm100BlockScaledPersistentDenseGemmKernel), causing a mypy callable return-type
mismatch; annotate make_kernel with a common kernel return type (e.g., add "from
typing import Callable" and declare make_kernel: Callable[[],
BlockScaledPersistentDenseGemmKernel] = ..." or the actual shared base/protocol
name used by Sm100BlockScaledPersistentDenseGemmKernel and Sm103* kernels) so
both lambdas satisfy the same Callable return type, or alternatively define a
small Protocol/base class and use Callable[[], ThatProtocol] if no shared base
exists; update the two lambda assignments (make_kernel) accordingly and import
typing as needed.
---
Nitpick comments:
In `@benchmarks/bench_sm103_vs_sm100.py`:
- Around line 165-173: Add an explicit CUDA availability check before creating
device and querying capabilities: verify torch.cuda.is_available() and if false
log/raise a clear error or exit rather than proceeding to torch.device("cuda")
and torch.cuda.get_device_capability; update the block that currently sets
device, calls torch.cuda.get_device_capability(device), computes sm_version, and
prints GPU info (variables/functions: device, torch.cuda.is_available,
torch.cuda.get_device_capability, torch.cuda.get_device_name) so it early-exits
with a readable message when CUDA is not available.
- Around line 89-99: The linter warning is caused by unused local variables; in
format_tactic rename the unused parameter prefetch to _prefetch (e.g., def
format_tactic(tactic): mma, cluster, swap, _prefetch, ktype, tma_store = tactic)
so Ruff ignores it, and similarly rename the unused error variable err to _err
in the other location (where err is assigned around lines ~200–203) or remove it
if not needed; update any matching unpacking or assignments that reference these
names (format_tactic, and the function/block that defines err) to silence the
Ruff unused-variable warnings.
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py`:
- Around line 1634-1639: Replace the lambda assigned to alpha_epilogue_op with a
nested def function to satisfy style/E731: define a function named (e.g.)
alpha_epilogue_op that accepts x and returns epilogue_op((alpha_value *
x).to(self.c_dtype)), keeping the same semantics (FP32 multiply then cast to
self.c_dtype) and using the existing symbols epilogue_op, alpha_value, and
self.c_dtype so callers of alpha_epilogue_op are unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: aefa4e42-4326-47bb-9f37-ea8da7670ce4
📒 Files selected for processing (5)
benchmarks/bench_sm103_vs_sm100.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/gemm/gemm_base.pyflashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py
💤 Files with no reviewable changes (2)
- flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
- flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
There was a problem hiding this comment.
Code Review
This pull request introduces a new benchmark script to compare SM103 and SM100 FP4 GEMM tactics. It re-enables the SM103 kernel and updates its implementation to leverage cutlass-dsl's API for hardware-specific values and dynamic epilogue tile computation. The changes also include refactoring accumulator handling for overlapping operations and ensuring correct type casting in the epilogue. A review comment suggests improving error logging in the benchmark_one function within the new benchmark script, as the current nested try-except blocks could mask original errors. Another comment emphasizes the importance of explicit type conversion in the epilogue to prevent potential type mismatches or precision issues.
| except Exception: | ||
| try: | ||
| times = bench_gpu_time( | ||
| run_fn, | ||
| dry_run_iters=max(3, iters // 4), | ||
| repeat_iters=iters, | ||
| enable_cupti=False, | ||
| use_cuda_graph=False, | ||
| cold_l2_cache=True, | ||
| sleep_after_run=True, | ||
| ) | ||
| return float(np.median(times)), None | ||
| except Exception as e2: | ||
| return None, str(e2) |
There was a problem hiding this comment.
| # may not be available in older cutlass-dsl versions. | ||
| SM103_TMEM_CAPACITY_COLUMNS = 512 | ||
| self.num_tmem_alloc_cols = SM103_TMEM_CAPACITY_COLUMNS | ||
| self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_103") |
There was a problem hiding this comment.
| self.epi_tile = sm103_utils.compute_epilogue_tile_shape( | ||
| self.cta_tile_shape_mnk, | ||
| self.use_2cta_instrs, | ||
| self.c_layout, | ||
| self.c_dtype, | ||
| ) |
| alpha_epilogue_op = lambda x: epilogue_op( | ||
| (alpha_value * x).to(self.c_dtype) |
There was a problem hiding this comment.
Explicitly converting (alpha_value * x) to self.c_dtype before calling epilogue_op is important for correctness. This ensures that type promotion to Float32 during multiplication is handled, and the result is cast back to the expected output data type, preventing potential type mismatches or precision issues during storage.
📌 Description
Issue #2621
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Chores