-
Notifications
You must be signed in to change notification settings - Fork 585
Update trtllm-gen fused moe routing kernel and add more kernels #1955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdd tile-based (non-power-of-two) tiling support to fused MoE routing via a compile-time Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant KernelParams
participant RoutingKernel
participant TileHelpers
Note over Caller,KernelParams: Initialize kernel params (isPow2, mTileTokensDim)
Caller->>KernelParams: setBaseParams(data)
alt isPow2 == true
RoutingKernel->>TileHelpers: divUpLog2(expert_count)
TileHelpers-->>RoutingKernel: numCta
RoutingKernel->>TileHelpers: mulLog2(idx)
TileHelpers-->>RoutingKernel: offset/permutedSize
else isPow2 == false
RoutingKernel->>TileHelpers: divUpTileN(expert_count, tileN)
TileHelpers-->>RoutingKernel: numCta
RoutingKernel->>TileHelpers: mulTileN(idx, tileN)
TileHelpers-->>RoutingKernel: offset/permutedSize
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
9d9ad95 to
7cd156d
Compare
f060ab9 to
e7ac015
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_batched_gemm_runner.cu (1)
423-450: MakeisValidConfigIndexhonor the WAR override
getValidConfigIndicesnow copies a config and patchesmValidK/mValidN/mValidMbefore callingisValidConfig, butisValidConfigIndexstill forwards the unmodified config. That means a config reported as valid bygetValidConfigIndices(or returned bygetDefaultValidConfigIndex) can immediately failisValidConfigIndex, breaking callers that double-check the chosen index. Please mirror the same WAR here so the validation helpers stay consistent withrun.- auto const& config = configs[configIndex]; - - return bmm.isValidConfig(config, gemmData); + auto myConfig = configs[configIndex]; + myConfig.mOptions.mValidK = k; + myConfig.mOptions.mValidN = gemmData.mProblemDimensions.mN; + myConfig.mOptions.mValidM = gemmData.mProblemDimensions.mM; + + return bmm.isValidConfig(myConfig, gemmData);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.his excluded by!**/gen/**
📒 Files selected for processing (23)
csrc/trtllm_batched_gemm_runner.cu(2 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu(5 hunks)csrc/trtllm_fused_moe_routing_deepseek.cu(2 hunks)csrc/trtllm_fused_moe_routing_llama4.cu(3 hunks)csrc/trtllm_fused_moe_routing_renormalize.cu(1 hunks)csrc/trtllm_fused_moe_runner.cu(4 hunks)flashinfer/artifacts.py(1 hunks)flashinfer/autotuner.py(1 hunks)flashinfer/fused_moe/core.py(6 hunks)flashinfer/jit/fused_moe.py(1 hunks)flashinfer/jit/gemm/core.py(2 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h(2 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h(3 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h(10 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h(4 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h(25 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h(11 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h(7 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h(2 hunks)include/flashinfer/trtllm/fused_moe/DevKernel.h(2 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh(6 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.h(7 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
🧰 Additional context used
🧬 Code graph analysis (9)
flashinfer/jit/fused_moe.py (1)
flashinfer/artifacts.py (1)
ArtifactPath(83-98)
csrc/trtllm_batched_gemm_runner.cu (2)
csrc/trtllm_gemm_runner.cu (8)
m(111-126)m(111-111)m(128-179)m(128-130)m(181-236)m(181-181)m(238-250)m(238-238)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)
isValidConfig(710-720)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
std(229-583)std(239-244)std(279-300)std(290-296)include/flashinfer/trtllm/fmha/kernelParams.h (4)
std(215-230)std(273-278)std(361-366)std(398-402)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h (1)
getTmemColStridePerGroup(99-103)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (4)
trtllm(58-63)gen(59-62)mExecPath(377-435)mInstanceIdx(380-380)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (6)
trtllm(82-87)gen(83-86)getShuffleBlockSize(602-608)string(438-440)string(445-447)mInstanceIdx(421-421)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (6)
trtllm(48-53)gen(49-52)gemm(147-152)gemmGatedAct(55-191)ActType(62-180)mOptions(213-213)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (1)
RouteImpl(28-57)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (2)
csrc/trtllm_batched_gemm_runner.cu (10)
run(156-259)run(156-164)run(261-275)run(261-265)run(277-293)run(277-283)run(295-310)run(295-299)getWorkspaceSizeInBytes(129-154)getWorkspaceSizeInBytes(129-131)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
std(229-583)std(239-244)std(279-300)std(290-296)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
gemm(32-297)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (3)
include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h (3)
flashinfer(36-38)gemm(42-347)gemm(468-488)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (6)
string(133-181)trtllm(38-271)gen(39-269)Dtype(43-268)dtypeIsBlockFmt(96-99)dtypeNumEltsPerSf(198-209)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
gemm(30-417)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
flashinfer/jit/gemm/core.py (2)
384-384: LGTM!The addition of the
-DTLLM_GEN_EXPORT_FLASHINFERflag is consistent with the existing preprocessor define pattern and is appropriately placed alongside related TRTLLM flags.
535-535: LGTM!The flag addition is consistent with the change at Line 384, ensuring both TRTLLM GEMM module variants have the same FlashInfer export configuration.
flashinfer/artifacts.py (1)
91-93: Artifact path update looks goodThe refreshed hash keeps the TRTLLM fused artifacts aligned with the latest cubin drop; no further action from my side.
flashinfer/jit/fused_moe.py (1)
235-242: Build flag alignment looks consistentAdding
-DTLLM_GEN_EXPORT_FLASHINFERand switching toTLLM_GEN_GEMM_CUBIN_PATHcleanly mirror the artifact rename; everything lines up with the updated loader flow.csrc/trtllm_fused_moe_runner.cu (1)
35-177: Nice touch on the routing metadataLetting
computeLog2fall back to-1and threadingmTileTokensDiminto every routing path keeps the pow2 and non-pow2 kernels in sync. Looks solid.csrc/trtllm_fused_moe_routing_renormalize.cu (1)
168-214: Tile-aware CTA math looks correctSplitting the CTA/count arithmetic between
divUpLog2/mulLog2and the newdivUpTileN/mulTileNhelpers is exactly what the non-power-of-two path needs. No issues spotted.
| int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), | ||
| supported_tile_nums.front(), supported_tile_nums.back()); | ||
|
|
||
| std::set<int32_t> selected_tile_nums = { | ||
| std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim, | ||
| std::min(supported_tile_nums.back(), tile_tokens_dim * 2), | ||
| std::min(supported_tile_nums.back(), tile_tokens_dim * 4)}; | ||
| auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); | ||
|
|
||
| std::set<int32_t> selected_tile_nums; | ||
| selected_tile_nums.insert(tile_tokens_dim); | ||
| if (std::next(it) != supported_tile_nums.end()) { | ||
| selected_tile_nums.insert(*std::next(it)); | ||
| if (std::next(std::next(it)) != supported_tile_nums.end()) { | ||
| selected_tile_nums.insert(*std::next(std::next(it))); | ||
| } | ||
| } | ||
| if (it != supported_tile_nums.begin()) { | ||
| selected_tile_nums.insert(*std::prev(it)); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against missing tile entry before iterating neighbors
tile_tokens_dim is clamped to the numeric range of supported_tile_nums, but the clamped value is not guaranteed to be present in the container. If the closest supported value is different (e.g., the list is {16, 24, 40, 64} and nextPowerOfTwo returns 32), std::find returns end(). The very next statement calls std::next(it), invoking undefined behaviour and potentially crashing the process. Please snap tile_tokens_dim to an actual supported entry (e.g., via std::lower_bound) before walking neighbours.
- auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
-
- std::set<int32_t> selected_tile_nums;
- selected_tile_nums.insert(tile_tokens_dim);
+ auto it =
+ std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
+ if (it == supported_tile_nums.end()) {
+ it = std::prev(supported_tile_nums.end());
+ }
+ tile_tokens_dim = *it;
+
+ std::set<int32_t> selected_tile_nums;
+ selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
selected_tile_nums.insert(*std::prev(it));
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), | |
| supported_tile_nums.front(), supported_tile_nums.back()); | |
| std::set<int32_t> selected_tile_nums = { | |
| std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim, | |
| std::min(supported_tile_nums.back(), tile_tokens_dim * 2), | |
| std::min(supported_tile_nums.back(), tile_tokens_dim * 4)}; | |
| auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); | |
| std::set<int32_t> selected_tile_nums; | |
| selected_tile_nums.insert(tile_tokens_dim); | |
| if (std::next(it) != supported_tile_nums.end()) { | |
| selected_tile_nums.insert(*std::next(it)); | |
| if (std::next(std::next(it)) != supported_tile_nums.end()) { | |
| selected_tile_nums.insert(*std::next(std::next(it))); | |
| } | |
| } | |
| if (it != supported_tile_nums.begin()) { | |
| selected_tile_nums.insert(*std::prev(it)); | |
| } | |
| int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), | |
| supported_tile_nums.front(), supported_tile_nums.back()); | |
| auto it = | |
| std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); | |
| if (it == supported_tile_nums.end()) { | |
| it = std::prev(supported_tile_nums.end()); | |
| } | |
| tile_tokens_dim = *it; | |
| std::set<int32_t> selected_tile_nums; | |
| selected_tile_nums.insert(tile_tokens_dim); | |
| if (std::next(it) != supported_tile_nums.end()) { | |
| selected_tile_nums.insert(*std::next(it)); | |
| if (std::next(std::next(it)) != supported_tile_nums.end()) { | |
| selected_tile_nums.insert(*std::next(std::next(it))); | |
| } | |
| } | |
| if (it != supported_tile_nums.begin()) { | |
| selected_tile_nums.insert(*std::prev(it)); | |
| } |
🤖 Prompt for AI Agents
In csrc/trtllm_fused_moe_kernel_launcher.cu around lines 67-82, the code uses
std::find on supported_tile_nums with a clamped tile_tokens_dim but that value
may not exist in the container and iterating neighbors from an end() iterator is
undefined; fix by snapping tile_tokens_dim to an actual supported entry via
std::lower_bound: use lower_bound to get the first element >= tile_tokens_dim,
if lower_bound == supported_tile_nums.end() set the iterator to std::prev(end())
(the largest supported value), otherwise keep the lower_bound result (or if you
want the closest neighbor compare the lower_bound and its previous element and
pick the nearer value), then set tile_tokens_dim = *it and proceed to insert *it
and its neighbors into selected_tile_nums while guarding with iterator boundary
checks (only call std::next/std::prev after verifying they are not
end()/begin()).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. We should make sure supported_tile_nums contains consecutive pow of 2 numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IwakuraRein, I appreciate you acknowledging the issue! However, I notice that the current implementation already includes non-power-of-2 values in supported_tile_nums (e.g., 192 at line 256), so strictly enforcing "consecutive powers of 2" would require broader changes.
More importantly, even if we constrain supported_tile_nums to powers of 2, the current code still has a vulnerability: nextPowerOfTwo(avg_tokens_per_expert) can produce a power-of-2 value that's clamped to the numeric range but doesn't exist in the container. For example:
- If
supported_tile_nums = {8, 16, 64, 128}(missing 32) - And
nextPowerOfTwo(...)returns 32 std::clamp(32, 8, 128)→ 32std::findreturnsend()→ ❌ UB when callingstd::next(it)
Recommendation:
I suggest implementing CodeRabbit's defensive fix using std::lower_bound, which gracefully handles any input:
auto it = std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
if (it == supported_tile_nums.end()) {
it = std::prev(it); // Use largest supported value
}
tile_tokens_dim = *it; // Snap to actual supported valueThis ensures robustness regardless of the supported tile set. If you want to enforce a design constraint (consecutive powers of 2), consider adding a debug assertion in the callers that construct supported_tile_nums, but the runtime logic should still be defensive.
Would you like me to open an issue to track adding this fix?
| std::stringstream ss; | ||
| ss << gemmGatedAct::dumpOptions(options) << ", "; | ||
| ss << "mBatchedM={}," << std::endl; | ||
| ss << "mBatchedN={}," << std::endl; | ||
| ss << gemmGatedAct::dumpOptions(options, dumpRuntimeParams) << ", "; | ||
| if (dumpRuntimeParams) { | ||
| ss << "mBatchedM={}," << std::endl; | ||
| ss << "mBatchedN={}," << std::endl; | ||
| } | ||
| ss << "mBatchMode=batchedGemm::BatchedGemmOptions::BatchMode(" | ||
| << static_cast<int32_t>(options.mBatchMode) << ")," << std::endl; | ||
| ss << "mNumBatches=" << options.mNumBatches << "," << std::endl; | ||
| ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; | ||
| ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; | ||
| ss << "mIsStaticBatch=" << options.mIsStaticBatch << "," << std::endl; | ||
| ss << "mNumTokens=" << options.mNumTokens << "," << std::endl; | ||
| if (dumpRuntimeParams) { | ||
| ss << "mNumBatches=" << options.mNumBatches << "," << std::endl; | ||
| } | ||
| ss << "mNumRegsPerThreadLoadB=" << options.mNumRegsPerThreadLoadB << "," << std::endl; | ||
| ss << "mNumRegsPerThreadLoadSfB=" << options.mNumRegsPerThreadLoadSfB << "," << std::endl; | ||
| if (dumpRuntimeParams) { | ||
| ss << "mNumTokens=" << options.mNumTokens << "," << std::endl; | ||
| } | ||
| ss << "mNumWarpsLoadB=" << options.mNumWarpsLoadB << "," << std::endl; | ||
| ss << "mNumWarpsLoadSfB=" << options.mNumWarpsLoadSfB << "," << std::endl; | ||
| ss << "mRouteImpl=batchedGemm::RouteImpl(" << static_cast<int32_t>(options.mRouteImpl) << ")," | ||
| << std::endl; | ||
| ss << "mRouteSfsImpl={batchedGemm::RouteImpl(" | ||
| << static_cast<int32_t>(options.mRouteSfsImpl.value()) << ")}," << std::endl; | ||
| ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; | ||
| ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; | ||
| ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," | ||
| << std::endl; | ||
| ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," | ||
| << std::endl; | ||
| ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; | ||
| ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl; | ||
| return ss.str(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore batched-dimension values in dumpOptions
dumpOptions now emits mBatchedM={} / mBatchedN={} regardless of the actual contents, which removes the batch-dimension information we rely on when debugging tuning logs. Please keep printing the real vectors.
- ss << "mBatchedM={}," << std::endl;
- ss << "mBatchedN={}," << std::endl;
+ ss << "mBatchedM={";
+ for (size_t i = 0; i < options.mBatchedM.size(); ++i) {
+ ss << (i ? ", " : "") << options.mBatchedM[i];
+ }
+ ss << "}," << std::endl;
+ ss << "mBatchedN={";
+ for (size_t i = 0; i < options.mBatchedN.size(); ++i) {
+ ss << (i ? ", " : "") << options.mBatchedN[i];
+ }
+ ss << "}," << std::endl;🤖 Prompt for AI Agents
In
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h
around lines 389 to 415, the dumpOptions implementation emits literal
placeholders "mBatchedM={}" and "mBatchedN={}" when dumpRuntimeParams is true,
removing the actual batched-dimension data; replace those placeholders with the
real contents of options.mBatchedM and options.mBatchedN so the vectors are
printed (format them the same way other vector fields are emitted, including
separators/newlines and trailing commas) — e.g. serialize each vector into the
stringstream instead of hard-coded braces, preserving the existing comma and
std::endl layout.
| namespace trtllm { | ||
| namespace gen { | ||
| class CudaRunner; | ||
| class GenCfg; | ||
| } // namespace gen | ||
| } // namespace trtllm | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the nested trtllm namespace block.
By adding namespace trtllm { … } inside namespace batchedGemm, every unqualified trtllm::gen in this file now resolves to batchedGemm::trtllm::gen, not ::trtllm::gen. As a result, the alias namespace tg = trtllm::gen; and members like trtllm::gen::CudaRunner* point at the new nested namespace, which only has the forward declarations and none of the real definitions (dtypeGetNumBits, Dtype, etc.). This breaks compilation immediately. Please drop this block (or move the forward declarations out to global scope / prefix uses with ::). One minimal fix is:
-namespace trtllm {
-namespace gen {
-class CudaRunner;
-class GenCfg;
-} // namespace gen
-} // namespace trtllmThat restores lookup to the existing ::trtllm::gen symbols provided via GemmOptions.h.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| namespace trtllm { | |
| namespace gen { | |
| class CudaRunner; | |
| class GenCfg; | |
| } // namespace gen | |
| } // namespace trtllm |
🤖 Prompt for AI Agents
In
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
around lines 48-54, the nested "namespace trtllm { namespace gen { ... } }"
block creates a batchedGemm::trtllm::gen shadow that breaks lookup of the real
::trtllm::gen symbols; remove that nested namespace block (or move the forward
declarations to the global scope or use ::trtllm::gen prefixes) so all
unqualified trtllm::gen references resolve to the existing ::trtllm::gen
definitions provided by GemmOptions.h.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
Show resolved
Hide resolved
| #define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ | ||
| if (data.mPaddingLog2 > 0) { \ | ||
| LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, true), kernel, numBlocks, numThreads, smemSize, \ | ||
| stream); \ | ||
| } else { \ | ||
| LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, false), kernel, numBlocks, numThreads, \ | ||
| smemSize, stream); \ | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep mPaddingLog2 == 0 on the pow2 path
The new LAUNCH_TILEN macro only treats mPaddingLog2 > 0 as a power-of-two configuration. Existing callers legitimately pass mPaddingLog2 == 0 (tile size 1) and rely on the previous pow2 code path. After this change those cases fall into the “tile-N” branch (KernelParams::isPow2 == false), but the host code never populates mTileTokensDim for pow2 flows, so the CUDA kernels dereference it as 0 and hit divUpTileN(count, 0) / mulTileN(..., 0) in RoutingKernel.cuh, yielding a divide-by-zero/device fault at runtime. Please keep zero in the pow2 branch (or initialise mTileTokensDim to 1 before launching).
- if (data.mPaddingLog2 > 0) {
+ if (data.mPaddingLog2 >= 0) {🤖 Prompt for AI Agents
In include/flashinfer/trtllm/fused_moe/DevKernel.h around lines 116-123, the
macro treats only mPaddingLog2 > 0 as the pow2 path causing mPaddingLog2 == 0
callers to take the tile-N branch and crash; change the pow2 check to include
zero (e.g., use data.mPaddingLog2 >= 0) so mPaddingLog2 == 0 follows the
pow2/LAUNCH_ESC(..., true) path, or alternatively ensure host code sets
mTileTokensDim = 1 for pow2 launches before invoking the macro.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/trtllm_batched_gemm_runner.cu(4 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/trtllm_batched_gemm_runner.cu (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)
isValidConfig(718-728)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (4)
csrc/trtllm_batched_gemm_runner.cu (10)
run(160-261)run(160-168)run(263-277)run(263-267)run(279-295)run(279-285)run(297-312)run(297-301)getWorkspaceSizeInBytes(129-158)getWorkspaceSizeInBytes(129-131)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
std(229-583)std(239-244)std(279-300)std(290-296)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (6)
if(1071-1079)string(438-440)string(445-447)loadCubinData(1494-1514)trtllm(82-87)gen(83-86)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h (3)
trtllm(28-90)gen(29-89)launchKernel(34-84)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (11)
csrc/trtllm_batched_gemm_runner.cu (4)
147-149: LGTM! Valid dimension initialization.The initialization of valid dimensions to match the full problem dimensions is correct and aligns with the documented default behavior.
246-248: LGTM! Consistent valid dimension initialization.The valid dimensions are correctly initialized before passing gemmData to the run method.
338-340: LGTM! Valid dimension initialization consistent across all methods.The initialization pattern is consistent with getWorkspaceSizeInBytes and run methods.
402-402: LGTM! Minor refactoring.Passing
configs[configIndex]directly instead of using an intermediate reference is a valid simplification.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (7)
27-32: LGTM! Conditional compilation for export modes.The conditional includes properly handle different export targets (FlashInfer vs standard).
76-87: LGTM! New valid dimension fields with clear documentation.The addition of mValidM/N/K fields with documentation explaining their purpose is well done. Default initialization to 0 is appropriate given these fields must be explicitly set by callers.
459-469: LGTM! New constructor and compile-time method.The constructor properly accepts cubin export and rotation parameters, and the generateAndCompileKernel method is appropriately guarded for non-export builds.
475-597: LGTM! Enhanced run method with proper implementation.The rewritten run method includes several improvements:
- Const reference parameter instead of copy (better performance)
- Complete implementation with module caching using context-aware keys
- Proper error handling and cleanup (unloading modules when no cache provided)
- Conditional compilation support for different build modes
696-713: LGTM! Proper propagation of valid dimensions.The method correctly propagates the valid dimension fields from BatchedGemmData to BatchedGemmOptions, ensuring they're available for validation.
718-728: LGTM! Config validation implementation.The isValidConfig method properly validates configurations by extracting options and checking them without modification (updateOptions=false).
800-802: LGTM! Private member variables.The mExportsCubin and mNumRotations members properly store the constructor parameters for later use.
| std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config, | ||
| BatchedGemmData const& data) const { | ||
| std::vector<size_t> workspaceSizes; | ||
|
|
||
| // Get options from config and data. | ||
| auto options = getOptionsFromConfigAndData(config, data); | ||
|
|
||
| if (options.mUseDeepSeekFp8 && options.mFusedAct) { | ||
| int32_t totalNumPaddedTokens = 0; | ||
| auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; | ||
| if (!options.mEnablesEarlyExit || options.mNumTokens == 0) { | ||
| for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { | ||
| totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM) | ||
| : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); | ||
| } | ||
| } else { | ||
| // Get tile in token dim. | ||
| auto tileTokensDim = batchM ? options.mTileM : options.mTileN; | ||
| totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; | ||
| } | ||
| } else { | ||
| // Get tile in token dim. | ||
| auto tileTokensDim = batchM ? options.mTileM : options.mTileN; | ||
| totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; | ||
| } | ||
|
|
||
| // Get options from config. | ||
| auto& options = config.mOptions; | ||
|
|
||
| int const tokenTile = batchM ? options.mTileM : options.mTileN; | ||
| // Get options from config. | ||
| auto& options = config.mOptions; | ||
|
|
||
| auto const numTokens = totalNumPaddedTokens; | ||
| auto const intermediateDim = batchM ? options.mN : options.mM; | ||
| auto const intermediateTile = batchM ? options.mTileN : options.mTileM; | ||
| int const tokenTile = batchM ? options.mTileM : options.mTileN; | ||
|
|
||
| auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); | ||
| auto const numTokens = totalNumPaddedTokens; | ||
| auto const intermediateDim = batchM ? options.mN : options.mM; | ||
| auto const intermediateTile = batchM ? options.mTileN : options.mTileM; | ||
|
|
||
| auto const numTilesToken = numTokens / tokenTile; | ||
| auto const numTilesInt = intermediateDim / intermediateTile; | ||
| auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); | ||
|
|
||
| // TODO: do we need to pad to 1024? | ||
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); | ||
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); | ||
| } | ||
| auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); | ||
|
|
||
| return workspaceSizes; | ||
| } | ||
| auto const numTilesToken = numTokens / tokenTile; | ||
| auto const numTilesInt = intermediateDim / intermediateTile; | ||
| auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); | ||
|
|
||
| //////////////////////////////////////////////////////////////////////////////////////////////////// | ||
| int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, | ||
| BatchedGemmData const& batchedGemmData, void* cudaStream, | ||
| int32_t /* multiProcessorCount */, bool usePdl, | ||
| std::optional<std::reference_wrapper<ModuleCache>> moduleCache) { | ||
| // Might be used. | ||
| (void)usePdl; | ||
| (void)moduleCache; | ||
| // Get options from config and data. | ||
| auto options = getOptionsFromConfigAndData(config, batchedGemmData); | ||
|
|
||
| bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; | ||
| bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 && | ||
| options.mDtypeB == tg::Dtype::E4m3; | ||
|
|
||
| auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData); | ||
| float* dPtrRowMax{nullptr}; | ||
| uint32_t* dPtrRowMaxBars{nullptr}; | ||
|
|
||
| // Set the completion barriers to 0 if needed. | ||
| if (useDeepSeekFp8 && options.mFusedAct) { | ||
| dPtrRowMax = reinterpret_cast<float*>(alignPtr(reinterpret_cast<char*>(workspace), 1024)); | ||
| dPtrRowMaxBars = reinterpret_cast<uint32_t*>( | ||
| alignPtr(reinterpret_cast<char*>(dPtrRowMax) + workspaceSizes[0], 1024)); | ||
| auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1], | ||
| reinterpret_cast<cudaStream_t>(cudaStream)); | ||
| if (err != cudaSuccess) { | ||
| return 1; | ||
| // TODO: do we need to pad to 1024? | ||
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); | ||
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); | ||
| } | ||
| } | ||
|
|
||
| auto [numCtaBatch, numCtaTile, numCtaInner] = | ||
| getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); | ||
| auto kernelParams = KernelParamsSetup::setKernelParams( | ||
| options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, | ||
| batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, | ||
| batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, | ||
| batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, | ||
| batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, | ||
| batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, | ||
| batchedGemmData.mInputBuffers.mPtrGatedActAlpha, | ||
| batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, | ||
| dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, | ||
| batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, | ||
| batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, | ||
| batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); | ||
|
|
||
| // The size of the grid. | ||
| std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner} | ||
| : std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner}; | ||
|
|
||
| #ifdef TLLM_GEN_EXPORT_INTERFACE | ||
| CUmodule cuModule; | ||
| CUfunction cuFunction; | ||
|
|
||
| auto fiModuleLoadData = [&](CUmodule* module) { | ||
| const std::string sha256 = config.mHash ? config.mHash : ""; | ||
| std::string fname_cubin = config.mFunctionName; | ||
| if (!fname_cubin.empty()) { | ||
| fname_cubin[0] = static_cast<char>(std::toupper(static_cast<unsigned char>(fname_cubin[0]))); | ||
| } | ||
| fname_cubin = tllm_gen_bmm_cubin_path + "/" + fname_cubin + ".cubin"; | ||
| std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256); | ||
| cuModuleLoadData(&cuModule, cubin.c_str()); | ||
| }; | ||
|
|
||
| if (moduleCache.has_value()) { | ||
| ModuleCache& moduleCacheRef = moduleCache.value().get(); | ||
|
|
||
| // Modules are associated with a specific context, so the context is included in the key | ||
| CUcontext ctx; | ||
| unsigned long long ctxId; | ||
| cuCtxGetCurrent(&ctx); | ||
| cuCtxGetId(ctx, &ctxId); | ||
|
|
||
| // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a | ||
| // string in decimal representation. | ||
| std::string const ctxName = | ||
| std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char)); | ||
| std::string const funcName = std::string(config.mFunctionName); | ||
| auto const moduleKey = ctxName + funcName; | ||
| auto module = moduleCacheRef.find(moduleKey); | ||
|
|
||
| // Use cache if module is found, otherwise load and insert into cache | ||
| if (module != moduleCacheRef.end()) { | ||
| cuFunction = std::get<1>(module->second); | ||
| } else { | ||
| fiModuleLoadData(&cuModule); | ||
| cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); | ||
| moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); | ||
| } | ||
| } else { | ||
| fiModuleLoadData(&cuModule); | ||
| cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); | ||
| return workspaceSizes; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable shadowing issue in workspace size calculation.
Line 767 creates a new reference auto& options = config.mOptions that shadows the options variable created on line 750 via getOptionsFromConfigAndData(). This means:
- Lines 752-764 use the complete options (with M/N/K from data)
- Lines 767-779 use only
config.mOptions(without runtime dimensions)
This shadowing likely loses the populated mM, mN, mK, mValidM, mValidN, mValidK, and other runtime fields from the data, which could lead to incorrect workspace size calculations.
Apply this diff to remove the shadowing:
- // Get options from config.
- auto& options = config.mOptions;
+ // Use the same options variable created earlier (no shadowing)Then use the existing options variable for the calculations that follow.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config, | |
| BatchedGemmData const& data) const { | |
| std::vector<size_t> workspaceSizes; | |
| // Get options from config and data. | |
| auto options = getOptionsFromConfigAndData(config, data); | |
| if (options.mUseDeepSeekFp8 && options.mFusedAct) { | |
| int32_t totalNumPaddedTokens = 0; | |
| auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; | |
| if (!options.mEnablesEarlyExit || options.mNumTokens == 0) { | |
| for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { | |
| totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM) | |
| : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); | |
| } | |
| } else { | |
| // Get tile in token dim. | |
| auto tileTokensDim = batchM ? options.mTileM : options.mTileN; | |
| totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; | |
| } | |
| } else { | |
| // Get tile in token dim. | |
| auto tileTokensDim = batchM ? options.mTileM : options.mTileN; | |
| totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; | |
| } | |
| // Get options from config. | |
| auto& options = config.mOptions; | |
| int const tokenTile = batchM ? options.mTileM : options.mTileN; | |
| // Get options from config. | |
| auto& options = config.mOptions; | |
| auto const numTokens = totalNumPaddedTokens; | |
| auto const intermediateDim = batchM ? options.mN : options.mM; | |
| auto const intermediateTile = batchM ? options.mTileN : options.mTileM; | |
| int const tokenTile = batchM ? options.mTileM : options.mTileN; | |
| auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); | |
| auto const numTokens = totalNumPaddedTokens; | |
| auto const intermediateDim = batchM ? options.mN : options.mM; | |
| auto const intermediateTile = batchM ? options.mTileN : options.mTileM; | |
| auto const numTilesToken = numTokens / tokenTile; | |
| auto const numTilesInt = intermediateDim / intermediateTile; | |
| auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); | |
| // TODO: do we need to pad to 1024? | |
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); | |
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); | |
| } | |
| auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); | |
| return workspaceSizes; | |
| } | |
| auto const numTilesToken = numTokens / tokenTile; | |
| auto const numTilesInt = intermediateDim / intermediateTile; | |
| auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); | |
| //////////////////////////////////////////////////////////////////////////////////////////////////// | |
| int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, | |
| BatchedGemmData const& batchedGemmData, void* cudaStream, | |
| int32_t /* multiProcessorCount */, bool usePdl, | |
| std::optional<std::reference_wrapper<ModuleCache>> moduleCache) { | |
| // Might be used. | |
| (void)usePdl; | |
| (void)moduleCache; | |
| // Get options from config and data. | |
| auto options = getOptionsFromConfigAndData(config, batchedGemmData); | |
| bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; | |
| bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 && | |
| options.mDtypeB == tg::Dtype::E4m3; | |
| auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData); | |
| float* dPtrRowMax{nullptr}; | |
| uint32_t* dPtrRowMaxBars{nullptr}; | |
| // Set the completion barriers to 0 if needed. | |
| if (useDeepSeekFp8 && options.mFusedAct) { | |
| dPtrRowMax = reinterpret_cast<float*>(alignPtr(reinterpret_cast<char*>(workspace), 1024)); | |
| dPtrRowMaxBars = reinterpret_cast<uint32_t*>( | |
| alignPtr(reinterpret_cast<char*>(dPtrRowMax) + workspaceSizes[0], 1024)); | |
| auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1], | |
| reinterpret_cast<cudaStream_t>(cudaStream)); | |
| if (err != cudaSuccess) { | |
| return 1; | |
| // TODO: do we need to pad to 1024? | |
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); | |
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); | |
| } | |
| } | |
| auto [numCtaBatch, numCtaTile, numCtaInner] = | |
| getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); | |
| auto kernelParams = KernelParamsSetup::setKernelParams( | |
| options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, | |
| batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, | |
| batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, | |
| batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, | |
| batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, | |
| batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, | |
| batchedGemmData.mInputBuffers.mPtrGatedActAlpha, | |
| batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, | |
| dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, | |
| batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, | |
| batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, | |
| batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); | |
| // The size of the grid. | |
| std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner} | |
| : std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner}; | |
| #ifdef TLLM_GEN_EXPORT_INTERFACE | |
| CUmodule cuModule; | |
| CUfunction cuFunction; | |
| auto fiModuleLoadData = [&](CUmodule* module) { | |
| const std::string sha256 = config.mHash ? config.mHash : ""; | |
| std::string fname_cubin = config.mFunctionName; | |
| if (!fname_cubin.empty()) { | |
| fname_cubin[0] = static_cast<char>(std::toupper(static_cast<unsigned char>(fname_cubin[0]))); | |
| } | |
| fname_cubin = tllm_gen_bmm_cubin_path + "/" + fname_cubin + ".cubin"; | |
| std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256); | |
| cuModuleLoadData(&cuModule, cubin.c_str()); | |
| }; | |
| if (moduleCache.has_value()) { | |
| ModuleCache& moduleCacheRef = moduleCache.value().get(); | |
| // Modules are associated with a specific context, so the context is included in the key | |
| CUcontext ctx; | |
| unsigned long long ctxId; | |
| cuCtxGetCurrent(&ctx); | |
| cuCtxGetId(ctx, &ctxId); | |
| // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a | |
| // string in decimal representation. | |
| std::string const ctxName = | |
| std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char)); | |
| std::string const funcName = std::string(config.mFunctionName); | |
| auto const moduleKey = ctxName + funcName; | |
| auto module = moduleCacheRef.find(moduleKey); | |
| // Use cache if module is found, otherwise load and insert into cache | |
| if (module != moduleCacheRef.end()) { | |
| cuFunction = std::get<1>(module->second); | |
| } else { | |
| fiModuleLoadData(&cuModule); | |
| cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); | |
| moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); | |
| } | |
| } else { | |
| fiModuleLoadData(&cuModule); | |
| cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); | |
| return workspaceSizes; | |
| } | |
| std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config, | |
| BatchedGemmData const& data) const { | |
| std::vector<size_t> workspaceSizes; | |
| // Get options from config and data. | |
| auto options = getOptionsFromConfigAndData(config, data); | |
| if (options.mUseDeepSeekFp8 && options.mFusedAct) { | |
| int32_t totalNumPaddedTokens = 0; | |
| auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; | |
| if (!options.mEnablesEarlyExit || options.mNumTokens == 0) { | |
| for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { | |
| totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM) | |
| : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); | |
| } | |
| } else { | |
| // Get tile in token dim. | |
| auto tileTokensDim = batchM ? options.mTileM : options.mTileN; | |
| totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; | |
| } | |
| // Use the same options variable created earlier (no shadowing) | |
| int const tokenTile = batchM ? options.mTileM : options.mTileN; | |
| auto const numTokens = totalNumPaddedTokens; | |
| auto const intermediateDim = batchM ? options.mN : options.mM; | |
| auto const intermediateTile = batchM ? options.mTileN : options.mTileM; | |
| auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); | |
| auto const numTilesToken = numTokens / tokenTile; | |
| auto const numTilesInt = intermediateDim / intermediateTile; | |
| auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); | |
| // TODO: do we need to pad to 1024? | |
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); | |
| workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); | |
| } | |
| return workspaceSizes; | |
| } |
🤖 Prompt for AI Agents
In
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
around lines 745-787, a new reference `auto& options = config.mOptions` (line
~767) shadows the earlier `options` obtained from getOptionsFromConfigAndData(),
causing loss of runtime-populated fields; remove that shadowing declaration and
reuse the original `options` variable for the subsequent calculations
(tokenTile, intermediateDim, intermediateTile, numBytesRowMax,
numBytesRowMaxBars, and workspace push_backs) so the runtime M/N/K/valid fields
from data are used when computing workspace sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
287-297: Restore the CLI entrypoint
Splitting the FP8/FP4 helpers removed the globalbench_trtllm_gen_fused_moe_autotuner, so executing the script now raises aNameError. Reintroduce a dispatcher that forwards to the specialised implementations.+def bench_trtllm_gen_fused_moe_autotuner( + tune_max_num_tokens: Optional[int], + quant_mode: str, + num_tokens: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + warmups: int, + iterations: int, +): + if quant_mode == "Fp8-Per-Tensor": + return bench_trtllm_gen_fused_moe_autotuner_fp8( + tune_max_num_tokens, + quant_mode, + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + warmups, + iterations, + ) + return bench_trtllm_gen_fused_moe_autotuner_fp4( + tune_max_num_tokens, + quant_mode, + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + warmups, + iterations, + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (4)
flashinfer/fused_moe/core.py (3)
RoutingMethodType(58-72)trtllm_fp4_block_scale_moe(1827-1961)trtllm_fp8_per_tensor_scale_moe(1661-1739)flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/testing/utils.py (1)
bench_gpu_time(972-1033)flashinfer/utils.py (1)
device_support_pdl(568-572)
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
29-29: Unused function argument: quant_mode
(ARG001)
67-89: Do not assign a lambda expression, use a def
Rewrite fn as a def
(E731)
| def fp8_quantize(x): | ||
| max = x.float().abs().nan_to_num().max() | ||
| scale = FLOAT8_E4M3_MAX / max | ||
| x = (x * scale).to(torch.float8_e4m3fn) | ||
| return x, 1.0 / scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard FP8 quantization against all-zero inputs
All-zero inputs make max zero, scale infinite, and the quantized tensor NaN (0 × ∞), which breaks the benchmark when buffers start cleared. Please handle the zero case before inverting the scale.
def fp8_quantize(x):
- max = x.float().abs().nan_to_num().max()
- scale = FLOAT8_E4M3_MAX / max
- x = (x * scale).to(torch.float8_e4m3fn)
- return x, 1.0 / scale
+ max_val = x.float().abs().nan_to_num().max()
+ if max_val == 0:
+ return torch.zeros_like(x, dtype=torch.float8_e4m3fn), 1.0
+ scale = FLOAT8_E4M3_MAX / max_val
+ quantized = (x * scale).to(torch.float8_e4m3fn)
+ return quantized, 1.0 / scale🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around lines 21 to 25, the
fp8_quantize function divides by max which can be zero for all-zero inputs;
guard against that by computing max = x.float().abs().nan_to_num().max(), then
check if max == 0 (or torch.isclose(max, torch.tensor(0., device=max.device)))
before inverting it; if it is zero, return the input cast to torch.float8_e4m3fn
(or an all-zero tensor of the same shape) and a safe inverse scale (e.g. 1.0),
otherwise compute scale = FLOAT8_E4M3_MAX / max and proceed with quantization
and return x.to(torch.float8_e4m3fn) and 1.0/scale.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
21-25: Guard FP8 quantization against all-zero inputs.All-zero inputs make
maxzero,scaleinfinite, and the quantized tensor NaN (0 × ∞), which breaks the benchmark when buffers start cleared. Please handle the zero case before inverting the scale.Apply this diff to fix the issue:
def fp8_quantize(x): - max = x.float().abs().nan_to_num().max() - scale = FLOAT8_E4M3_MAX / max - x = (x * scale).to(torch.float8_e4m3fn) - return x, 1.0 / scale + max_val = x.float().abs().nan_to_num().max() + if max_val == 0: + return torch.zeros_like(x, dtype=torch.float8_e4m3fn), 1.0 + scale = FLOAT8_E4M3_MAX / max_val + quantized = (x * scale).to(torch.float8_e4m3fn) + return quantized, 1.0 / scale
🧹 Nitpick comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
67-89: Consider refactoring lambda to a named function.The lambda expression is quite long (22 lines). While functional, a named function would improve readability and follow PEP 8 style guidelines.
Apply this diff to refactor:
+ def run_moe(): + return trtllm_fp8_per_tensor_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + w13, + output1_scale_scalar, + output1_scales_gate_scalar, + w2, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + 1.0, # routed_scaling_factor + False, # use_routing_scales_on_input + None, + RoutingMethodType.TopK.value, + enable_pdl, + tune_max_num_tokens + ) + - fn = lambda: trtllm_fp8_per_tensor_scale_moe( - routing_logits, - None, # routing_bias - hidden_states, - w13, - output1_scale_scalar, - output1_scales_gate_scalar, - w2, - output2_scale_scalar, - num_experts, - top_k, - None, # n_group - None, # topk_group - intermediate_size, - 0, # local_expert_offset - num_experts, - 1.0, # routed_scaling_factor - False, # use_routing_scales_on_input - None, - RoutingMethodType.TopK.value, - enable_pdl, - tune_max_num_tokens - ) + fn = run_moe
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (5)
flashinfer/fused_moe/core.py (3)
GatedActType(173-177)trtllm_fp4_block_scale_moe(1827-1961)trtllm_fp8_per_tensor_scale_moe(1661-1739)csrc/trtllm_fused_moe_kernel_launcher.cu (4)
trtllm_fp4_block_scale_moe(1177-1273)trtllm_fp4_block_scale_moe(1177-1190)trtllm_fp8_per_tensor_scale_moe(352-412)trtllm_fp8_per_tensor_scale_moe(352-360)flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/testing/utils.py (1)
bench_gpu_time(972-1033)flashinfer/utils.py (1)
device_support_pdl(568-572)
🪛 GitHub Actions: pre-commit
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
[error] 4-4: Unused import: torch.nn.functional. The import statement is present but not used in the code.
[error] 9-9: Removed import statement from flashinfer.fused_moe. The old import is replaced with a formatted multi-line import block.
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
29-29: Unused function argument: quant_mode
(ARG001)
67-89: Do not assign a lambda expression, use a def
Rewrite fn as a def
(E731)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (5)
12-12: LGTM!The new import for
trtllm_fp8_per_tensor_scale_moeis properly used in the FP8 benchmark function.
18-19: LGTM!The quantization constants are properly defined and will be used in the FP8 quantization logic.
237-246: LGTM!The benchmarking logic correctly runs with and without autotuning and collects timing measurements.
262-262: LGTM!The CLI correctly extends quantization mode choices to include the new FP8 per-tensor option.
287-310: LGTM!The routing logic correctly dispatches to the appropriate benchmark function based on the quantization mode, and all parameters are properly forwarded.
| import argparse | ||
| from typing import Optional, Literal | ||
| import torch | ||
| import torch.nn.functional as F |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import.
The import torch.nn.functional as F is not used anywhere in the code.
Apply this diff to remove the unused import:
-import torch.nn.functional as F📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import torch.nn.functional as F |
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 4-4: Unused import: torch.nn.functional. The import statement is present but not used in the code.
🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around line 4, the import
"torch.nn.functional as F" is unused; remove that import line from the file to
eliminate the unused dependency and clean up imports.
|
|
||
| def bench_trtllm_gen_fused_moe_autotuner_fp8( | ||
| tune_max_num_tokens: Optional[int], | ||
| quant_mode: Literal["Fp8-Per-Tensor"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused parameter.
The quant_mode parameter is not used within the function body. If it's intended for future use or external validation, consider adding a comment explaining its purpose.
Apply this diff if the parameter is not needed:
def bench_trtllm_gen_fused_moe_autotuner_fp8(
tune_max_num_tokens: Optional[int],
- quant_mode: Literal["Fp8-Per-Tensor"],
num_tokens: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
warmups: int,
iterations: int,
):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| quant_mode: Literal["Fp8-Per-Tensor"], | |
| def bench_trtllm_gen_fused_moe_autotuner_fp8( | |
| tune_max_num_tokens: Optional[int], | |
| num_tokens: int, | |
| num_experts: int, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| top_k: int, | |
| warmups: int, | |
| iterations: int, | |
| ): |
🧰 Tools
🪛 Ruff (0.14.3)
29-29: Unused function argument: quant_mode
(ARG001)
🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around line 29, the
function signature includes an unused parameter quant_mode:
Literal["Fp8-Per-Tensor"]; remove this parameter from the signature and any
references to it, or if it is intentionally reserved for future use, keep it but
add a clear comment above the parameter explaining its purpose and why it is
unused (e.g., "reserved for future quantization modes"); update any call sites
if you remove it to avoid breaking callers.
99eb4ec to
8768aad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
23-27: Guard FP8 quantization against zero max.Line 23 lets
scale = FLOAT8_E4M3_MAX / 0when the tensor is all zeros, sox * scaleturns into NaNs and the benchmark fails immediately. Bail out with zeros and a unit inverse-scale before dividing.Apply this diff to harden the quantizer:
def fp8_quantize(x): - max = x.float().abs().nan_to_num().max() - scale = FLOAT8_E4M3_MAX / max - x = (x * scale).to(torch.float8_e4m3fn) - return x, 1.0 / scale + max_val = x.float().abs().nan_to_num().amax() + if max_val.item() == 0: + zeros = torch.zeros_like(x, dtype=torch.float8_e4m3fn) + return zeros, torch.tensor(1.0, device=x.device, dtype=torch.float32) + scale = FLOAT8_E4M3_MAX / max_val + quantized = (x * scale).to(torch.float8_e4m3fn) + return quantized, scale.reciprocal()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (4)
flashinfer/fused_moe/core.py (3)
trtllm_fp4_block_scale_moe(1827-1961)trtllm_fp8_per_tensor_scale_moe(1661-1739)RoutingMethodType(58-72)flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/testing/utils.py (1)
bench_gpu_time(972-1033)flashinfer/utils.py (1)
device_support_pdl(568-572)
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
32-32: Unused function argument: quant_mode
(ARG001)
70-92: Do not assign a lambda expression, use a def
Rewrite fn as a def
(E731)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| output1_scale_scalar = torch.tensor( | ||
| [hidden_states_scale * w13_scale] * num_experts, device=device | ||
| ) | ||
| output1_scales_gate_scalar = torch.ones( | ||
| num_experts, device=device, dtype=torch.float32 | ||
| ) | ||
| output2_scale_scalar = torch.tensor( | ||
| [hidden_states_scale * w2_scale] * num_experts, device=device | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Construct FP8 scale vectors without CPU conversion errors.
Lines 60-68 call torch.tensor([...], device=device) on CUDA scalars, which raises TypeError: can't convert CUDA tensor to numpy(). That stops the FP8 path before benchmarking. Build the vectors on device without Python lists.
Apply this diff to keep the scales on CUDA:
- output1_scale_scalar = torch.tensor(
- [hidden_states_scale * w13_scale] * num_experts, device=device
- )
+ scale_prod_1 = (hidden_states_scale * w13_scale).item()
+ output1_scale_scalar = torch.full(
+ (num_experts,),
+ scale_prod_1,
+ device=device,
+ dtype=torch.float32,
+ )
@@
- output2_scale_scalar = torch.tensor(
- [hidden_states_scale * w2_scale] * num_experts, device=device
- )
+ scale_prod_2 = (hidden_states_scale * w2_scale).item()
+ output2_scale_scalar = torch.full(
+ (num_experts,),
+ scale_prod_2,
+ device=device,
+ dtype=torch.float32,
+ )🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around lines 60 to 68, the
code creates FP8 scale vectors using Python lists of CUDA scalars which triggers
"can't convert CUDA tensor to numpy()" on CUDA; replace those list constructions
with device-native tensor factories (e.g., use torch.full or torch.ones with
shape (num_experts,) and the desired dtype/device) to produce
output1_scale_scalar and output2_scale_scalar directly on the CUDA device (and
keep output1_scales_gate_scalar as torch.ones on device with correct dtype).
|
/bot run |
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: jiahanc <[email protected]>
703ed28 to
0e88417
Compare
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (10)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
116-123: Critical: Fix the pow2 condition to include zeroThis issue was previously flagged: the condition
data.mPaddingLog2 > 0excludesmPaddingLog2 == 0, causing it to fall into the tile-N branch (false). However,mPaddingLog2 == 0represents tile size 1 (2^0 = 1), which is a power of 2 and should follow the pow2 path. When zero takes the false branch,mTileTokensDimis not populated for pow2 flows, leading to divide-by-zero runtime faults in the CUDA kernels.Apply this diff:
- if (data.mPaddingLog2 > 0) { + if (data.mPaddingLog2 >= 0) {csrc/trtllm_fused_moe_kernel_launcher.cu (1)
66-81: Fix UB when tile not present; snap to nearest supported entry before walking neighbors.
std::findmay return end() (e.g., nextPowerOfTwo=32 but 32 not in list). Usingstd::next(it)on end() is UB. Replace withlower_bound, clamp iterator, and snaptile_tokens_dimto an actual element; also guard neighbors.- // assume supported_tile_nums is sorted - int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), - supported_tile_nums.front(), supported_tile_nums.back()); - auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); + // assume supported_tile_nums is sorted + int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), + supported_tile_nums.front(), supported_tile_nums.back()); + auto it = std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); + if (it == supported_tile_nums.end()) { it = std::prev(supported_tile_nums.end()); } + // Optionally pick the closer neighbor if lower_bound overshoots + if (it != supported_tile_nums.begin()) { + auto prev = std::prev(it); + if (std::abs(*prev - tile_tokens_dim) <= std::abs(*it - tile_tokens_dim)) it = prev; + } + tile_tokens_dim = *it; std::set<int32_t> selected_tile_nums; selected_tile_nums.insert(tile_tokens_dim); - if (std::next(it) != supported_tile_nums.end()) { + if (std::next(it) != supported_tile_nums.end()) { selected_tile_nums.insert(*std::next(it)); - if (std::next(std::next(it)) != supported_tile_nums.end()) { + if (std::next(std::next(it)) != supported_tile_nums.end()) { selected_tile_nums.insert(*std::next(std::next(it))); } } - if (it != supported_tile_nums.begin()) { + if (it != supported_tile_nums.begin()) { selected_tile_nums.insert(*std::prev(it)); }include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)
48-54: Remove nested trtllm::gen inside batchedGemm; it shadows real ::trtllm::gen.This creates
batchedGemm::trtllm::genand breaks lookups for the actual::trtllm::gen(dtype helpers, enums). Drop the nested block and reference global with::.-namespace trtllm { -namespace gen { -class CudaRunner; -class GenCfg; -} // namespace gen -} // namespace trtllm +// Forward declarations, when needed, should reference global namespace or use fully qualified ::trtllm::gen. +// (Remove nested shadow to keep tg = ::trtllm::gen valid.)And make the alias explicit:
-namespace tg = trtllm::gen; +namespace tg = ::trtllm::gen;Also applies to: 59-60
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)
58-64: Avoid nested trtllm::gen in batchedGemm; use global to prevent shadowing.Same shadowing issue as in GemmGatedActOptions.h. Remove the nested block and use
::trtllm::genin aliases.-namespace trtllm { -namespace gen { -class CudaRunner; -class GenCfg; -} // namespace gen -} // namespace trtllm +// Use ::trtllm::gen forward decls if needed; avoid introducing batchedGemm::trtllm::gen.-namespace tg = trtllm::gen; +namespace tg = ::trtllm::gen;Also applies to: 69-71
388-415: Restore printing of mBatchedM/mBatchedN in dumpOptions.Placeholders
{}drop essential tuning logs. Serialize the actual vectors whendumpRuntimeParamsis true.- if (dumpRuntimeParams) { - ss << "mBatchedM={}," << std::endl; - ss << "mBatchedN={}," << std::endl; - } + if (dumpRuntimeParams) { + ss << "mBatchedM={"; + for (size_t i = 0; i < options.mBatchedM.size(); ++i) { + ss << (i ? ", " : "") << options.mBatchedM[i]; + } + ss << "}," << std::endl; + ss << "mBatchedN={"; + for (size_t i = 0; i < options.mBatchedN.size(); ++i) { + ss << (i ? ", " : "") << options.mBatchedN[i]; + } + ss << "}," << std::endl; + }benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (3)
30-40: Remove or document the unusedquant_modeparameter.The
quant_modeparameter at line 32 is not used within the function body.If this parameter is reserved for future use, add a comment explaining its purpose. Otherwise, remove it:
def bench_trtllm_gen_fused_moe_autotuner_fp8( tune_max_num_tokens: Optional[int], - quant_mode: Literal["Fp8-Per-Tensor"], num_tokens: int, num_experts: int, hidden_size: int, intermediate_size: int, top_k: int, warmups: int, iterations: int, ):
23-27: Guard FP8 quantization against all-zero inputs.When the input tensor is all zeros,
maxbecomes zero,scalebecomes infinite, and the quantized result is NaN (0 × ∞). This will break benchmarking when buffers start cleared.Apply this diff to handle the zero case:
def fp8_quantize(x): - max = x.float().abs().nan_to_num().max() - scale = FLOAT8_E4M3_MAX / max - x = (x * scale).to(torch.float8_e4m3fn) - return x, 1.0 / scale + max_val = x.float().abs().nan_to_num().max() + if max_val == 0: + return torch.zeros_like(x, dtype=torch.float8_e4m3fn), 1.0 + scale = FLOAT8_E4M3_MAX / max_val + quantized = (x * scale).to(torch.float8_e4m3fn) + return quantized, 1.0 / scale
60-68: Construct FP8 scale vectors without CPU conversion errors.Lines 60-68 call
torch.tensor([...], device=device)on CUDA scalars, which raisesTypeError: can't convert CUDA tensor to numpy(). Build the vectors on device without Python lists.Apply this diff to keep scales on CUDA:
- output1_scale_scalar = torch.tensor( - [hidden_states_scale * w13_scale] * num_experts, device=device - ) + scale_prod_1 = (hidden_states_scale * w13_scale).item() + output1_scale_scalar = torch.full( + (num_experts,), + scale_prod_1, + device=device, + dtype=torch.float32, + ) @@ - output2_scale_scalar = torch.tensor( - [hidden_states_scale * w2_scale] * num_experts, device=device - ) + scale_prod_2 = (hidden_states_scale * w2_scale).item() + output2_scale_scalar = torch.full( + (num_experts,), + scale_prod_2, + device=device, + dtype=torch.float32, + )include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)
745-787: Fix variable shadowing that loses runtime dimensions in workspace calculation.Line 767 creates a new reference
auto& options = config.mOptionsthat shadows theoptionsvariable from line 750. This causes:
- Lines 752-764 to use complete options (with runtime M/N/K/valid fields from data)
- Lines 769-779 to use only
config.mOptions(missing runtime dimensions)The shadowing loses the populated
mM,mN,mK,mValidM,mValidN,mValidKfields, potentially leading to incorrect workspace size calculations.Apply this diff to remove the shadowing:
- // Get options from config. - auto& options = config.mOptions; + // Use the same options variable created earlier (no shadowing)Then use the existing
optionsvariable for all subsequent calculations.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (1)
652-665: Fix warning macro that prevents valid dimension clamping.When
options.mValidK > options.mK, theTLLM_LOG_WARNINGat line 655 fires. UnderTLLM_GEN_EXPORT_INTERFACE, this macro immediately returnsfalse, so the subsequent clamp operation at lines 659-661 never executes. This breaks callers relying onupdateOptions == trueto sanitize oversized valid dimensions.Replace the warning with a conditional log that doesn't trigger early return:
- if (options.mValidM > options.mM || options.mValidN > options.mN || - options.mValidK > options.mK) { - TLLM_LOG_WARNING( - options.mValidK <= options.mK, - "ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively."); - if (updateOptions) { - options.mValidM = std::min(options.mValidM, options.mM); - options.mValidN = std::min(options.mValidN, options.mN); - options.mValidK = std::min(options.mValidK, options.mK); - } else { - return false; - } - } + if (options.mValidM > options.mM || options.mValidN > options.mN || + options.mValidK > options.mK) { +#ifdef TLLM_GEN_DEBUG + printArgs("WARNING: ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively.\n"); +#endif + if (updateOptions) { + options.mValidM = std::min(options.mValidM, options.mM); + options.mValidN = std::min(options.mValidN, options.mN); + options.mValidK = std::min(options.mValidK, options.mK); + } else { + return false; + } + }
🧹 Nitpick comments (5)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
381-392: Tile sets look good; ensure sorted+unique before selection to avoid duplicates.When conditionally pushing 128/192/256, duplicates can appear across branches; keep the vector sorted and unique before passing to
computeSelectedTileN.- std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128, 192, 256}; + std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128, 192, 256}; + mSupportedTileN.erase(std::unique(mSupportedTileN.begin(), mSupportedTileN.end()), mSupportedTileN.end());- std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128}; + std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128}; + mSupportedTileN.erase(std::unique(mSupportedTileN.begin(), mSupportedTileN.end()), mSupportedTileN.end());- std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64}; + std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64}; ... // after conditional push_backs + std::sort(supported_tile_nums.begin(), supported_tile_nums.end()); + supported_tile_nums.erase(std::unique(supported_tile_nums.begin(), supported_tile_nums.end()), + supported_tile_nums.end());Also applies to: 730-740, 1322-1336
flashinfer/fused_moe/core.py (1)
125-133: Minor nit: duplicate enum in list.
MxE4m3appears twice intrtllm_gen_dtype_has_scale. Harmless; remove duplicate for clarity.- DtypeTrtllmGen.MxE4m3, DtypeTrtllmGen.E2m1, DtypeTrtllmGen.MxE2m1, - DtypeTrtllmGen.MxE4m3,include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)
154-169: New validation is good; message tweaks optional.DeepSeek FP8 and shuffled-MatrixA constraints on hidden/valid sizes are correct. Consider clarifying error text to include
hiddenSizeStrfor both fields.include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
53-57: Template split (isPow2_, UsePdl_) and new mTileTokensDim are wired correctly.Defaults (
mPaddingLog2=-1) are safe under non-pow2 path; Data->Params copy includesmTileTokensDim. Comment onmPtrPermutedIdxToTokenIdxshape still mentions a derived formula and may confuse readers.- // dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts] + // dim: [total_permuted_tokens]; actual size equals padded tokens count reported via mPtrPermutedIdxSize[0]Also applies to: 95-103, 104-157, 232-254, 274-301
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
70-92: Preferdefover lambda assignment for clarity.Assigning a lambda expression to a variable reduces readability and makes debugging harder (stack traces show
<lambda>instead of a meaningful function name).Refactor to use a proper function definition:
- fn = lambda: trtllm_fp8_per_tensor_scale_moe( - routing_logits, - None, # routing_bias - hidden_states, - w13, - output1_scale_scalar, - output1_scales_gate_scalar, - w2, - output2_scale_scalar, - num_experts, - top_k, - None, # n_group - None, # topk_group - intermediate_size, - 0, # local_expert_offset - num_experts, - 1.0, # routed_scaling_factor - False, # use_routing_scales_on_input - None, - RoutingMethodType.TopK.value, - enable_pdl, - num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, - ) + def fn(): + return trtllm_fp8_per_tensor_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + w13, + output1_scale_scalar, + output1_scales_gate_scalar, + w2, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + 1.0, # routed_scaling_factor + False, # use_routing_scales_on_input + None, + RoutingMethodType.TopK.value, + enable_pdl, + num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.his excluded by!**/gen/**
📒 Files selected for processing (25)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(4 hunks)csrc/trtllm_batched_gemm_runner.cu(5 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu(5 hunks)csrc/trtllm_fused_moe_routing_deepseek.cu(2 hunks)csrc/trtllm_fused_moe_routing_llama4.cu(3 hunks)csrc/trtllm_fused_moe_routing_renormalize.cu(1 hunks)csrc/trtllm_fused_moe_runner.cu(4 hunks)flashinfer/artifacts.py(2 hunks)flashinfer/autotuner.py(1 hunks)flashinfer/fused_moe/core.py(6 hunks)flashinfer/jit/fused_moe.py(1 hunks)flashinfer/jit/gemm/core.py(2 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h(2 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h(4 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h(10 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h(4 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h(25 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h(11 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h(7 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h(2 hunks)include/flashinfer/trtllm/fused_moe/DevKernel.h(2 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh(6 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.h(7 hunks)tests/moe/test_trtllm_gen_fused_moe.py(7 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
🚧 Files skipped from review as they are similar to previous changes (3)
- flashinfer/artifacts.py
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
- csrc/trtllm_fused_moe_runner.cu
🧰 Additional context used
🧬 Code graph analysis (10)
flashinfer/jit/fused_moe.py (1)
flashinfer/artifacts.py (1)
ArtifactPath(83-98)
csrc/trtllm_batched_gemm_runner.cu (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)
isValidConfig(718-728)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h (1)
getTmemColStridePerGroup(99-103)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3)
csrc/trtllm_batched_gemm_runner.cu (10)
run(160-261)run(160-168)run(263-277)run(263-267)run(279-295)run(279-285)run(297-312)run(297-301)getWorkspaceSizeInBytes(129-158)getWorkspaceSizeInBytes(129-131)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
std(229-583)std(239-244)std(279-300)std(290-296)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h (3)
trtllm(28-90)gen(29-89)launchKernel(34-84)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (3)
trtllm(48-53)gen(49-52)gemm(147-152)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (1)
RouteImpl(28-57)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (3)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (4)
trtllm(58-63)gen(59-62)mExecPath(377-435)mInstanceIdx(380-380)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (6)
trtllm(82-87)gen(83-86)getShuffleBlockSize(602-608)string(438-440)string(445-447)mInstanceIdx(421-421)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h (4)
getShuffleBlockSize(539-545)string(407-409)string(414-416)string(420-521)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (5)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
trtllm_fp4_block_scale_moe(1177-1273)trtllm_fp4_block_scale_moe(1177-1190)trtllm_fp8_per_tensor_scale_moe(352-412)trtllm_fp8_per_tensor_scale_moe(352-360)flashinfer/fused_moe/core.py (3)
trtllm_fp4_block_scale_moe(1827-1961)trtllm_fp8_per_tensor_scale_moe(1661-1739)RoutingMethodType(58-72)flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/testing/utils.py (1)
bench_gpu_time(972-1033)flashinfer/utils.py (1)
device_support_pdl(569-573)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
gemm(32-297)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (2)
RoutingMethodType(58-72)WeightLayout(161-168)include/flashinfer/trtllm/fused_moe/runner.h (1)
RoutingMethodType(37-136)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (3)
include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h (3)
flashinfer(36-38)gemm(42-347)gemm(468-488)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (6)
string(133-181)trtllm(38-271)gen(39-269)Dtype(43-268)dtypeIsBlockFmt(96-99)dtypeNumEltsPerSf(198-209)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
gemm(30-417)
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
32-32: Unused function argument: quant_mode
(ARG001)
70-92: Do not assign a lambda expression, use a def
Rewrite fn as a def
(E731)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (29)
flashinfer/autotuner.py (1)
479-506: LGTM! Improved error visibility.Including the exception details in the warning log enhances observability during profiling failures. The debug log on line 491 continues to provide full context with shapes and error details, while the warning now gives a quick overview of what went wrong.
include/flashinfer/trtllm/fused_moe/DevKernel.h (3)
127-132: LGTM: Clean replacement with preserved signaturesThe replacement of
LAUNCH_PDLwithLAUNCH_TILENcorrectly preserves type signatures and parameters for both FP32 and BF16 routing paths.
141-174: LGTM: Consistent replacement across all dtype combinationsThe replacement of
LAUNCH_PDLwithLAUNCH_TILENis correctly applied across all nine dtype combinations (score/bias/expW), preserving type signatures, parameters, and the extraFlag/numExperts propagation.
194-204: LGTM: Correct replacement for all expert-count pathsThe replacement of
LAUNCH_PDLwithLAUNCH_TILENis correctly applied across all four branches (FP32/BF16 × extraFlag1), preserving type signatures and the numExperts parameter.flashinfer/jit/gemm/core.py (2)
533-539: LGTM! Consistent flag addition across TRTLLM modules.Good to see the
-DTLLM_GEN_EXPORT_FLASHINFERflag added consistently to both the regular and low-latency TRTLLM GEMM modules. This ensures uniform behavior across both compilation paths.
382-387: The flag addition is not used by the compiled sources and should be reviewed for correctness.The
TLLM_GEN_EXPORT_FLASHINFERpreprocessor flag exists in the codebase but is only used within thebatched_gemm/trtllmGen_bmm_export/module, not in thegemm/trtllmGen_gemm_export/path that the compiled runners include. The gemm runners compile againstGemmInterface.hwhich usesTLLM_GEN_EXPORT_INTERFACEinstead, notTLLM_GEN_EXPORT_FLASHINFER. Adding this flag to the gemm runners' compilation will have no effect on their behavior.Verify whether:
- The flag should be added to a different compilation target (batched_gemm runners)
- The gemm runners should be using
TLLM_GEN_EXPORT_FLASHINFERinstead ofTLLM_GEN_EXPORT_INTERFACE- This flag addition was unintended
Likely an incorrect or invalid review comment.
flashinfer/jit/fused_moe.py (1)
236-236: LGTM: New export flag added.The addition of
-DTLLM_GEN_EXPORT_FLASHINFERis straightforward and aligns with the PR's objective to add this export flag for the TRTLLM fused MOE SM100 module.csrc/trtllm_fused_moe_kernel_launcher.cu (1)
560-566: Verify gemm1_output_scale shape (tokens dimension).Block-scale path uses
max_num_padded_tokensforgemm1_output_scalewhilegemm1_outputusesmax_num_padded_tokens_gemm1. Mismatch might over-allocate or later index past produced rows. Confirm intended dimension and align both if necessary.include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (1)
70-87: isPow2/tileN dual-path arithmetic looks consistent.
mulTileN/divUpTileN/divUpMulTileNare correct and used consistently in CTA count, mnLimit, and permuted size computations. No issues.Also applies to: 321-347, 361-369, 554-564, 586-596, 603-613
flashinfer/fused_moe/core.py (1)
1319-1347: Parameter passthrough of tune_max_num_tokens looks correct.Autotuner reconfiguration and op wrappers correctly forward
tune_max_num_tokens. No blocking issues.Also applies to: 1540-1567, 1712-1739, 1763-1824, 1922-1961, 2061-2100
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)
184-191: dumpOptions(runtime flag) and new config fields LGTM.Extended dump with
dumpRuntimeParamsand runtime wiring fields looks fine.Also applies to: 199-215
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (1)
237-255: DeepSeek FP8 valid-dimension checks are correct; LdgPlusSts guards are helpful.Constraint checks on (M/N/K)%128 and route Sfs impl combinations look consistent.
Also applies to: 276-284
csrc/trtllm_batched_gemm_runner.cu (1)
147-158: Good: propagate validM/N/K and explicitly gate split‑K.Setting valid dims before queries and requiring
mClusterDimZ==1avoids known failures; config sorting and fallback look fine.Also applies to: 246-255, 338-341, 449-452
csrc/trtllm_fused_moe_routing_renormalize.cu (1)
168-176: Tile/pow2 dual-path applied consistently in renormalize kernels.
divUp{Log2,TileN},mul{Log2,TileN}are used coherently for CTA count, mnLimit, and permuted size. Looks correct.Also applies to: 179-186, 193-203, 206-216
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (1)
34-36: Enum update aligns with new routing path.The extra variant plus helper keep the routing helpers consistent and makes the downstream checks straightforward. Looks good.
csrc/trtllm_fused_moe_routing_deepseek.cu (1)
396-421: Nice split between pow2 and tile flows.The constexpr branch cleanly swaps in the tile arithmetic without touching the existing log2 path. That should really help with the new non‑power-of-two kernels.
csrc/trtllm_fused_moe_routing_llama4.cu (1)
188-244: Thanks for mirroring the tile-aware logic here.The consistent min(mnLimit1, mnLimit2) handling across both branches is reassuring, especially with the new tile sizes.
tests/moe/test_trtllm_gen_fused_moe.py (1)
2088-2175: Great to see FP8 block-scale in the renorm matrix.The added parameter coverage (tiles, routing configs, GeGlu guard) should catch regressions once the new kernels land.
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
291-314: LGTM!The dispatch logic cleanly separates FP8 and FP4 benchmark paths based on the quantization mode. The implementation correctly routes to the appropriate function with all necessary parameters.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (2)
85-96: LGTM!The addition of
validM/N/Kparameters with default values of-1provides a clean backward-compatible interface. The initialization logic correctly defaults valid dimensions to their corresponding size dimensions when not explicitly provided.
109-209: LGTM!The shape/stride computation correctly distinguishes between padded dimensions (for strides) and valid dimensions (for shapes). This optimization reduces unnecessary memory traffic by clamping TMA shapes to the valid data range while maintaining correct stride calculations for the full allocated memory.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3)
76-87: LGTM!The addition of
mValidM/N/Kfields toProblemDimensionsis well-documented and properly initialized. The comment clearly explains their purpose: tracking the valid range of dimensions separately from padded dimensions due to alignment constraints.
461-469: LGTM!The new constructor properly initializes the
mExportsCubinandmNumRotationsmember variables. ThegenerateAndCompileKernelmethod is appropriately guarded with#ifndef TLLM_GEN_EXPORT_INTERFACEto keep compilation-related functionality separate from the export interface.
696-713: LGTM!The
getOptionsFromConfigAndDatamethod correctly populates all problem dimension fields, including the newly addedmValidM/N/Kfields from the problem dimensions data.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (5)
641-650: LGTM!The default initialization logic for
mValidM/N/Kis clean and correct. When these fields are not explicitly set (-1), they default to the full dimension sizes, which is the expected behavior for non-padded cases.
667-678: LGTM!The validation correctly prevents use of
validM/N/Kparameters withBlockMajorKlayout, as the swizzled memory layout is incompatible with the valid-dimension optimization. The check properly detects when any valid parameter differs from its corresponding size parameter.
1199-1201: LGTM!The validation correctly ensures both
mKandmValidKare multiples of 128 when using DeepSeek FP8. This is essential for the per-128-channel scaling mechanism to work correctly.
1014-1019: LGTM!The validation correctly requires both
mMandmValidMto be multiples of the shuffle block size when using shuffled matrix A. This ensures the memory reordering is consistent across both the allocated and valid data ranges.
1493-1514: LGTM!The
loadCubinDatafunction cleanly abstracts the two cubin loading paths:
- FlashInfer mode: loads from storage using the cubin loader with proper path construction
- TensorRT-LLM mode: uses embedded cubin data from
config.mDataThe
static_assertat line 1508 appropriately catches configuration errors at compile time when the required macro is missing.
| "-DENABLE_FP8", | ||
| "-DENABLE_FP4", | ||
| f'-DTLLM_GEN_BMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"', | ||
| f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify artifact path consistency after macro rename.
The macro has been renamed to TLLM_GEN_GEMM_CUBIN_PATH but still points to ArtifactPath.TRTLLM_GEN_BMM. From artifacts.py, there are two distinct artifact paths:
TRTLLM_GEN_BMM:"d60ea310477e47941ad9527c6d61c0531d82dd26/batched_gemm-ac83afb-96dfa16"TRTLLM_GEN_GEMM:"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
Please confirm whether TLLM_GEN_GEMM_CUBIN_PATH should point to the BMM artifact path (batched GEMM) or if it should be updated to use ArtifactPath.TRTLLM_GEN_GEMM.
Run the following script to verify which artifact path the C++ code expects:
🏁 Script executed:
#!/bin/bash
# Description: Check how TLLM_GEN_GEMM_CUBIN_PATH and TLLM_GEN_BMM_CUBIN_PATH
# are used in the C++ source files to determine the correct artifact path.
# Search for usage of the macro in C++ files
echo "=== Searching for TLLM_GEN_GEMM_CUBIN_PATH usage ==="
rg -n -C3 --type=cpp --type=cu 'TLLM_GEN_GEMM_CUBIN_PATH'
echo -e "\n=== Searching for TLLM_GEN_BMM_CUBIN_PATH usage ==="
rg -n -C3 --type=cpp --type=cu 'TLLM_GEN_BMM_CUBIN_PATH'
echo -e "\n=== Checking trtllm_batched_gemm_runner.cu (from compilation list) ==="
fd -t f 'trtllm_batched_gemm_runner.cu' --exec rg -n -C5 'CUBIN_PATH|batched.*gemm|TLLM_GEN'Length of output: 2345
🏁 Script executed:
#!/bin/bash
# Search for the macro names without type restrictions
echo "=== Searching for TLLM_GEN_GEMM_CUBIN_PATH in all files ==="
rg -n -C2 'TLLM_GEN_GEMM_CUBIN_PATH' --max-count=20
echo -e "\n=== Searching for TLLM_GEN_BMM_CUBIN_PATH in all files ==="
rg -n -C2 'TLLM_GEN_BMM_CUBIN_PATH' --max-count=20
echo -e "\n=== Checking fused_moe.py around line 241 ==="
fd -t f 'fused_moe.py' --exec cat {} \; | head -250 | tail -50
echo -e "\n=== Looking for artifact path TRTLLM_GEN_BMM definition ==="
rg -n 'TRTLLM_GEN_BMM|TRTLLM_GEN_GEMM' artifacts.py -A1 -B1Length of output: 5352
🏁 Script executed:
# Script to verify: search in codebase without skippingLength of output: 59
Update fused_moe.py line 241 to use the correct artifact path.
The macro TLLM_GEN_GEMM_CUBIN_PATH at line 241 is defined with the wrong artifact path. Evidence shows:
flashinfer/jit/gemm/core.py(lines 386, 537) correctly usesArtifactPath.TRTLLM_GEN_GEMMfor the same macro- Both C++ header files in
trtllm/gemm/trtllmGen_gemm_export/andtrtllm/batched_gemm/trtllmGen_bmm_export/expectTLLM_GEN_GEMM_CUBIN_PATH(lines 30–33 and 1497–1509 respectively) TLLM_GEN_BMM_CUBIN_PATHdoes not exist anywhere in the codebase
Line 241 should be:
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',🤖 Prompt for AI Agents
In flashinfer/jit/fused_moe.py around line 241, the macro definition uses the
wrong artifact path (ArtifactPath.TRTLLM_GEN_BMM) for TLLM_GEN_GEMM_CUBIN_PATH;
change it to use ArtifactPath.TRTLLM_GEN_GEMM instead, preserving the existing
string escaping/formatting so the line becomes the same f-string but referencing
ArtifactPath.TRTLLM_GEN_GEMM.
| if (numEpilogueWarps) { | ||
| extraGmemCMultiplier *= numEpilogueWarps / 4; | ||
| } | ||
|
|
||
| // Number of bytes to store the output in smem. | ||
| auto const numBytesSmemStoreC = usesSmemForGmemC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix shared-memory sizing for epilogue warps.
Multiplying by numEpilogueWarps / 4 truncates toward zero. For example, tileM=192 → numEpilogueWarps = 6, so the multiplier stays at 1 and we under-allocate the DSMEM buffers. Even worse, 1–3 warps become 0 and allocate nothing. We need a ceil-div against the 4‑warp baseline.
Please replace the scaling with something like:
- if (numEpilogueWarps) {
- extraGmemCMultiplier *= numEpilogueWarps / 4;
- }
+ if (numEpilogueWarps) {
+ auto const warpGroups = (numEpilogueWarps + 3) / 4;
+ extraGmemCMultiplier *= warpGroups;
+ }|
[FAILED] Pipeline #37954416: 12/17 passed |
Signed-off-by: Siyuan Fu <[email protected]>
|
/bot run |
|
@IwakuraRein is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
Show resolved
Hide resolved
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
/bot run |
|
[FAILED] Pipeline #37976772: 12/17 passed |
📌 Description
co-work with @IwakuraRein
computeSelectedTileNtune_max_num_tokensto FP8 per-tensor and FP8 block scaleTLLM_GEN_BMM_CUBIN_PATHtoTLLM_GEN_GEMM_CUBIN_PATHTLLM_GEN_EXPORT_FLASHINFERNOTE: split-k kernels are temporarily disabled as they cause failure in renormalize + expert 256 tests.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests