Skip to content

Upgrade cutlass 4.2.1 -> 4.4.2#2798

Merged
kahyunnam merged 5 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/spark-cutlass-moe-fix
Mar 19, 2026
Merged

Upgrade cutlass 4.2.1 -> 4.4.2#2798
kahyunnam merged 5 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/spark-cutlass-moe-fix

Conversation

@kahyunnam
Copy link
Copy Markdown
Collaborator

@kahyunnam kahyunnam commented Mar 16, 2026

📌 Description

Upgrade cutlass 4.2.1 -> 4.4.1, also add "CUTLASS_ENABLE_GDC_" to cutlass compilation flags.

Addresses this issue raised on slack: "Hi team, we're seeing CUTLASS TMA descriptor crashes on DGX Spark ... the crash happens in tma_warp_specialized_generic_moe_gemm_kernelLauncher<Sm120, fp4> from fused_moe_120.so."

🔍 Related Issues

[Bug] NVFP4 MoE models crash on GB10 (SM121) during CUDA graph capture #2776](#2776)

[Bug] NVFP4 mm_fp4 GEMM broken on SM120 (RTX PRO 6000 Blackwell) - all backends fail #2577

https://github.com/flashinfer-ai/flashinfer/pull/2716/changes

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Chores
    • Updated CUTLASS subproject dependency to latest version.
    • Applied internal namespace qualification adjustments for improved code consistency.
    • Refined template parameter handling in extension modules.

No functional changes or end-user impacts.

Fixes TMA descriptor bug where the CUDA driver was not properly setting
the OOB address gen mode, causing non-deterministic crashes in
tma_warp_specialized_generic_moe_gemm_kernelLauncher<Sm120, fp4> on
DGX Spark (SM121) with NVFP4 MoE models.

Ref: NVBug 5804240, upstream issues flashinfer-ai#2776, flashinfer-ai#2577
Ref: TRT-LLM fix NVIDIA/TensorRT-LLM#11956
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 16, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9c9ee9c5-4046-4fc3-a951-1ffcef53eaf6

📥 Commits

Reviewing files that changed from the base of the PR and between 228d88d and 6fdaa51.

📒 Files selected for processing (5)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

📝 Walkthrough

Walkthrough

Updates cutlass subproject dependency and standardizes namespace qualification for error-reporting function calls across multiple GEMM kernel launchers and builders. Additionally adjusts template parameter usage in mixed-input GEMM builder from StageCountType::bytes to StageCountType.

Changes

Cohort / File(s) Summary
Cutlass Dependency
3rdparty/cutlass
Subproject commit updated to incorporate upstream changes; no local code modifications.
Namespace Qualification - Error Reporting
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h, csrc/nv_internal/.../fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl, csrc/nv_internal/.../moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl, csrc/nv_internal/.../moe_gemm/moe_gemm_template_dispatch.h
Qualifies three cutlassGetStatusString() calls with cutlass:: namespace prefix in error-path branches (can_implement, init_status, run_status); applied consistently across FPAA_intB and MOE GEMM launchers.
Template Parameter Adjustment
csrc/nv_internal/tensorrt_llm/cutlass_extensions/.../sm90_gmma_builder_mixed_input.inl
Replaces StageCountType::bytes template specialization with StageCountType in two mixed-input branches and one non-mixed input branch for stage count computation calls.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Suggested reviewers

  • djmmoss
  • yzh119
  • aleozlx
  • bkryu
  • jimmyzho
  • nv-yunzheq

Poem

🐰 Hopping through namespaces with care,
Qualifying strings here and there,
Stage counts now bloom without ::bytes,
Cutlass aligned—our gemm kernels ignite!
Rabbits approve this tidy refine.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Upgrade cutlass 4.2.1 -> 4.4.2' clearly summarizes the main change—a CUTLASS dependency version upgrade—which aligns with the primary objective of this PR.
Description check ✅ Passed The description includes all required template sections: a clear description of changes (CUTLASS upgrade and GDC flag addition), related issues linked, and a completed checklist confirming pre-commit hooks and tests. The description adequately documents the purpose and supporting evidence for the changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request aims to resolve critical TMA descriptor crashes observed in NVFP4 MoE models on DGX Spark systems. This is achieved by updating the underlying Cutlass library to a newer version (4.4.1) and introducing specific GPU Device Code (GDC) compilation flags to ensure proper handling and execution of kernels across different GPU architectures, thereby enhancing stability and compatibility.

