-
Notifications
You must be signed in to change notification settings - Fork 581
perf: TRT-LLM Gen finalize kernel optimization #2092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR introduces vectorized, Top-K unrolled finalization paths for fused MOE kernels by adding trait-driven type specializations for packed elements (bfloat16_2, bfloat16_4, half_4), nine trait template specializations across three unroll factors, and updates launch macros to propagate TopK values and expose TopKUnrollFactor as compile-time constants. Changes
Sequence DiagramsequenceDiagram
participant host as Host
participant run as run(Data, stream)
participant selector as LAUNCH_TOPK_EXPW
participant macroESC as LAUNCH_ESC
participant kernel as finalizeKernelVecLoad
host->>run: invoke with Data
run->>run: check TopK <= MaxTopK
run->>selector: call with data.topK
selector->>selector: select TopK ∈ {4, 2, 1}
selector->>macroESC: LAUNCH_ESC with selected TopK
macroESC->>macroESC: instantiate KernelParams<Type, TypeExpW, TopK, UsePdl>
macroESC->>kernel: launch with FinalizeTraits<TopK, TypeExpW>
rect rgb(200, 220, 240)
Note over kernel: Per-chunk vectorized loop
kernel->>kernel: load scales/indices per chunk
kernel->>kernel: __syncthreads()
kernel->>kernel: apply scale & accumulate
kernel->>kernel: repeat for all chunks
end
kernel-->>host: finalize complete
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on improving the performance of the TRT-LLM Gen Mixture-of-Experts (MoE) finalize kernel. The changes involve refactoring the kernel to leverage vectorized memory access patterns and shared memory for frequently accessed data, which collectively contribute to faster execution times, especially noticeable at lower batch sizes. The aim is to make the MoE layer more efficient within the TRT-LLM generation pipeline. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Tagging @ChristinaZ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant performance optimization for the TRT-LLM MoE finalize kernel. The changes leverage shared memory to cache expert scales and indices, reducing global memory traffic and improving performance, especially for smaller batch sizes. The implementation adds new template metaprogramming structures for handling vectorized loads with different unroll factors and a new dispatch macro to select the optimal kernel. My review identified a critical correctness bug in the initialization of default expert scales, which could lead to incorrect computations. I also found a minor issue with unreachable code in one of the new dispatch macros. Overall, the optimization strategy is sound, and with the suggested fixes, this will be a valuable performance improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
142-151: Unreachableelsebranch in LAUNCH_TOPK_EXPW due to% 1 == 0
data.topK % 1 == 0is always true for integraltopK, so the finalelse(and"Unsupported topK"warning) is unreachable. You can simplify the macro and drop the dead branch:-#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.topK % 4 == 0) { \ - LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \ - } else if (data.topK % 2 == 0) { \ - LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \ - } else if (data.topK % 1 == 0) { \ - LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported topK"); \ - } +#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.topK % 4 == 0) { \ + LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \ + } else if (data.topK % 2 == 0) { \ + LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \ + } else { \ + LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ + }This keeps behavior identical while removing unreachable code.
csrc/trtllm_fused_moe_dev_kernel.cu (1)
784-881: Bug: default scale initialization only sets the first element to 1 when expertWeightsPtr is nullIn
finalizeKernelVecLoad, whenparams.expertWeightsPtr == nullptryou initializescalePackedas:auto scalePacked = (params.expertWeightsPtr != nullptr) ? reinterpret_cast<ScalePackedType const*>( params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor] : ScalePackedType{TypeExpW(1.f)};For packed types (
float2/float4,half2/half_4,bfloat16_2/bfloat16_4), aggregate initialization with a single scalar sets only the first lane to1and leaves the remaining lanes zero‑initialized. That means, for each unrolled Top‑K chunk, only the first expert gets scale 1, while the others get scale 0, which is incorrect whenexpertWeightsPtris meant to default to all‑ones (same behavior as the scalar finalize kernel).This is a correctness bug for any
topK > 1whenexpertWeightsPtr == nullptrand the vectorized path is taken; contributions from non‑first experts in each chunk are effectively dropped.You can fix this by constructing a
ScaleArrayTypefilled with ones and reinterpreting it intoScalePackedType:- auto scalePacked = (params.expertWeightsPtr != nullptr) - ? reinterpret_cast<ScalePackedType const*>( - params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor] - : ScalePackedType{TypeExpW(1.f)}; + ScalePackedType scalePacked; + if (params.expertWeightsPtr != nullptr) { + scalePacked = reinterpret_cast<ScalePackedType const*>( + params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor]; + } else { + ScaleArrayType ones; + ones.fill(TypeExpW(1.f)); + scalePacked = *reinterpret_cast<ScalePackedType*>(&ones); + }This ensures every lane in the packed scale vector is initialized to 1.0, matching the semantics of the non‑vectorized finalize kernel.
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
784-881: Minor: heavy use ofreinterpret_castfor packed/array types may rely on alignment/aliasing assumptions
finalizeKernelVecLoadfrequently reinterprets between:
int*andIdxPackedType/IdxArrayTypeTypeExpW*andScalePackedType/ScaleArrayTypeInputElem*andfloat4(viavectorizedLoadPtx)This is a common pattern in high‑performance CUDA, but it does assume:
- Sufficient alignment for
IdxPackedType/ScalePackedTypeandfloat4loads.- That the packed structs and
cutlass::Arrayspecializations are layout‑compatible.If you start targeting architectures or toolchains with stricter alignment/aliasing semantics, it might be safer to:
- Explicitly enforce alignment on the underlying allocations for
expandedIdxToPermutedIdxandexpertWeightsPtr, and/or- Store shared‑memory buffers in the
cutlass::Arraytypes and only reinterpret at the global‑memory load/store boundary.Not a blocker, but worth keeping in mind if you hit obscure misalignment or performance issues in future.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/trtllm_fused_moe_dev_kernel.cu(5 hunks)flashinfer/jit/fused_moe.py(1 hunks)include/flashinfer/trtllm/fused_moe/DevKernel.h(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/trtllm/fused_moe/DevKernel.h
🧬 Code graph analysis (2)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
csrc/fmhaReduction.cu (1)
kernel(339-339)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (3)
void(76-251)void(322-325)void(331-387)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
include/flashinfer/trtllm/fused_moe/DevKernel.h (2)
119-140: LAUNCH_EXPW type/TopK dispatch looks consistentThe extended
LAUNCH_EXPWmacro cleanly threadstopKinto theKernelParams<Type, TypeExpW, TopKUnrollFactor, UsePdl>instantiation for all supported(mDtypeElt, mDtypeExpW)pairs. No issues from a correctness or call‑site perspective.
467-472: KernelParams TopKUnrollFactor integration looks correctAdding
TypeExpW_andTopKUnrollFactor_tofinalize::KernelParams(and exposingstatic constexpr int TopKUnrollFactor) matches howLAUNCH_EXPW/LAUNCH_TOPK_EXPWnow instantiate the params and howfinalizeKernelVecLoaduses the unroll factor at compile time. No issues here.flashinfer/jit/fused_moe.py (1)
50-59: Adding-lineinfoto sm100 fused MoE build is reasonableIncluding
"-lineinfo"in the sm100 Cutlass fused MoE nvcc flags is fine and should help debugging/profiling without changing kernel semantics. No additional concerns.csrc/trtllm_fused_moe_dev_kernel.cu (2)
677-749: MaxTopK and trait definitions are consistent with 1/2/4-way unrolling
MaxTopK = 64plus thebfloat16_2/bfloat16_4/half_4packed types andScaleTraitsStruct/FinalizeTraitsspecializations cover all expected(UnrollFactor ∈ {1,2,4}, TypeExpW ∈ {bf16,float,half})combinations. Layout and alignment choices look coherent with the 128‑bit load path infinalizeKernelVecLoad. No functional issues spotted here.
951-982: Finalizerun()dispatch with LAUNCH_TOPK_EXPW looks coherent
- DeepSeek path (
mUseDeepSeekFp8) now usesLAUNCH_TOPK_EXPWwithfinalizeDeepSeekKernel, which is compatible with the updatedKernelParams<Type, TypeExpW, TopKUnrollFactor, UsePdl>(even though the kernel only consumesType).- Non‑DeepSeek path:
- For small grids (
numBlocksX * numBlocksY < 1184), you keep the scalarfinalizeKerneland still dispatch viaLAUNCH_TOPK_EXPW, which is fine sinceTopKUnrollFactoris unused there.- For larger grids, you guard
data.topK <= MaxTopKbefore launchingfinalizeKernelVecLoadwithgridDim.x = data.numTokensandblockDim.x = FINALIZE_THREADS_PER_BLOCK, matching the kernel’s expectations (blockIdx.xas token index,stride = FINALIZE_THREADS_PER_BLOCK).Assuming upstream ensures
data.numTokensfits within CUDA’s grid‑dim limits, this dispatch logic looks correct.If you want to be extra safe, you can add a runtime check close to this dispatch to assert that
data.numTokens <= 65535(or whatever max grid.x you intend to support), mirroring the existing cap you already apply onnumBlocksYin the scalar path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
142-149:LAUNCH_TOPK_EXPWdispatch relies ontopKdivisibility assumptionsThe
% 4/% 2dispatch cleanly guarantees that the chosenTopKUnrollFactordividesdata.topK, which is what the vectorized finalize path expects. Ifdata.topKcan ever be zero or negative, consider adding an explicitFLASHINFER_CHECK(data.topK > 0, ...)at the call‑site or here to fail fast instead of silently falling back to the scalar unroll‑1 path.csrc/trtllm_fused_moe_dev_kernel.cu (3)
677-781: Packed type and trait wiring is sound; consider adding size/alignment static_assertsThe
MaxTopKconstant and theScaleTraitsStruct/FinalizeTraitsspecializations for unroll factors 1/2/4 acrosscutlass::bfloat16_t,float, andhalfline up with how the vectorized finalize path usesIdxPackedType,IdxArrayType,ScalePackedType, andScaleArrayType. To make this more robust against future refactors (or changes in CUDA/CUTLASS typedefs), you might want to add a few compile‑time checks, e.g.:static_assert(sizeof(IdxPackedType) == sizeof(IdxArrayType), "Idx pack/array size mismatch"); static_assert(sizeof(ScalePackedType) == sizeof(ScaleArrayType), "Scale pack/array size mismatch");(and similarly check alignment if you rely on vectorized loads). This would catch any accidental drift between the packed structs and their corresponding
cutlass::Arrayrepresentations.
788-830: Vectorized finalize kernel: TopKUnrollFactor invariants and alignment assumptionsThe new
finalizeKernelVecLoadpath looks logically consistent with the scalar finalize kernel:
TopKUnrollFactoris taken fromKernelParams::TopKUnrollFactor, statically asserted to be 1/2/4, and used to select the appropriateFinalizeTraits.- Shared memory buffers are sized as
MaxTopK / TopKUnrollFactor, and thekChunkIdx < params.topK / TopKUnrollFactorloops stay within bounds as long asparams.topK <= MaxTopKandparams.topK % TopKUnrollFactor == 0.- The per‑chunk prefetch of
expandedIdxToPermutedIdxandexpertWeightsPtrvia packed types, followed by the inner accumulation overTopKUnrollFactorwithpermutedIdx == -1checks, mirrors the scalar logic.Two points worth calling out:
Dependence on launch‑time guarantees
Correctness of the chunking depends onparams.topK % TopKUnrollFactor == 0. Right now this is enforced implicitly byLAUNCH_TOPK_EXPW(choosing 4/2/1 based on divisibility). It would be safer to either:
- add a defensive runtime check early in
finalizeKernelVecLoad(e.g.,assert(params.topK % TopKUnrollFactor == 0);orFLASHINFER_CHECK), or- document this invariant clearly in a comment near the kernel definition.
This helps avoid subtle issues if someone ever instantiatesKernelParamsmanually or bypassesLAUNCH_TOPK_EXPW.Vectorized load alignment assumptions
vectorizedLoadPtxperformsld.global.v4.f32on afloat4 const*, where the pointer comes from:auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; float4 input = vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));Here
inElemPtris acutlass::Array<Type, FINALIZE_ELEM_PER_THREAD> const*, and then you reinterpret its elements asfloat4. This assumes 16‑byte alignment of the underlying storage (both for the allocation and thecutlass::Arraylayout). If that alignment is not guaranteed by upstream allocations, the PTX vector load could become misaligned. It may be worth:
- explicitly documenting the alignment requirement on
params.inPtr, and/or- adding a debug‑mode assertion on pointer alignment before issuing the vector load.
Even if current callers satisfy this, making the assumption explicit will prevent future regressions.
Finally, note that when
useScale == falsetheScalePackedType{TypeExpW(1.f)}path is effectively dead (since you fall back toscale = 1.0flater). You could skip writingscaleArrSmemin that case to shave a few instructions, but that’s a micro‑optimization.Also applies to: 844-878
961-983: NewMaxTopKguard changes behavior for largetopKin the vectorized pathThe updated
run:
- Uses
LAUNCH_TOPK_EXPWfor both DeepSeek and standard finalize kernels, soTopKUnrollFactoris consistently propagated.- Switches to
finalizeKernelVecLoadwhennumBlocksX * numBlocksY >= 1184, but only after enforcingFLASHINFER_CHECK(data.topK <= MaxTopK)withMaxTopK = 64.This introduces a new runtime constraint: previously, large
topKvalues would still go through the scalarfinalizeKernel; now, for large grids andtopK > 64, you’ll trip the check instead of falling back to the scalar path.If TRT‑LLM configs can ever use
topK > 64in this code path, you may want to:
- Replace the
FLASHINFER_CHECKwith a conditional fallback:if (data.topK > MaxTopK) { dim3 numBlocks(numBlocksX, numBlocksY); LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); } else { // current vectorized path }
- Or, if the design genuinely forbids
topK > 64, document that constraint prominently (e.g., in a comment or higher‑level API checks) so users see it before hitting a runtime failure.Please confirm that all supported TRT‑LLM MoE configs satisfy
topK <= 64in this path.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/trtllm_fused_moe_dev_kernel.cu(5 hunks)include/flashinfer/trtllm/fused_moe/DevKernel.h(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/trtllm/fused_moe/DevKernel.h
🧬 Code graph analysis (2)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
csrc/fmhaReduction.cu (1)
kernel(339-339)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (3)
void(76-251)void(322-325)void(331-387)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/trtllm/fused_moe/DevKernel.h (2)
119-140: TopK unroll factor propagation throughLAUNCH_EXPWlooks consistentForwarding
topKintoLAUNCH_ESCso thatKernelParams<Type, TypeExpW, TopKUnrollFactor, UsePdl>can see the unroll factor as a compile‑time constant is coherent, and the dtype pair coverage matches the existing combinations of(mDtypeElt, mDtypeExpW). No functional issues spotted here.
465-470: Breaking change verified: template parameter addition may impact external direct instantiationsThe
finalize::KernelParamstemplate at lines 465-470 has the signature<typename Type_, typename TypeExpW_, int TopKUnrollFactor_, bool UsePdl_>. All internal instantiations use macros (e.g.,KernelParams<types, true>) that correctly expand to provide all four template arguments throughLAUNCH_ESC. However, if external code directly instantiated this template with the previous 3-parameter form<Type_, TypeExpW_, bool UsePdl_>, it will break and require updating.
|
/bot run |
<!-- .github/pull_request_template.md --> ## 📌 Description - Small optimization for TRT-LLM Gen MoE finalize kernel TopK=8, NumExperts=128, HiddenSize=4096 | BS | Baseline, us | Optimized, us | Speed-up | | ------------- | ------------- | ------------- | ------------- | | 256 | 11 | 6 | 1.83 | | 512 | 12 | 7 | 1.71 | | 1024 | 16 | 15 | 1.06 | | 4096 | 55 | 49 | 1.12 | | 8192 | 107 | 95 | 1.13 | | 16384 | 205 | 183 | 1.12 | <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enabled vectorized, Top-K unrolled finalize path for MOE (Mixture of Experts) kernel operations with improved performance. * Added support for multiple data types (bfloat16, float, half) with enhanced type specialization and packing. * Introduced runtime validation for TopK configurations (≤ 64) to ensure optimal vectorized execution. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
TopK=8, NumExperts=128, HiddenSize=4096
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit