Skip to content

feat: update trtllm-gen MoE cubins#2416

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
nekorobov:user/nkorobov/update-moe-cubins
Jan 28, 2026
Merged

feat: update trtllm-gen MoE cubins#2416
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
nekorobov:user/nkorobov/update-moe-cubins

Conversation

@nekorobov
Copy link
Collaborator

@nekorobov nekorobov commented Jan 26, 2026

📌 Description

Update TRT-LLM Gen MoE Rubins to include bias support for NVFP4 and element-wise fused activation function.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added structured sparsity support for matrix operations, enabling Any_2_4 and Pairwise_4_8 sparsity formats in batched computations.
  • Refactor

    • Reorganized configuration parameters to accommodate sparsity metadata handling.
    • Enhanced memory allocation and grid computation logic for sparsity-aware operations.
  • Chores

    • Updated kernel artifact binaries and associated checksums.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 26, 2026

Important

Review skipped

Review was skipped due to path filters

⛔ Files ignored due to path filters (1)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SparsityDecl.h is excluded by !**/gen/**

CodeRabbit blocks several paths by default. You can override this behavior by explicitly including those paths in the path filters. For example, including **/dist/** will override the default block on the dist directory, by removing the pattern from both the lists.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • 🔍 Trigger a full review
📝 Walkthrough

Walkthrough

This PR adds structured sparsity support to the batched GEMM operations, renames the mUseShuffledMatrixA field to mUseShuffledMatrix, and updates build artifact versions/checksums. Sparsity metadata is threaded through configuration options, kernel parameter setup, memory trait calculations, and TMA descriptor construction.

Changes

Cohort / File(s) Summary
Option field renaming fix
csrc/trtllm_batched_gemm_runner.cu
Corrected conditional to compare options.mUseShuffledMatrix instead of options.mUseShuffledMatrixA, fixing the gating logic for config selection.
Build artifacts
flashinfer/artifacts.py
Updated artifact path and checksum constants for TRTLLM_GEN_BMM to reflect new cubin builds.
Sparsity configuration layer
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h, GemmGatedActOptions.h, GemmOptions.h
Extended constructors with sparsity and block-size parameters; added mSparsityA, mNumRegsCopySparsityInfo, mSfBlockSizeB/C fields; renamed mUseShuffledMatrixA to mUseShuffledMatrix; updated validation logic for sparsity constraints and block-size alignment.
Sparsity interface & grid computation
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
Added sparsity metadata input buffer field (mPtrSparsityInfoA); introduced public helpers getLaunchGrid() and getNumCtas() for CTA grid derivation; updated kernel parameter wiring to pass sparsity info and bias pointers.
Sparsity kernel setup & metadata
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h, KernelParamsDecl.h, TmaDescriptor.h
Added makeTmaShapeStrideSparsityInfoA() template function and ptrSparsityInfoA parameter to kernel parameter setup; declared tmaSparsityInfoA tensor map field; extended TMA descriptor dtype handling to support UInt8 format for sparsity information.
Sparsity kernel traits & memory allocation
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h
Expanded constructor signature to include sparsityA and per-SF element counts; added sparsity-aware SMEM/TMEM allocation chunks (smemSparsityInfoA, tmemSparsityInfoA); introduced getSmemOffsetSparsityInfoA() and getTmemOffsetSparsityInfoA() accessors; updated memory bit calculations to account for sparsity.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • cyx-6
  • wenscarl
  • bkryu
  • nvmbreughe
  • djmmoss
  • yzh119

Poem

🐰 Sparsity threads through code with care,
Options, params, memory laid bare,
Shuffled matrices renamed with grace,
TMA descriptors find their place,
A dense feature, sparse and bright!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: update trtllm-gen MoE cubins' accurately summarizes the main change in the changeset, which updates artifact paths and checksums for TRT-LLM Gen MoE cubins and adds related configuration support.
Description check ✅ Passed The PR description addresses the core purpose (cubin updates for bias support and fused activation) but lacks detail on implementation specifics. Critical template sections like 'Related Issues' are incomplete, and the pre-commit and test checklist items are inconsistently marked.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the TRT-LLM Gen MoE batched GEMM functionality by integrating structured sparsity support for weights and improving the handling of NVFP4 operations, including bias and block scaling. It also introduces more flexibility for fused activation functions. These changes are backed by updated pre-compiled kernels and a comprehensive refactoring of internal validation and parameter management, aiming to boost performance and expand the capabilities of the GEMM operations.

Highlights

  • Sparsity Support for Batched GEMM: Introduced structured sparsity for matrix A (weights) in TRT-LLM Gen MoE batched GEMM operations. This includes support for Any_2_4 and Pairwise_4_8 sparsity patterns, requiring updates to input buffer definitions, TMA descriptors, shared memory allocation, and kernel parameter setup to handle sparsity information.
  • Enhanced NVFP4 and Fused Activation Handling: Improved support for NVFP4 operations, particularly concerning bias and block scaling factors. The validation logic for scaling block sizes (sfBlockSizeA, sfBlockSizeB, sfBlockSizeC) has been significantly refined. A new ActType::None has been added to allow for more flexible activation function handling, and related checks have been updated.
  • Updated TRT-LLM Gen MoE Cubins: The pull request updates the version hashes for TRTLLM_GEN_BMM in flashinfer/artifacts.py, indicating that new pre-compiled kernels (cubins) are being integrated. These new cubins likely contain the implementations for the added sparsity and enhanced NVFP4/fused activation features.
  • Refactoring and Validation Improvements: Several internal parameters and validation checks across BatchedGemmOptions.h, GemmOptions.h, and KernelTraits.h have been updated. This includes renaming mUseShuffledMatrixA to mUseShuffledMatrix for consistency, refining checks for LdgPlusSts and 2CTA BatchedGemm with routing, and centralizing logic for selecting/checking block sizes based on data types and sparsity.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces comprehensive support for structured sparsity and generalizes block scaling factors within the TRT-LLM Gen MoE cubins. This involves significant updates to data structures, kernel parameters, and validation logic across various header and source files. Key changes include the introduction of sparsity-aware dimension calculations, new parameters for handling sparsity information, and refactoring of grid dimension computation. The artifact paths and checksums have also been updated to reflect the new cubins.

Comment on lines +447 to +449
options, options.mM, options.mN, options.mK >> isSparseA, options.mTileM, options.mTileN,
options.mTileK >> isSparseA, MatrixType::MatrixA, options.mValidM, options.mValidN,
options.mValidK >> isSparseA);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Adjusting options.mK and options.mTileK by right-shifting with isSparseA (effectively dividing by 2 if sparse) is a critical change for correctly handling the K dimension in sparse matrix operations. This directly impacts the logical dimensions used by TMA and ensures proper memory access.

    auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc(
        options, options.mM, options.mN, options.mK >> isSparseA, options.mTileM, options.mTileN,
        options.mTileK >> isSparseA, MatrixType::MatrixA, options.mValidM, options.mValidN,

// Logical shape is [B, divUpMul(M, tileM), K].
// Logical strides are [divUpMul(M, tileM) * K, K, 1].
// If layoutA is MatrixLayout::MajorMn
// Logical shape is [B, divUpMul(M, tileM), K / S].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The introduction of K / S in the logical shape for MatrixLayout::MajorK when batchN is a critical change for correctly representing the dimensions of sparse matrices. This ensures that the effective K dimension is properly accounted for.

         Logical shape is [B, divUpMul(M, tileM), K / S].

Comment on lines +406 to +407
// totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM *
// options.mClusterDimX)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying the totalNumPaddedTokens calculation to include options.mClusterDimX is important for correctly handling multi-CTA GEMM configurations. This ensures accurate padding and token counting in such scenarios.

    //   totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM * options.mClusterDimX)

}

if (options.mClusterDimX > 1) {
TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N."); indicates that 2CTA Gemm now supports batchM. This is a notable expansion of functionality and should be verified with appropriate test cases.

Suggested change
TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N.");
TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N.");

Comment on lines +313 to +326
{
// Number of bytes for sparsity info in SMEM.
auto const numBytesSmemSparsityInfoA =
numStages * tileM * tg::getNumBytesSparsityInfo(sparsityA, tileK);
// Number of bytes alignment for sparsity info in SMEM.
auto const numBytesAlignmentSparsityInfoA = 1024;
// No need to reuse the first chunk.
auto const reuseChunksSmemSparsityInfoA = false;

// Add info.
smemChunkNames.emplace_back("smemSparsityInfoA");
numBytesAndAlignmentPerSmemChunk.emplace_back(
std::make_pair(numBytesSmemSparsityInfoA, numBytesAlignmentSparsityInfoA));
firstChunkReuseSmem.emplace_back(reuseChunksSmemSparsityInfoA);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The addition of SmemSparsityInfoA to shared memory allocation is necessary for handling sparsity information within the kernel. This ensures that the sparsity data is available for computation.

Suggested change
{
// Number of bytes for sparsity info in SMEM.
auto const numBytesSmemSparsityInfoA =
numStages * tileM * tg::getNumBytesSparsityInfo(sparsityA, tileK);
// Number of bytes alignment for sparsity info in SMEM.
auto const numBytesAlignmentSparsityInfoA = 1024;
// No need to reuse the first chunk.
auto const reuseChunksSmemSparsityInfoA = false;
// Add info.
smemChunkNames.emplace_back("smemSparsityInfoA");
numBytesAndAlignmentPerSmemChunk.emplace_back(
std::make_pair(numBytesSmemSparsityInfoA, numBytesAlignmentSparsityInfoA));
firstChunkReuseSmem.emplace_back(reuseChunksSmemSparsityInfoA);
// SmemSparsityInfoA
{
// Number of bytes for sparsity info in SMEM.
auto const numBytesSmemSparsityInfoA =
numStages * tileM * tg::getNumBytesSparsityInfo(sparsityA, tileK);
// Number of bytes alignment for sparsity info in SMEM.
auto const numBytesAlignmentSparsityInfoA = 1024;
// No need to reuse the first chunk.
auto const reuseChunksSmemSparsityInfoA = false;
// Add info.
smemChunkNames.emplace_back("smemSparsityInfoA");
numBytesAndAlignmentPerSmemChunk.emplace_back(
std::make_pair(numBytesSmemSparsityInfoA, numBytesAlignmentSparsityInfoA));
firstChunkReuseSmem.emplace_back(reuseChunksSmemSparsityInfoA);
}

Comment on lines +555 to +568
{
// Number of columns for the sparsity info for A (note: for Dense, this is 0).
auto const numTmemColsSparsityInfoA =
numStages * tg::getNumBytesSparsityInfo(sparsityA, tileK) / 4 /* bytes */;
// Number of columns for Sf alignment.
auto const numColsAlignmentSparsityInfoA = 2;
// No need to reuse TMEM.
auto const reuseChunksTmemSparsityInfoA = false;

// Add info.
tmemChunkNames.emplace_back("tmemSparsityInfoA");
numBytesAndAlignmentPerTmemChunk.emplace_back(
std::make_pair(numTmemColsSparsityInfoA, numColsAlignmentSparsityInfoA));
firstChunkReuseTmem.emplace_back(reuseChunksTmemSparsityInfoA);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The addition of Sparsity info for A to TMEM allocation is necessary for handling sparsity information during TMA operations. This ensures that sparsity data is correctly transferred and processed.

Suggested change
{
// Number of columns for the sparsity info for A (note: for Dense, this is 0).
auto const numTmemColsSparsityInfoA =
numStages * tg::getNumBytesSparsityInfo(sparsityA, tileK) / 4 /* bytes */;
// Number of columns for Sf alignment.
auto const numColsAlignmentSparsityInfoA = 2;
// No need to reuse TMEM.
auto const reuseChunksTmemSparsityInfoA = false;
// Add info.
tmemChunkNames.emplace_back("tmemSparsityInfoA");
numBytesAndAlignmentPerTmemChunk.emplace_back(
std::make_pair(numTmemColsSparsityInfoA, numColsAlignmentSparsityInfoA));
firstChunkReuseTmem.emplace_back(reuseChunksTmemSparsityInfoA);
// Sparsity info for A
{
// Number of columns for the sparsity info for A (note: for Dense, this is 0).
auto const numTmemColsSparsityInfoA =
numStages * tg::getNumBytesSparsityInfo(sparsityA, tileK) / 4 /* bytes */;
// Number of columns for Sf alignment.
auto const numColsAlignmentSparsityInfoA = 2;
// No need to reuse TMEM.
auto const reuseChunksTmemSparsityInfoA = false;
// Add info.
tmemChunkNames.emplace_back("tmemSparsityInfoA");
numBytesAndAlignmentPerTmemChunk.emplace_back(
std::make_pair(numTmemColsSparsityInfoA, numColsAlignmentSparsityInfoA));
firstChunkReuseTmem.emplace_back(reuseChunksTmemSparsityInfoA);
}

// If layoutA is MatrixLayout::MajorMn
// Logical shape is [B, divUpMul(M, tileM), K / S].
// Logical strides are [divUpMul(M, tileM) * K / S, K / S, 1].
// If layoutA is MatrixLayout::MajorMn (sparsity not supported)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding (sparsity not supported) to the MatrixLayout::MajorMn description is important for clarity and correctness, explicitly stating the limitations of sparsity with this layout.

      If layoutA is MatrixLayout::MajorMn (sparsity not supported)

Comment on lines +50 to 51
dtype == tg::Dtype::UInt8) {
tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding dtype == tg::Dtype::UInt8 to the tmaDataFormat check for CU_TENSOR_MAP_DATA_TYPE_UINT8 is necessary for correctly handling the data type of sparsity information in TMA descriptors.

  if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::MxE4m3 || dtype == tg::Dtype::UE8m0 ||
      dtype == tg::Dtype::UInt8) {

Comment on lines +84 to +85
(dtype == tg::Dtype::UE8m0 || dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::E2m1 ||
dtype == tg::Dtype::UInt8)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Including tg::Dtype::E2m1 and tg::Dtype::UInt8 in the swizzleType check for CU_TENSOR_MAP_SWIZZLE_NONE ensures that TMA swizzling is correctly applied (or not applied) for these data types, which are relevant for sparsity and block-scaled formats.

    } else if ((fastestDimTileSizeBytes % 16) == 0 &&
               (dtype == tg::Dtype::UE8m0 || dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::E2m1 ||
                dtype == tg::Dtype::UInt8)) {

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h`:
- Around line 348-351: The modulo check in BatchedGemmOptions (using
numEltsPerSfRoute derived from options.mSfBlockSizeA/B) can divide by zero when
mSfBlockSizeA or mSfBlockSizeB is 0; add a defensive guard that
asserts/validates the SF block size is > 0 (e.g., check options.mSfBlockSizeA/B
> 0 depending on batchM) before computing numEltsPerSfRoute and before the
TLLM_CHECK_ERROR call, and if it is zero produce a clear error via
TLLM_CHECK_ERROR (or an assert) indicating the SF block size must be positive so
the subsequent tileK % (numEltsPerSfRoute * 16) check is safe.

In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h`:
- Around line 152-156: The code computes hiddenGranularity = 4 *
options.mSfBlockSizeC without validating mSfBlockSizeC; add a pre-check that
options.mSfBlockSizeC > 0 (or hiddenGranularity != 0) before computing the
modulo and call TLLM_CHECK_ERROR (or an assertion) with a clear error message if
it is zero, then proceed to compute outHiddenSize and perform the existing
TLLM_CHECK_ERROR on outHiddenSize % hiddenGranularity == 0; reference
options.mDtypeC, options.mSfBlockSizeC, hiddenGranularity, outHiddenSize and
TLLM_CHECK_ERROR when adding the guard.
🧹 Nitpick comments (4)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)

