Skip to content

Conversation

@nekorobov
Copy link
Contributor

@nekorobov nekorobov commented Nov 14, 2025

📌 Description

  • Small optimization for TRT-LLM Gen MoE finalize kernel

TopK=8, NumExperts=128, HiddenSize=4096

BS Baseline, us Optimized, us Speed-up
256 11 6 1.83
512 12 7 1.71
1024 16 15 1.06
4096 55 49 1.12
8192 107 95 1.13
16384 205 183 1.12

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Enabled vectorized, Top-K unrolled finalize path for MOE (Mixture of Experts) kernel operations with improved performance.
    • Added support for multiple data types (bfloat16, float, half) with enhanced type specialization and packing.
    • Introduced runtime validation for TopK configurations (≤ 64) to ensure optimal vectorized execution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 14, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

This PR introduces vectorized, Top-K unrolled finalization paths for fused MOE kernels by adding trait-driven type specializations for packed elements (bfloat16_2, bfloat16_4, half_4), nine trait template specializations across three unroll factors, and updates launch macros to propagate TopK values and expose TopKUnrollFactor as compile-time constants.

Changes

Cohort / File(s) Summary
Vectorized finalize path scaffolding
csrc/trtllm_fused_moe_dev_kernel.cu
Adds MaxTopK constant (64), new packed element types (bfloat16_2, bfloat16_4, half_4), and introduces two primary trait templates: ScaleTraitsStruct and FinalizeTraits, each with nine specializations across UnrollFactors 1–4 and data types (bfloat16_t, float, half). Extends finalizeKernelVecLoad to allocate per-chunk shared memory for scales and permuted indices, compute per-chunk values using trait-driven types, replace scalar topK loops with chunked unrolled loops, and add per-chunk synchronization and input loading. Alters run(Data, stream) to prioritize vectorized path via LAUNCH_TOPK_EXPW and enforces TopK <= MaxTopK validation.
Launch macro and KernelParams updates
include/flashinfer/trtllm/fused_moe/DevKernel.h
Extends LAUNCH_EXPW macro signature to propagate topK parameter into type instantiations; introduces new LAUNCH_TOPK_EXPW launcher macro that selects topK (4, 2, or 1) from data and delegates to LAUNCH_EXPW. Updates KernelParams template signatures in both moe::dev and moe::dev::finalize namespaces from <Type_, TypeExpW_, UsePdl_> to <Type_, TypeExpW_, int TopKUnrollFactor_, bool UsePdl_>, and exposes TopKUnrollFactor as a public static constexpr compile-time constant.

Sequence Diagram

sequenceDiagram
    participant host as Host
    participant run as run(Data, stream)
    participant selector as LAUNCH_TOPK_EXPW
    participant macroESC as LAUNCH_ESC
    participant kernel as finalizeKernelVecLoad
    
    host->>run: invoke with Data
    run->>run: check TopK <= MaxTopK
    run->>selector: call with data.topK
    selector->>selector: select TopK ∈ {4, 2, 1}
    selector->>macroESC: LAUNCH_ESC with selected TopK
    macroESC->>macroESC: instantiate KernelParams<Type, TypeExpW, TopK, UsePdl>
    macroESC->>kernel: launch with FinalizeTraits<TopK, TypeExpW>
    
    rect rgb(200, 220, 240)
        Note over kernel: Per-chunk vectorized loop
        kernel->>kernel: load scales/indices per chunk
        kernel->>kernel: __syncthreads()
        kernel->>kernel: apply scale & accumulate
        kernel->>kernel: repeat for all chunks
    end
    
    kernel-->>host: finalize complete
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Areas requiring extra attention:

  • Nine trait template specializations across UnrollFactors and data types—verify correctness of each packing/unpacking combination (bfloat16_2, bfloat16_4, half_4)
  • Per-chunk vectorized loop logic in finalizeKernelVecLoad, particularly conversions between packed types and float arrays for accumulation and synchronization semantics
  • Runtime TopK validation check (TopK <= MaxTopK) and its interaction with launcher macro selection logic
  • Macro signature change to LAUNCH_EXPW propagating topK; verify all call sites are updated correctly

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • djmmoss
  • cyx-6
  • nvmbreughe
  • wenscarl

Poem

🐰 A cotton-tailed kernel hops in glee,
Vectorized traits now dance in harmony!
Four, two, or one—the chunks align,
Per-TopK packing, traits intertwine,
Fused MOE's path now streams divine! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: a performance optimization for the TRT-LLM Gen finalize kernel, with clear specificity about what is being optimized.
Description check ✅ Passed The description includes the required context: a brief explanation of the optimization, performance benchmarks demonstrating the impact across various batch sizes, and follows the template structure with populated sections.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on improving the performance of the TRT-LLM Gen Mixture-of-Experts (MoE) finalize kernel. The changes involve refactoring the kernel to leverage vectorized memory access patterns and shared memory for frequently accessed data, which collectively contribute to faster execution times, especially noticeable at lower batch sizes. The aim is to make the MoE layer more efficient within the TRT-LLM generation pipeline.

Highlights

  • Performance Optimization: Implemented a small but significant optimization for the TRT-LLM Gen MoE finalize kernel, yielding speed-ups of up to 1.83x for smaller batch sizes, as detailed in the PR description's performance table.
  • Vectorized Loading and Shared Memory: Enhanced the finalizeKernelVecLoad by introducing TopKUnrollFactor and utilizing shared memory (__shared__) to cache expert weights and permuted indices, thereby reducing global memory access and improving data locality.
  • Templated Type Handling: Added new templated structs (ScaleTraitsStruct, FinalizeTraits) to support vectorized loading for various data types (bfloat16, half, float) and unroll factors (1, 2, 4), making the kernel more flexible and efficient.
  • Dynamic Kernel Launch Selection: Introduced a new macro LAUNCH_TOPK_EXPW to dynamically select the most efficient kernel launch configuration based on the topK value, optimizing for unroll factors of 1, 2, or 4.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@nekorobov
Copy link
Contributor Author

Tagging @ChristinaZ

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for the TRT-LLM MoE finalize kernel. The changes leverage shared memory to cache expert scales and indices, reducing global memory traffic and improving performance, especially for smaller batch sizes. The implementation adds new template metaprogramming structures for handling vectorized loads with different unroll factors and a new dispatch macro to select the optimal kernel. My review identified a critical correctness bug in the initialization of default expert scales, which could lead to incorrect computations. I also found a minor issue with unreachable code in one of the new dispatch macros. Overall, the optimization strategy is sound, and with the suggested fixes, this will be a valuable performance improvement.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)

142-151: Unreachable else branch in LAUNCH_TOPK_EXPW due to % 1 == 0

data.topK % 1 == 0 is always true for integral topK, so the final else (and "Unsupported topK" warning) is unreachable. You can simplify the macro and drop the dead branch:

-#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \
-  if (data.topK % 4 == 0) {                                                     \
-    LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream);      \
-  } else if (data.topK % 2 == 0) {                                              \
-    LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream);      \
-  } else if (data.topK % 1 == 0) {                                              \
-    LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream);      \
-  } else {                                                                      \
-    FLASHINFER_WARN("Unsupported topK");                                        \
-  }
+#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \
+  if (data.topK % 4 == 0) {                                                     \
+    LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream);      \
+  } else if (data.topK % 2 == 0) {                                              \
+    LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream);      \
+  } else {                                                                      \
+    LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream);      \
+  }

This keeps behavior identical while removing unreachable code.

csrc/trtllm_fused_moe_dev_kernel.cu (1)

784-881: Bug: default scale initialization only sets the first element to 1 when expertWeightsPtr is null

In finalizeKernelVecLoad, when params.expertWeightsPtr == nullptr you initialize scalePacked as:

auto scalePacked = (params.expertWeightsPtr != nullptr)
                       ? reinterpret_cast<ScalePackedType const*>(
                             params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor]
                       : ScalePackedType{TypeExpW(1.f)};

For packed types (float2/float4, half2/half_4, bfloat16_2/bfloat16_4), aggregate initialization with a single scalar sets only the first lane to 1 and leaves the remaining lanes zero‑initialized. That means, for each unrolled Top‑K chunk, only the first expert gets scale 1, while the others get scale 0, which is incorrect when expertWeightsPtr is meant to default to all‑ones (same behavior as the scalar finalize kernel).

This is a correctness bug for any topK > 1 when expertWeightsPtr == nullptr and the vectorized path is taken; contributions from non‑first experts in each chunk are effectively dropped.

You can fix this by constructing a ScaleArrayType filled with ones and reinterpreting it into ScalePackedType:

-    auto scalePacked = (params.expertWeightsPtr != nullptr)
-                           ? reinterpret_cast<ScalePackedType const*>(
-                                 params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor]
-                           : ScalePackedType{TypeExpW(1.f)};
+    ScalePackedType scalePacked;
+    if (params.expertWeightsPtr != nullptr) {
+      scalePacked = reinterpret_cast<ScalePackedType const*>(
+          params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor];
+    } else {
+      ScaleArrayType ones;
+      ones.fill(TypeExpW(1.f));
+      scalePacked = *reinterpret_cast<ScalePackedType*>(&ones);
+    }

This ensures every lane in the packed scale vector is initialized to 1.0, matching the semantics of the non‑vectorized finalize kernel.

🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_dev_kernel.cu (1)

784-881: Minor: heavy use of reinterpret_cast for packed/array types may rely on alignment/aliasing assumptions

finalizeKernelVecLoad frequently reinterprets between:

  • int* and IdxPackedType / IdxArrayType
  • TypeExpW* and ScalePackedType / ScaleArrayType
  • InputElem* and float4 (via vectorizedLoadPtx)

This is a common pattern in high‑performance CUDA, but it does assume:

  • Sufficient alignment for IdxPackedType / ScalePackedType and float4 loads.
  • That the packed structs and cutlass::Array specializations are layout‑compatible.

If you start targeting architectures or toolchains with stricter alignment/aliasing semantics, it might be safer to:

  • Explicitly enforce alignment on the underlying allocations for expandedIdxToPermutedIdx and expertWeightsPtr, and/or
  • Store shared‑memory buffers in the cutlass::Array types and only reinterpret at the global‑memory load/store boundary.

Not a blocker, but worth keeping in mind if you hit obscure misalignment or performance issues in future.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 636a3ab and 3145819.

📒 Files selected for processing (3)
  • csrc/trtllm_fused_moe_dev_kernel.cu (5 hunks)
  • flashinfer/jit/fused_moe.py (1 hunks)
  • include/flashinfer/trtllm/fused_moe/DevKernel.h (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/trtllm/fused_moe/DevKernel.h
🧬 Code graph analysis (2)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
csrc/fmhaReduction.cu (1)
  • kernel (339-339)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (3)
  • void (76-251)
  • void (322-325)
  • void (331-387)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
include/flashinfer/trtllm/fused_moe/DevKernel.h (2)

119-140: LAUNCH_EXPW type/TopK dispatch looks consistent

The extended LAUNCH_EXPW macro cleanly threads topK into the KernelParams<Type, TypeExpW, TopKUnrollFactor, UsePdl> instantiation for all supported (mDtypeElt, mDtypeExpW) pairs. No issues from a correctness or call‑site perspective.


467-472: KernelParams TopKUnrollFactor integration looks correct

Adding TypeExpW_ and TopKUnrollFactor_ to finalize::KernelParams (and exposing static constexpr int TopKUnrollFactor) matches how LAUNCH_EXPW/LAUNCH_TOPK_EXPW now instantiate the params and how finalizeKernelVecLoad uses the unroll factor at compile time. No issues here.

flashinfer/jit/fused_moe.py (1)

50-59: Adding -lineinfo to sm100 fused MoE build is reasonable

Including "-lineinfo" in the sm100 Cutlass fused MoE nvcc flags is fine and should help debugging/profiling without changing kernel semantics. No additional concerns.

csrc/trtllm_fused_moe_dev_kernel.cu (2)

677-749: MaxTopK and trait definitions are consistent with 1/2/4-way unrolling

MaxTopK = 64 plus the bfloat16_2/bfloat16_4/half_4 packed types and ScaleTraitsStruct/FinalizeTraits specializations cover all expected (UnrollFactor ∈ {1,2,4}, TypeExpW ∈ {bf16,float,half}) combinations. Layout and alignment choices look coherent with the 128‑bit load path in finalizeKernelVecLoad. No functional issues spotted here.


951-982: Finalize run() dispatch with LAUNCH_TOPK_EXPW looks coherent

  • DeepSeek path (mUseDeepSeekFp8) now uses LAUNCH_TOPK_EXPW with finalizeDeepSeekKernel, which is compatible with the updated KernelParams<Type, TypeExpW, TopKUnrollFactor, UsePdl> (even though the kernel only consumes Type).
  • Non‑DeepSeek path:
    • For small grids (numBlocksX * numBlocksY < 1184), you keep the scalar finalizeKernel and still dispatch via LAUNCH_TOPK_EXPW, which is fine since TopKUnrollFactor is unused there.
    • For larger grids, you guard data.topK <= MaxTopK before launching finalizeKernelVecLoad with gridDim.x = data.numTokens and blockDim.x = FINALIZE_THREADS_PER_BLOCK, matching the kernel’s expectations (blockIdx.x as token index, stride = FINALIZE_THREADS_PER_BLOCK).

Assuming upstream ensures data.numTokens fits within CUDA’s grid‑dim limits, this dispatch logic looks correct.

If you want to be extra safe, you can add a runtime check close to this dispatch to assert that data.numTokens <= 65535 (or whatever max grid.x you intend to support), mirroring the existing cap you already apply on numBlocksY in the scalar path.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)

142-149: LAUNCH_TOPK_EXPW dispatch relies on topK divisibility assumptions

The % 4 / % 2 dispatch cleanly guarantees that the chosen TopKUnrollFactor divides data.topK, which is what the vectorized finalize path expects. If data.topK can ever be zero or negative, consider adding an explicit FLASHINFER_CHECK(data.topK > 0, ...) at the call‑site or here to fail fast instead of silently falling back to the scalar unroll‑1 path.

csrc/trtllm_fused_moe_dev_kernel.cu (3)

677-781: Packed type and trait wiring is sound; consider adding size/alignment static_asserts

The MaxTopK constant and the ScaleTraitsStruct / FinalizeTraits specializations for unroll factors 1/2/4 across cutlass::bfloat16_t, float, and half line up with how the vectorized finalize path uses IdxPackedType, IdxArrayType, ScalePackedType, and ScaleArrayType. To make this more robust against future refactors (or changes in CUDA/CUTLASS typedefs), you might want to add a few compile‑time checks, e.g.:

static_assert(sizeof(IdxPackedType) == sizeof(IdxArrayType), "Idx pack/array size mismatch");
static_assert(sizeof(ScalePackedType) == sizeof(ScaleArrayType), "Scale pack/array size mismatch");

(and similarly check alignment if you rely on vectorized loads). This would catch any accidental drift between the packed structs and their corresponding cutlass::Array representations.


788-830: Vectorized finalize kernel: TopKUnrollFactor invariants and alignment assumptions

The new finalizeKernelVecLoad path looks logically consistent with the scalar finalize kernel:

  • TopKUnrollFactor is taken from KernelParams::TopKUnrollFactor, statically asserted to be 1/2/4, and used to select the appropriate FinalizeTraits.
  • Shared memory buffers are sized as MaxTopK / TopKUnrollFactor, and the kChunkIdx < params.topK / TopKUnrollFactor loops stay within bounds as long as params.topK <= MaxTopK and params.topK % TopKUnrollFactor == 0.
  • The per‑chunk prefetch of expandedIdxToPermutedIdx and expertWeightsPtr via packed types, followed by the inner accumulation over TopKUnrollFactor with permutedIdx == -1 checks, mirrors the scalar logic.

Two points worth calling out:

  1. Dependence on launch‑time guarantees
    Correctness of the chunking depends on params.topK % TopKUnrollFactor == 0. Right now this is enforced implicitly by LAUNCH_TOPK_EXPW (choosing 4/2/1 based on divisibility). It would be safer to either:

    • add a defensive runtime check early in finalizeKernelVecLoad (e.g., assert(params.topK % TopKUnrollFactor == 0); or FLASHINFER_CHECK), or
    • document this invariant clearly in a comment near the kernel definition.
      This helps avoid subtle issues if someone ever instantiates KernelParams manually or bypasses LAUNCH_TOPK_EXPW.
  2. Vectorized load alignment assumptions
    vectorizedLoadPtx performs ld.global.v4.f32 on a float4 const*, where the pointer comes from:

    auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol;
    float4 input =
        vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));

    Here inElemPtr is a cutlass::Array<Type, FINALIZE_ELEM_PER_THREAD> const*, and then you reinterpret its elements as float4. This assumes 16‑byte alignment of the underlying storage (both for the allocation and the cutlass::Array layout). If that alignment is not guaranteed by upstream allocations, the PTX vector load could become misaligned. It may be worth:

    • explicitly documenting the alignment requirement on params.inPtr, and/or
    • adding a debug‑mode assertion on pointer alignment before issuing the vector load.

    Even if current callers satisfy this, making the assumption explicit will prevent future regressions.

Finally, note that when useScale == false the ScalePackedType{TypeExpW(1.f)} path is effectively dead (since you fall back to scale = 1.0f later). You could skip writing scaleArrSmem in that case to shave a few instructions, but that’s a micro‑optimization.

Also applies to: 844-878


961-983: New MaxTopK guard changes behavior for large topK in the vectorized path

The updated run:

  • Uses LAUNCH_TOPK_EXPW for both DeepSeek and standard finalize kernels, so TopKUnrollFactor is consistently propagated.
  • Switches to finalizeKernelVecLoad when numBlocksX * numBlocksY >= 1184, but only after enforcing FLASHINFER_CHECK(data.topK <= MaxTopK) with MaxTopK = 64.

This introduces a new runtime constraint: previously, large topK values would still go through the scalar finalizeKernel; now, for large grids and topK > 64, you’ll trip the check instead of falling back to the scalar path.

If TRT‑LLM configs can ever use topK > 64 in this code path, you may want to:

  • Replace the FLASHINFER_CHECK with a conditional fallback:
if (data.topK > MaxTopK) {
  dim3 numBlocks(numBlocksX, numBlocksY);
  LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
} else {
  // current vectorized path
}
  • Or, if the design genuinely forbids topK > 64, document that constraint prominently (e.g., in a comment or higher‑level API checks) so users see it before hitting a runtime failure.

Please confirm that all supported TRT‑LLM MoE configs satisfy topK <= 64 in this path.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3145819 and b2ca5df.

📒 Files selected for processing (2)
  • csrc/trtllm_fused_moe_dev_kernel.cu (5 hunks)
  • include/flashinfer/trtllm/fused_moe/DevKernel.h (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/trtllm/fused_moe/DevKernel.h
🧬 Code graph analysis (2)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
csrc/fmhaReduction.cu (1)
  • kernel (339-339)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (3)
  • void (76-251)
  • void (322-325)
  • void (331-387)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/trtllm/fused_moe/DevKernel.h (2)

119-140: TopK unroll factor propagation through LAUNCH_EXPW looks consistent

Forwarding topK into LAUNCH_ESC so that KernelParams<Type, TypeExpW, TopKUnrollFactor, UsePdl> can see the unroll factor as a compile‑time constant is coherent, and the dtype pair coverage matches the existing combinations of (mDtypeElt, mDtypeExpW). No functional issues spotted here.


465-470: Breaking change verified: template parameter addition may impact external direct instantiations

The finalize::KernelParams template at lines 465-470 has the signature <typename Type_, typename TypeExpW_, int TopKUnrollFactor_, bool UsePdl_>. All internal instantiations use macros (e.g., KernelParams<types, true>) that correctly expand to provide all four template arguments through LAUNCH_ESC. However, if external code directly instantiated this template with the previous 3-parameter form <Type_, TypeExpW_, bool UsePdl_>, it will break and require updating.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 14, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !138 has been created, and the CI pipeline #38500464 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 merged commit cce4952 into flashinfer-ai:main Nov 14, 2025
4 checks passed
qsang-nv pushed a commit to qsang-nv/flashinfer that referenced this pull request Nov 18, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

- Small optimization for TRT-LLM Gen MoE finalize kernel

TopK=8, NumExperts=128, HiddenSize=4096

| BS | Baseline, us | Optimized, us | Speed-up |
| ------------- | ------------- | ------------- | ------------- |
| 256  | 11  | 6  | 1.83 |
| 512  | 12  | 7  | 1.71 |
| 1024 | 16  | 15  | 1.06 |
| 4096  | 55 | 49  | 1.12 |
| 8192 | 107 | 95  | 1.13 |
| 16384  | 205  | 183  | 1.12 |

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Enabled vectorized, Top-K unrolled finalize path for MOE (Mixture of
Experts) kernel operations with improved performance.
* Added support for multiple data types (bfloat16, float, half) with
enhanced type specialization and packing.
* Introduced runtime validation for TopK configurations (≤ 64) to ensure
optimal vectorized execution.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants