feat: update trtllm-gen MoE cubins#2416
Conversation
Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
|
Important Review skippedReview was skipped due to path filters ⛔ Files ignored due to path filters (1)
CodeRabbit blocks several paths by default. You can override this behavior by explicitly including those paths in the path filters. For example, including You can disable this status message by setting the
📝 WalkthroughWalkthroughThis PR adds structured sparsity support to the batched GEMM operations, renames the Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the TRT-LLM Gen MoE batched GEMM functionality by integrating structured sparsity support for weights and improving the handling of NVFP4 operations, including bias and block scaling. It also introduces more flexibility for fused activation functions. These changes are backed by updated pre-compiled kernels and a comprehensive refactoring of internal validation and parameter management, aiming to boost performance and expand the capabilities of the GEMM operations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces comprehensive support for structured sparsity and generalizes block scaling factors within the TRT-LLM Gen MoE cubins. This involves significant updates to data structures, kernel parameters, and validation logic across various header and source files. Key changes include the introduction of sparsity-aware dimension calculations, new parameters for handling sparsity information, and refactoring of grid dimension computation. The artifact paths and checksums have also been updated to reflect the new cubins.
| options, options.mM, options.mN, options.mK >> isSparseA, options.mTileM, options.mTileN, | ||
| options.mTileK >> isSparseA, MatrixType::MatrixA, options.mValidM, options.mValidN, | ||
| options.mValidK >> isSparseA); |
There was a problem hiding this comment.
Adjusting options.mK and options.mTileK by right-shifting with isSparseA (effectively dividing by 2 if sparse) is a critical change for correctly handling the K dimension in sparse matrix operations. This directly impacts the logical dimensions used by TMA and ensures proper memory access.
auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc(
options, options.mM, options.mN, options.mK >> isSparseA, options.mTileM, options.mTileN,
options.mTileK >> isSparseA, MatrixType::MatrixA, options.mValidM, options.mValidN,| // Logical shape is [B, divUpMul(M, tileM), K]. | ||
| // Logical strides are [divUpMul(M, tileM) * K, K, 1]. | ||
| // If layoutA is MatrixLayout::MajorMn | ||
| // Logical shape is [B, divUpMul(M, tileM), K / S]. |
There was a problem hiding this comment.
| // totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM * | ||
| // options.mClusterDimX) |
There was a problem hiding this comment.
Modifying the totalNumPaddedTokens calculation to include options.mClusterDimX is important for correctly handling multi-CTA GEMM configurations. This ensures accurate padding and token counting in such scenarios.
// totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM * options.mClusterDimX)| } | ||
|
|
||
| if (options.mClusterDimX > 1) { | ||
| TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N."); |
There was a problem hiding this comment.
The removal of TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N."); indicates that 2CTA Gemm now supports batchM. This is a notable expansion of functionality and should be verified with appropriate test cases.
| TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N."); | |
| TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N."); |
| { | ||
| // Number of bytes for sparsity info in SMEM. | ||
| auto const numBytesSmemSparsityInfoA = | ||
| numStages * tileM * tg::getNumBytesSparsityInfo(sparsityA, tileK); | ||
| // Number of bytes alignment for sparsity info in SMEM. | ||
| auto const numBytesAlignmentSparsityInfoA = 1024; | ||
| // No need to reuse the first chunk. | ||
| auto const reuseChunksSmemSparsityInfoA = false; | ||
|
|
||
| // Add info. | ||
| smemChunkNames.emplace_back("smemSparsityInfoA"); | ||
| numBytesAndAlignmentPerSmemChunk.emplace_back( | ||
| std::make_pair(numBytesSmemSparsityInfoA, numBytesAlignmentSparsityInfoA)); | ||
| firstChunkReuseSmem.emplace_back(reuseChunksSmemSparsityInfoA); |
There was a problem hiding this comment.
The addition of SmemSparsityInfoA to shared memory allocation is necessary for handling sparsity information within the kernel. This ensures that the sparsity data is available for computation.
| { | |
| // Number of bytes for sparsity info in SMEM. | |
| auto const numBytesSmemSparsityInfoA = | |
| numStages * tileM * tg::getNumBytesSparsityInfo(sparsityA, tileK); | |
| // Number of bytes alignment for sparsity info in SMEM. | |
| auto const numBytesAlignmentSparsityInfoA = 1024; | |
| // No need to reuse the first chunk. | |
| auto const reuseChunksSmemSparsityInfoA = false; | |
| // Add info. | |
| smemChunkNames.emplace_back("smemSparsityInfoA"); | |
| numBytesAndAlignmentPerSmemChunk.emplace_back( | |
| std::make_pair(numBytesSmemSparsityInfoA, numBytesAlignmentSparsityInfoA)); | |
| firstChunkReuseSmem.emplace_back(reuseChunksSmemSparsityInfoA); | |
| // SmemSparsityInfoA | |
| { | |
| // Number of bytes for sparsity info in SMEM. | |
| auto const numBytesSmemSparsityInfoA = | |
| numStages * tileM * tg::getNumBytesSparsityInfo(sparsityA, tileK); | |
| // Number of bytes alignment for sparsity info in SMEM. | |
| auto const numBytesAlignmentSparsityInfoA = 1024; | |
| // No need to reuse the first chunk. | |
| auto const reuseChunksSmemSparsityInfoA = false; | |
| // Add info. | |
| smemChunkNames.emplace_back("smemSparsityInfoA"); | |
| numBytesAndAlignmentPerSmemChunk.emplace_back( | |
| std::make_pair(numBytesSmemSparsityInfoA, numBytesAlignmentSparsityInfoA)); | |
| firstChunkReuseSmem.emplace_back(reuseChunksSmemSparsityInfoA); | |
| } |
| { | ||
| // Number of columns for the sparsity info for A (note: for Dense, this is 0). | ||
| auto const numTmemColsSparsityInfoA = | ||
| numStages * tg::getNumBytesSparsityInfo(sparsityA, tileK) / 4 /* bytes */; | ||
| // Number of columns for Sf alignment. | ||
| auto const numColsAlignmentSparsityInfoA = 2; | ||
| // No need to reuse TMEM. | ||
| auto const reuseChunksTmemSparsityInfoA = false; | ||
|
|
||
| // Add info. | ||
| tmemChunkNames.emplace_back("tmemSparsityInfoA"); | ||
| numBytesAndAlignmentPerTmemChunk.emplace_back( | ||
| std::make_pair(numTmemColsSparsityInfoA, numColsAlignmentSparsityInfoA)); | ||
| firstChunkReuseTmem.emplace_back(reuseChunksTmemSparsityInfoA); |
There was a problem hiding this comment.
The addition of Sparsity info for A to TMEM allocation is necessary for handling sparsity information during TMA operations. This ensures that sparsity data is correctly transferred and processed.
| { | |
| // Number of columns for the sparsity info for A (note: for Dense, this is 0). | |
| auto const numTmemColsSparsityInfoA = | |
| numStages * tg::getNumBytesSparsityInfo(sparsityA, tileK) / 4 /* bytes */; | |
| // Number of columns for Sf alignment. | |
| auto const numColsAlignmentSparsityInfoA = 2; | |
| // No need to reuse TMEM. | |
| auto const reuseChunksTmemSparsityInfoA = false; | |
| // Add info. | |
| tmemChunkNames.emplace_back("tmemSparsityInfoA"); | |
| numBytesAndAlignmentPerTmemChunk.emplace_back( | |
| std::make_pair(numTmemColsSparsityInfoA, numColsAlignmentSparsityInfoA)); | |
| firstChunkReuseTmem.emplace_back(reuseChunksTmemSparsityInfoA); | |
| // Sparsity info for A | |
| { | |
| // Number of columns for the sparsity info for A (note: for Dense, this is 0). | |
| auto const numTmemColsSparsityInfoA = | |
| numStages * tg::getNumBytesSparsityInfo(sparsityA, tileK) / 4 /* bytes */; | |
| // Number of columns for Sf alignment. | |
| auto const numColsAlignmentSparsityInfoA = 2; | |
| // No need to reuse TMEM. | |
| auto const reuseChunksTmemSparsityInfoA = false; | |
| // Add info. | |
| tmemChunkNames.emplace_back("tmemSparsityInfoA"); | |
| numBytesAndAlignmentPerTmemChunk.emplace_back( | |
| std::make_pair(numTmemColsSparsityInfoA, numColsAlignmentSparsityInfoA)); | |
| firstChunkReuseTmem.emplace_back(reuseChunksTmemSparsityInfoA); | |
| } |
| // If layoutA is MatrixLayout::MajorMn | ||
| // Logical shape is [B, divUpMul(M, tileM), K / S]. | ||
| // Logical strides are [divUpMul(M, tileM) * K / S, K / S, 1]. | ||
| // If layoutA is MatrixLayout::MajorMn (sparsity not supported) |
| dtype == tg::Dtype::UInt8) { | ||
| tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT8; |
There was a problem hiding this comment.
Adding dtype == tg::Dtype::UInt8 to the tmaDataFormat check for CU_TENSOR_MAP_DATA_TYPE_UINT8 is necessary for correctly handling the data type of sparsity information in TMA descriptors.
if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::MxE4m3 || dtype == tg::Dtype::UE8m0 ||
dtype == tg::Dtype::UInt8) {| (dtype == tg::Dtype::UE8m0 || dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::E2m1 || | ||
| dtype == tg::Dtype::UInt8)) { |
There was a problem hiding this comment.
Including tg::Dtype::E2m1 and tg::Dtype::UInt8 in the swizzleType check for CU_TENSOR_MAP_SWIZZLE_NONE ensures that TMA swizzling is correctly applied (or not applied) for these data types, which are relevant for sparsity and block-scaled formats.
} else if ((fastestDimTileSizeBytes % 16) == 0 &&
(dtype == tg::Dtype::UE8m0 || dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::E2m1 ||
dtype == tg::Dtype::UInt8)) {There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h`:
- Around line 348-351: The modulo check in BatchedGemmOptions (using
numEltsPerSfRoute derived from options.mSfBlockSizeA/B) can divide by zero when
mSfBlockSizeA or mSfBlockSizeB is 0; add a defensive guard that
asserts/validates the SF block size is > 0 (e.g., check options.mSfBlockSizeA/B
> 0 depending on batchM) before computing numEltsPerSfRoute and before the
TLLM_CHECK_ERROR call, and if it is zero produce a clear error via
TLLM_CHECK_ERROR (or an assert) indicating the SF block size must be positive so
the subsequent tileK % (numEltsPerSfRoute * 16) check is safe.
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h`:
- Around line 152-156: The code computes hiddenGranularity = 4 *
options.mSfBlockSizeC without validating mSfBlockSizeC; add a pre-check that
options.mSfBlockSizeC > 0 (or hiddenGranularity != 0) before computing the
modulo and call TLLM_CHECK_ERROR (or an assertion) with a clear error message if
it is zero, then proceed to compute outHiddenSize and perform the existing
TLLM_CHECK_ERROR on outHiddenSize % hiddenGranularity == 0; reference
options.mDtypeC, options.mSfBlockSizeC, hiddenGranularity, outHiddenSize and
TLLM_CHECK_ERROR when adding the guard.
🧹 Nitpick comments (4)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)
79-81: HandleActType::NoneingetActTypeName.
With the new enum value, returning “Unknown type” can be confusing in logs or errors. Consider adding an explicit case.♻️ Suggested tweak
switch (type) { case ActType::SwiGlu: return "SwiGlu"; case ActType::GeGlu: return "GeGlu"; + case ActType::None: + return "None"; default: return "Unknown type"; }include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
142-151: UnusedmmaKparameter.The
mmaKparameter was added to the function signature but is not used in the function body. If this parameter is intended for future use, consider adding a[[maybe_unused]]attribute or a comment explaining its purpose. Otherwise, remove it.-inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind, int mmaK, bool isSparseA) { +inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind, [[maybe_unused]] int mmaK, [[maybe_unused]] bool isSparseA) {include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (2)
84-88: Macro with danglingelsecan cause unexpected control flow.This macro ends with
else return falsewithout braces, which can lead to subtle bugs when used inside anif-elsechain without braces.Consider wrapping in a do-while(0) or ensuring all usages are properly braced:
-#define GEMM_UPDATE_OR_ERROR(OPTION, VALUE) \ - if (updateOptions) { \ - OPTION = VALUE; \ - } else \ - return false +#define GEMM_UPDATE_OR_ERROR(OPTION, VALUE) \ + do { \ + if (updateOptions) { \ + OPTION = VALUE; \ + } else { \ + return false; \ + } \ + } while (0)
604-627: Consider making shuffle mapsconst.These shuffle maps appear to be constant lookup tables that should not be modified at runtime. Making them
constwould prevent accidental modification and enable compiler optimizations.-inline std::vector<int> srcToDstBlk16RowMap = +inline const std::vector<int> srcToDstBlk16RowMap = { 0, 8, ... }; -inline std::vector<int> srcToDstBlk32RowMap = +inline const std::vector<int> srcToDstBlk32RowMap = { 0, 8, 16, 24, ... };Alternatively, consider using
std::arraywithconstexprfor compile-time initialization:inline constexpr std::array<int, 16> srcToDstBlk16RowMap = { ... };
| int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB; | ||
| TLLM_CHECK_ERROR(options.mTileK % (numEltsPerSfRoute * 16) == 0, | ||
| "tileK needs to be a multiple of 16 * numEltsPerSf (", numEltsPerSfRoute, | ||
| ") = ", numEltsPerSfRoute * 16); |
There was a problem hiding this comment.
Add a defensive guard for zero sfBlockSize when routing SFs via TMA.
If mSfBlockSizeA/B is ever 0 in a block-format + TMA SF route, the modulo check will divide by zero. Consider asserting a positive block size before the modulo.
💡 Proposed guard
- int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
+ int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
+ TLLM_CHECK_ERROR(numEltsPerSfRoute > 0,
+ "sfBlockSizeA/B must be > 0 when routing SFs with TMA.");
TLLM_CHECK_ERROR(options.mTileK % (numEltsPerSfRoute * 16) == 0,
"tileK needs to be a multiple of 16 * numEltsPerSf (", numEltsPerSfRoute,
") = ", numEltsPerSfRoute * 16);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB; | |
| TLLM_CHECK_ERROR(options.mTileK % (numEltsPerSfRoute * 16) == 0, | |
| "tileK needs to be a multiple of 16 * numEltsPerSf (", numEltsPerSfRoute, | |
| ") = ", numEltsPerSfRoute * 16); | |
| int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB; | |
| TLLM_CHECK_ERROR(numEltsPerSfRoute > 0, | |
| "sfBlockSizeA/B must be > 0 when routing SFs with TMA."); | |
| TLLM_CHECK_ERROR(options.mTileK % (numEltsPerSfRoute * 16) == 0, | |
| "tileK needs to be a multiple of 16 * numEltsPerSf (", numEltsPerSfRoute, | |
| ") = ", numEltsPerSfRoute * 16); |
🤖 Prompt for AI Agents
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h`
around lines 348 - 351, The modulo check in BatchedGemmOptions (using
numEltsPerSfRoute derived from options.mSfBlockSizeA/B) can divide by zero when
mSfBlockSizeA or mSfBlockSizeB is 0; add a defensive guard that
asserts/validates the SF block size is > 0 (e.g., check options.mSfBlockSizeA/B
> 0 depending on batchM) before computing numEltsPerSfRoute and before the
TLLM_CHECK_ERROR call, and if it is zero produce a clear error via
TLLM_CHECK_ERROR (or an assert) indicating the SF block size must be positive so
the subsequent tileK % (numEltsPerSfRoute * 16) check is safe.
| if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { | ||
| int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2; | ||
| int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); | ||
| int const hiddenGranularity = 4 * options.mSfBlockSizeC; | ||
| TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize, | ||
| ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); |
There was a problem hiding this comment.
Add a guard for zero mSfBlockSizeC.
hiddenGranularity now derives from mSfBlockSizeC; if it is 0, the modulo check will divide by zero. Consider asserting it > 0 before use.
💡 Proposed guard
if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) {
int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2;
int const hiddenGranularity = 4 * options.mSfBlockSizeC;
+ TLLM_CHECK_ERROR(options.mSfBlockSizeC > 0,
+ "sfBlockSizeC must be > 0 for block-scaled outputs.");
TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize,
") must be a multiple of ", hiddenGranularity, " for block-scaled outputs.");
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { | |
| int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2; | |
| int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); | |
| int const hiddenGranularity = 4 * options.mSfBlockSizeC; | |
| TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize, | |
| ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); | |
| if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { | |
| int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2; | |
| int const hiddenGranularity = 4 * options.mSfBlockSizeC; | |
| TLLM_CHECK_ERROR(options.mSfBlockSizeC > 0, | |
| "sfBlockSizeC must be > 0 for block-scaled outputs."); | |
| TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize, | |
| ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); | |
| } |
🤖 Prompt for AI Agents
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h`
around lines 152 - 156, The code computes hiddenGranularity = 4 *
options.mSfBlockSizeC without validating mSfBlockSizeC; add a pre-check that
options.mSfBlockSizeC > 0 (or hiddenGranularity != 0) before computing the
modulo and call TLLM_CHECK_ERROR (or an assertion) with a clear error message if
it is zero, then proceed to compute outHiddenSize and perform the existing
TLLM_CHECK_ERROR on outHiddenSize % hiddenGranularity == 0; reference
options.mDtypeC, options.mSfBlockSizeC, hiddenGranularity, outHiddenSize and
TLLM_CHECK_ERROR when adding the guard.
|
/bot run |
|
See some compilation errors such as: Can you double checking? |
Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
|
@yzh119 , apologies, forgot to add new file to the commit. Should be fixed now |
|
/bot run |
|
[FAILED] Pipeline #42669928: 10/20 passed |
📌 Description
Update TRT-LLM Gen MoE Rubins to include bias support for NVFP4 and element-wise fused activation function.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Chores
✏️ Tip: You can customize this high-level summary in your review settings.