79-81: Handle ActType::None in getActTypeName.
With the new enum value, returning “Unknown type” can be confusing in logs or errors. Consider adding an explicit case.

♻️ Suggested tweak
   switch (type) {
     case ActType::SwiGlu:
       return "SwiGlu";
     case ActType::GeGlu:
       return "GeGlu";
+    case ActType::None:
+      return "None";
     default:
       return "Unknown type";
   }
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)

142-151: Unused mmaK parameter.

The mmaK parameter was added to the function signature but is not used in the function body. If this parameter is intended for future use, consider adding a [[maybe_unused]] attribute or a comment explaining its purpose. Otherwise, remove it.

-inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind, int mmaK, bool isSparseA) {
+inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind, [[maybe_unused]] int mmaK, [[maybe_unused]] bool isSparseA) {
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (2)

84-88: Macro with dangling else can cause unexpected control flow.

This macro ends with else return false without braces, which can lead to subtle bugs when used inside an if-else chain without braces.

Consider wrapping in a do-while(0) or ensuring all usages are properly braced:

-#define GEMM_UPDATE_OR_ERROR(OPTION, VALUE) \
-  if (updateOptions) {                      \
-    OPTION = VALUE;                         \
-  } else                                    \
-    return false
+#define GEMM_UPDATE_OR_ERROR(OPTION, VALUE) \
+  do {                                      \
+    if (updateOptions) {                    \
+      OPTION = VALUE;                       \
+    } else {                                \
+      return false;                         \
+    }                                       \
+  } while (0)

604-627: Consider making shuffle maps const.

These shuffle maps appear to be constant lookup tables that should not be modified at runtime. Making them const would prevent accidental modification and enable compiler optimizations.

-inline std::vector<int> srcToDstBlk16RowMap =
+inline const std::vector<int> srcToDstBlk16RowMap =
       {
         0,  8,
         ...
       };
-inline std::vector<int> srcToDstBlk32RowMap =
+inline const std::vector<int> srcToDstBlk32RowMap =
       {
         0,  8, 16, 24,
         ...
       };

Alternatively, consider using std::array with constexpr for compile-time initialization:

inline constexpr std::array<int, 16> srcToDstBlk16RowMap = { ... };

Comment on lines +348 to +351
int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
TLLM_CHECK_ERROR(options.mTileK % (numEltsPerSfRoute * 16) == 0,
"tileK needs to be a multiple of 16 * numEltsPerSf (", numEltsPerSfRoute,
") = ", numEltsPerSfRoute * 16);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a defensive guard for zero sfBlockSize when routing SFs via TMA.
If mSfBlockSizeA/B is ever 0 in a block-format + TMA SF route, the modulo check will divide by zero. Consider asserting a positive block size before the modulo.

💡 Proposed guard
-        int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
+        int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
+        TLLM_CHECK_ERROR(numEltsPerSfRoute > 0,
+                         "sfBlockSizeA/B must be > 0 when routing SFs with TMA.");
         TLLM_CHECK_ERROR(options.mTileK % (numEltsPerSfRoute * 16) == 0,
                          "tileK needs to be a multiple of 16 * numEltsPerSf (", numEltsPerSfRoute,
                          ") = ", numEltsPerSfRoute * 16);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
TLLM_CHECK_ERROR(options.mTileK % (numEltsPerSfRoute * 16) == 0,
"tileK needs to be a multiple of 16 * numEltsPerSf (", numEltsPerSfRoute,
") = ", numEltsPerSfRoute * 16);
int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
TLLM_CHECK_ERROR(numEltsPerSfRoute > 0,
"sfBlockSizeA/B must be > 0 when routing SFs with TMA.");
TLLM_CHECK_ERROR(options.mTileK % (numEltsPerSfRoute * 16) == 0,
"tileK needs to be a multiple of 16 * numEltsPerSf (", numEltsPerSfRoute,
") = ", numEltsPerSfRoute * 16);
🤖 Prompt for AI Agents
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h`
around lines 348 - 351, The modulo check in BatchedGemmOptions (using
numEltsPerSfRoute derived from options.mSfBlockSizeA/B) can divide by zero when
mSfBlockSizeA or mSfBlockSizeB is 0; add a defensive guard that
asserts/validates the SF block size is > 0 (e.g., check options.mSfBlockSizeA/B
> 0 depending on batchM) before computing numEltsPerSfRoute and before the
TLLM_CHECK_ERROR call, and if it is zero produce a clear error via
TLLM_CHECK_ERROR (or an assert) indicating the SF block size must be positive so
the subsequent tileK % (numEltsPerSfRoute * 16) check is safe.

Comment on lines 152 to 156
if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) {
int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2;
int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC);
int const hiddenGranularity = 4 * options.mSfBlockSizeC;
TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize,
") must be a multiple of ", hiddenGranularity, " for block-scaled outputs.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a guard for zero mSfBlockSizeC.
hiddenGranularity now derives from mSfBlockSizeC; if it is 0, the modulo check will divide by zero. Consider asserting it > 0 before use.

💡 Proposed guard
   if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) {
     int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2;
     int const hiddenGranularity = 4 * options.mSfBlockSizeC;
+    TLLM_CHECK_ERROR(options.mSfBlockSizeC > 0,
+                     "sfBlockSizeC must be > 0 for block-scaled outputs.");
     TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize,
                      ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs.");
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) {
int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2;
int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC);
int const hiddenGranularity = 4 * options.mSfBlockSizeC;
TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize,
") must be a multiple of ", hiddenGranularity, " for block-scaled outputs.");
if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) {
int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2;
int const hiddenGranularity = 4 * options.mSfBlockSizeC;
TLLM_CHECK_ERROR(options.mSfBlockSizeC > 0,
"sfBlockSizeC must be > 0 for block-scaled outputs.");
TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize,
") must be a multiple of ", hiddenGranularity, " for block-scaled outputs.");
}
🤖 Prompt for AI Agents
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h`
around lines 152 - 156, The code computes hiddenGranularity = 4 *
options.mSfBlockSizeC without validating mSfBlockSizeC; add a pre-check that
options.mSfBlockSizeC > 0 (or hiddenGranularity != 0) before computing the
modulo and call TLLM_CHECK_ERROR (or an assertion) with a clear error message if
it is zero, then proceed to compute outHiddenSize and perform the existing
TLLM_CHECK_ERROR on outHiddenSize % hiddenGranularity == 0; reference
options.mDtypeC, options.mSfBlockSizeC, hiddenGranularity, outHiddenSize and
TLLM_CHECK_ERROR when adding the guard.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 27, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !270 has been created, and the CI pipeline #42635521 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 27, 2026

See some compilation errors such as:

/workspace/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h:27:10: fatal error: trtllm/gen/SparsityDecl.h: No such file or directory
   27 | #include "trtllm/gen/SparsityDecl.h"
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~~

Can you double checking?

Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
@nekorobov
Copy link
Collaborator Author

@yzh119 , apologies, forgot to add new file to the commit. Should be fixed now

@yzh119
Copy link
Collaborator

yzh119 commented Jan 27, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !270 has been updated with latest changes, and the CI pipeline #42669928 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #42669928: 10/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants