Upgrade cutlass 4.2.1 -> 4.4.2#2798
Conversation
Fixes TMA descriptor bug where the CUDA driver was not properly setting the OOB address gen mode, causing non-deterministic crashes in tma_warp_specialized_generic_moe_gemm_kernelLauncher<Sm120, fp4> on DGX Spark (SM121) with NVFP4 MoE models. Ref: NVBug 5804240, upstream issues flashinfer-ai#2776, flashinfer-ai#2577 Ref: TRT-LLM fix NVIDIA/TensorRT-LLM#11956
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (5)
📝 WalkthroughWalkthroughUpdates cutlass subproject dependency and standardizes namespace qualification for error-reporting function calls across multiple GEMM kernel launchers and builders. Additionally adjusts template parameter usage in mixed-input GEMM builder from Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request aims to resolve critical TMA descriptor crashes observed in NVFP4 MoE models on DGX Spark systems. This is achieved by updating the underlying Cutlass library to a newer version (4.4.1) and introducing specific GPU Device Code (GDC) compilation flags to ensure proper handling and execution of kernels across different GPU architectures, thereby enhancing stability and compatibility. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/jit/fused_moe.py (1)
33-99: Optional cleanup: centralize repeated GDC macro literals.The same define strings are repeated across generators; extracting constants reduces drift risk in future flag edits.
♻️ Suggested refactor
+CUTLASS_GDC_FLAG_SM100 = "-DCUTLASS_ENABLE_GDC_FOR_SM100=1" +CUTLASS_GDC_FLAG_SM90 = "-DCUTLASS_ENABLE_GDC_FOR_SM90=1" + def gen_cutlass_fused_moe_sm120_module(use_fast_build: bool = False) -> JitSpec: nvcc_flags = [ @@ - "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", + CUTLASS_GDC_FLAG_SM100, @@ def gen_cutlass_fused_moe_sm103_module(use_fast_build: bool = False) -> JitSpec: @@ - "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", + CUTLASS_GDC_FLAG_SM100, @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec: @@ - "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", + CUTLASS_GDC_FLAG_SM100, @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec: @@ - "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", + CUTLASS_GDC_FLAG_SM90,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/fused_moe.py` around lines 33 - 99, Extract the repeated GDC define strings into named constants and use them in the generators instead of repeating the literal; e.g., add constants like CUTLASS_ENABLE_GDC_FOR_SM100 = "-DCUTLASS_ENABLE_GDC_FOR_SM100=1" and CUTLASS_ENABLE_GDC_FOR_SM90 = "-DCUTLASS_ENABLE_GDC_FOR_SM90=1" (or a function that returns the flag for a given SM), then replace the literal occurrences in gen_cutlass_fused_moe_sm120_module, gen_cutlass_fused_moe_sm103_module, gen_cutlass_fused_moe_sm100_module and gen_cutlass_fused_moe_sm90_module with those constants (or the helper) so all nvcc_flags lists reference the centralized symbol instead of hardcoded strings.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/jit/fused_moe.py`:
- Around line 33-99: Extract the repeated GDC define strings into named
constants and use them in the generators instead of repeating the literal; e.g.,
add constants like CUTLASS_ENABLE_GDC_FOR_SM100 =
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1" and CUTLASS_ENABLE_GDC_FOR_SM90 =
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1" (or a function that returns the flag for a
given SM), then replace the literal occurrences in
gen_cutlass_fused_moe_sm120_module, gen_cutlass_fused_moe_sm103_module,
gen_cutlass_fused_moe_sm100_module and gen_cutlass_fused_moe_sm90_module with
those constants (or the helper) so all nvcc_flags lists reference the
centralized symbol instead of hardcoded strings.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8cf53880-75ce-4532-a6ec-473fae61abd7
📒 Files selected for processing (2)
3rdparty/cutlassflashinfer/jit/fused_moe.py
There was a problem hiding this comment.
Code Review
This pull request upgrades the cutlass submodule to version 4.4.1 and adds compilation flags to enable Grid Dependent Control (GDC), which is intended to fix a crash on newer GPU architectures like SM120. The changes appear correct and address the described issue. I have one suggestion regarding code duplication to improve maintainability.
flashinfer/jit/fused_moe.py
Outdated
| "-DENABLE_FP8", | ||
| "-DENABLE_FP4", | ||
| "-DUSING_OSS_CUTLASS_MOE_GEMM", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", |
There was a problem hiding this comment.
This flag -DCUTLASS_ENABLE_GDC_FOR_SM100=1 is also added to gen_cutlass_fused_moe_sm103_module and gen_cutlass_fused_moe_sm100_module. Since many flags are shared across these functions for Blackwell architectures, consider refactoring them into a common base list of flags to improve maintainability and reduce duplication.
|
[FAILED] Pipeline #46292785: 8/20 passed |
|
/bot run |
|
[FAILED] Pipeline #46371671: 11/20 passed |
|
/bot run |
|
conceptually i don't have problem with it from code review standpoint there are errors in the JIT unit test H100 also note that this is something that caused problems in the past and we wanna test thoroughly and watch out for the H100 and the Spark/RTX PRO 6000 related tests context: #2737 so i won't approve until tests are clean on SM90 and SM120f also cc @bkryu as extra set of eyes |
|
/bot run |
|
/bot cancel |
|
Unknown Command Command Use |
|
/bot run |
|
[SUCCESS] Pipeline #46479848: 13/20 passed |
|
does xqa use cutlass? is this a precision tolerance issue? (1 failed test on spark) |
@aleozlx 98.9 seems close enough to 99 where this is probably a tolerance issue, xqa does not use cutlass |
📌 Description
Upgrade cutlass 4.2.1 -> 4.4.1, also add "CUTLASS_ENABLE_GDC_" to cutlass compilation flags.
Addresses this issue raised on slack: "Hi team, we're seeing CUTLASS TMA descriptor crashes on DGX Spark ... the crash happens in tma_warp_specialized_generic_moe_gemm_kernelLauncher<Sm120, fp4> from fused_moe_120.so."
🔍 Related Issues
[Bug] NVFP4 MoE models crash on GB10 (SM121) during CUDA graph capture #2776](#2776)
[Bug] NVFP4 mm_fp4 GEMM broken on SM120 (RTX PRO 6000 Blackwell) - all backends fail #2577
https://github.com/flashinfer-ai/flashinfer/pull/2716/changes
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
No functional changes or end-user impacts.