Highlights

  • Cutlass Upgrade: Upgraded the Cutlass library from version 4.2.1 to 4.4.1, which includes performance improvements and bug fixes.
  • Compilation Flags: Added CUTLASS_ENABLE_GDC_FOR_SM100=1 and CUTLASS_ENABLE_GDC_FOR_SM90=1 compilation flags to various fused MoE module generation functions.
  • Crash Resolution: Addressed an issue causing TMA descriptor crashes on DGX Spark with NVFP4 MoE models, specifically tma_warp_specialized_generic_moe_gemm_kernelLauncher.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • 3rdparty/cutlass
    • Updated the Cutlass subproject to commit 4370102f9dacab813282e1d67722fceb0b90a019 from f3fde58372d33e9a5650ba7b80fc48b3b49d40c8.
  • flashinfer/jit/fused_moe.py
    • Added -DCUTLASS_ENABLE_GDC_FOR_SM100=1 to gen_cutlass_fused_moe_sm120_module's nvcc flags.
    • Added -DCUTLASS_ENABLE_GDC_FOR_SM100=1 to gen_cutlass_fused_moe_sm103_module's nvcc flags.
    • Added -DCUTLASS_ENABLE_GDC_FOR_SM100=1 to gen_cutlass_fused_moe_sm100_module's nvcc flags.
    • Added -DCUTLASS_ENABLE_GDC_FOR_SM90=1 to gen_cutlass_fused_moe_sm90_module's nvcc flags.
Activity
  • The pull request was opened by kahyunnam to address reported issues with NVFP4 MoE models crashing on specific hardware configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@kahyunnam kahyunnam marked this pull request as ready for review March 16, 2026 23:15
@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !419 has been created, and the CI pipeline #46292785 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/jit/fused_moe.py (1)

33-99: Optional cleanup: centralize repeated GDC macro literals.

The same define strings are repeated across generators; extracting constants reduces drift risk in future flag edits.

♻️ Suggested refactor
+CUTLASS_GDC_FLAG_SM100 = "-DCUTLASS_ENABLE_GDC_FOR_SM100=1"
+CUTLASS_GDC_FLAG_SM90 = "-DCUTLASS_ENABLE_GDC_FOR_SM90=1"
+
 def gen_cutlass_fused_moe_sm120_module(use_fast_build: bool = False) -> JitSpec:
     nvcc_flags = [
@@
-        "-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
+        CUTLASS_GDC_FLAG_SM100,
@@
 def gen_cutlass_fused_moe_sm103_module(use_fast_build: bool = False) -> JitSpec:
@@
-        "-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
+        CUTLASS_GDC_FLAG_SM100,
@@
 def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
@@
-        "-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
+        CUTLASS_GDC_FLAG_SM100,
@@
 def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec:
@@
-        "-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
+        CUTLASS_GDC_FLAG_SM90,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/fused_moe.py` around lines 33 - 99, Extract the repeated GDC
define strings into named constants and use them in the generators instead of
repeating the literal; e.g., add constants like CUTLASS_ENABLE_GDC_FOR_SM100 =
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1" and CUTLASS_ENABLE_GDC_FOR_SM90 =
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1" (or a function that returns the flag for a
given SM), then replace the literal occurrences in
gen_cutlass_fused_moe_sm120_module, gen_cutlass_fused_moe_sm103_module,
gen_cutlass_fused_moe_sm100_module and gen_cutlass_fused_moe_sm90_module with
those constants (or the helper) so all nvcc_flags lists reference the
centralized symbol instead of hardcoded strings.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/jit/fused_moe.py`:
- Around line 33-99: Extract the repeated GDC define strings into named
constants and use them in the generators instead of repeating the literal; e.g.,
add constants like CUTLASS_ENABLE_GDC_FOR_SM100 =
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1" and CUTLASS_ENABLE_GDC_FOR_SM90 =
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1" (or a function that returns the flag for a
given SM), then replace the literal occurrences in
gen_cutlass_fused_moe_sm120_module, gen_cutlass_fused_moe_sm103_module,
gen_cutlass_fused_moe_sm100_module and gen_cutlass_fused_moe_sm90_module with
those constants (or the helper) so all nvcc_flags lists reference the
centralized symbol instead of hardcoded strings.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8cf53880-75ce-4532-a6ec-473fae61abd7

📥 Commits

Reviewing files that changed from the base of the PR and between a5e5cae and b5dba94.

📒 Files selected for processing (2)
  • 3rdparty/cutlass
  • flashinfer/jit/fused_moe.py

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request upgrades the cutlass submodule to version 4.4.1 and adds compilation flags to enable Grid Dependent Control (GDC), which is intended to fix a crash on newer GPU architectures like SM120. The changes appear correct and address the described issue. I have one suggestion regarding code duplication to improve maintainability.

"-DENABLE_FP8",
"-DENABLE_FP4",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This flag -DCUTLASS_ENABLE_GDC_FOR_SM100=1 is also added to gen_cutlass_fused_moe_sm103_module and gen_cutlass_fused_moe_sm100_module. Since many flags are shared across these functions for Blackwell architectures, consider refactoring them into a common base list of flags to improve maintainability and reduce duplication.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46292785: 8/20 passed

@yzh119 yzh119 added the run-ci label Mar 17, 2026
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 17, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !419 has been created, and the CI pipeline #46371671 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46371671: 11/20 passed

@kahyunnam kahyunnam changed the title Upgrade cutlass 4.2.1 -> 4.4.1, also add "CUTLASS_ENABLE_GDC_" to cutlass compilation flags. Upgrade cutlass 4.2.1 -> 4.4.2, also add "CUTLASS_ENABLE_GDC_" to cutlass compilation flags. Mar 18, 2026
@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !419 has been updated with latest changes, and the CI pipeline #46402158 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 18, 2026

conceptually i don't have problem with it from code review standpoint

there are errors in the JIT unit test H100

also note that this is something that caused problems in the past and we wanna test thoroughly and watch out for the H100 and the Spark/RTX PRO 6000 related tests

context: #2737

so i won't approve until tests are clean on SM90 and SM120f

also cc @bkryu as extra set of eyes

@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !419 has been updated with latest changes, and the CI pipeline #46450640 is currently running. I'll report back once the pipeline job completes.

@kahyunnam kahyunnam changed the title Upgrade cutlass 4.2.1 -> 4.4.2, also add "CUTLASS_ENABLE_GDC_" to cutlass compilation flags. Upgrade cutlass 4.2.1 -> 4.4.2 Mar 18, 2026
@kahyunnam kahyunnam requested a review from jiahanc as a code owner March 19, 2026 00:43
@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot cancel

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

Unknown Command

Command /bot cancel is not recognized.

Use /bot help for available commands.

@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !419 has been updated with latest changes, and the CI pipeline #46479848 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #46479848: 13/20 passed

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 19, 2026

does xqa use cutlass?

is this a precision tolerance issue? (1 failed test on spark)

E   AssertionError: Batch validation failed: Total 4096 elements, only 4052 (98.9%) meet tolerance criteria, require at least 99.0%
    assert 0.9892578125 >= 0.99
/tmp/flashinfer/tests/attention/test_xqa.py:463: AssertionError: Batch validation failed: Total 4096 elements, only 4052 (98.9%) meet tolerance criteria, require at least 99.0%
=============================== warnings summary ===============================
../../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:435
  /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:435: UserWarning: 
      Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
      Minimum and Maximum cuda capability supported by this version of PyTorch is
      (8.0) - (12.0)
      
    queued_call()
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
-------- generated xml file: /tmp/junit/tests_attention_test_xqa.py.xml --------
=========================== short test summary info ============================
FAILED tests/attention/test_xqa.py::test_xqa[True-1.0-True-0.5-HND-8-128-16-4-1-512-True-input_type1-True-False]
= 1 failed,

@kahyunnam
Copy link
Copy Markdown
Collaborator Author

E AssertionError: Batch validation failed: Total 4096 elements, only 4052 (98.9%) meet tolerance criteria, require at least 99.0%
assert 0.9892578125 >= 0.99
/tmp/flashinfer/tests/attention/test_xqa.py:463: AssertionError: Batch validation failed: Total 4096 elements, only 4052 (98.9%) meet tolerance criteria, require at least 99.0%
=============================== warnings summary ===============================
../../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/init.py:435
/opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/init.py:435: UserWarning:
Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
Minimum and Maximum cuda capability supported by this version of PyTorch is
(8.0) - (12.0)

queued_call()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
-------- generated xml file: /tmp/junit/tests_attention_test_xqa.py.xml --------
=========================== short test summary info ============================
FAILED tests/attention/test_xqa.py::test_xqa[True-1.0-True-0.5-HND-8-128-16-4-1-512-True-input_type1-True-False]
= 1 failed,

@aleozlx 98.9 seems close enough to 99 where this is probably a tolerance issue, xqa does not use cutlass

@kahyunnam kahyunnam merged commit 9276e44 into flashinfer-ai:main Mar 19, 2026
30 checks passed
@kahyunnam kahyunnam deleted the knam/spark-cutlass-moe-fix branch March 19, 2026 18:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants