Skip to content

feat: Fuse shared experts into trtllm_gen moe (fp8)#2625

Open
nv-yunzheq wants to merge 5 commits intoflashinfer-ai:mainfrom
nv-yunzheq:DSR1_shared_expert_fusion
Open

feat: Fuse shared experts into trtllm_gen moe (fp8)#2625
nv-yunzheq wants to merge 5 commits intoflashinfer-ai:mainfrom
nv-yunzheq:DSR1_shared_expert_fusion

Conversation

@nv-yunzheq
Copy link
Collaborator

@nv-yunzheq nv-yunzheq commented Feb 23, 2026

For #2551
Integrating NVIDIA/TensorRT-LLM#11143

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Support for fused shared experts in MoE: per-token expert counts and routing expanded; FP8 MoE APIs accept an optional num_fused_shared_experts parameter.
  • Documentation

    • Docstrings and API docs updated to describe the fused-experts parameter and its usage constraints.
  • Tests

    • Added and extended tests covering fused shared-expert scenarios and routing behavior.
  • Bug Fixes

    • Initialized fused-expert routing metadata to ensure consistent runtime behavior.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the concept of 'fused shared experts' into the existing TensorRT-LLM MoE framework, particularly for FP8 operations. The primary goal is to optimize the handling of experts that are shared across multiple tokens, by incorporating them directly into the MoE kernel's routing and execution logic. This change impacts how memory is allocated, how routing decisions are made, and how the overall MoE computation is performed, leading to a more streamlined and potentially faster processing of MoE layers with shared components.

Highlights

  • Fused Shared Experts Integration: Introduced support for fusing shared experts into the TensorRT-LLM Mixture-of-Experts (MoE) kernel, specifically for FP8 quantization, enhancing efficiency and flexibility.
  • Updated MoE Kernel Logic: Modified the MoE kernel launchers and runners to correctly account for and process num_fused_shared_experts in tensor allocations, routing calculations, and workspace sizing.
  • DeepSeekV3 Routing Enhancement: The DeepSeekV3 routing method was updated to properly handle the routing and weighting of both regular and fused shared experts.
  • Routing Method Compatibility Checks: Added explicit checks to ensure that Llama4 and Renormalize routing methods do not support the fused shared expert functionality, preventing misuse.
  • Python API and Test Coverage: Extended the Python API to expose the num_fused_shared_experts parameter and added new test cases to validate the correctness of the fused shared experts feature.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Adjusted tensor allocation sizes for num_tokens_per_expert, expanded_idx_to_permuted_idx, expert_indexes, and expert_count_histogram to include num_fused_shared_experts.
    • Updated getMaxPermutedPaddedCount and getMaxNumCtasInBatchDim calls to use totalExpertsPerToken and totalNumExperts.
    • Modified moe_runner->getDefaultValidConfigIndex and getValidConfigIndices to use effectiveTopK and effectiveLocalExperts.
    • Passed num_fused_shared_experts to the routing_runner.run method.
  • csrc/trtllm_fused_moe_routing_deepseek.cu
    • Updated idxTopK and idxShared calculations to use mTotalExpertsPerToken.
    • Added logic to write packed scores and weights for fused shared experts.
    • Introduced numExperts, topK, and numThreadsHist variables to reflect total experts including fused shared ones.
    • Added FLASHINFER_CHECK for num_fused_shared_experts being less than warp size.
    • Modified useSingleCluster condition to include data.mNumTokens * topK.
    • Adjusted maxTokensCoop calculation to use topK.
    • Incremented data.mNumExperts, data.mTopK, and data.mNumLocalExperts by data.mNumFusedSharedExperts if shared experts are present.
    • Updated expandedIdxSize calculation to use topK.
  • csrc/trtllm_fused_moe_runner.cu
    • Added numFusedSharedExpert parameter to the Runner::run method signature.
    • Introduced totalExpertsPerToken variable in run method for DeepSeekV3 routing.
    • Added mNumFusedSharedExperts and mTotalExpertsPerToken to routingData for DeepSeekV3.
    • Implemented logic to calculate mSharedExpertTokenOffset and mSharedExpertNumTokens for shared experts distribution.
    • Added FLASHINFER_CHECK to disallow numFusedSharedExpert for Llama4 and Renormalize routing methods.
    • Introduced totalNumExperts and totalExpertsPerToken in setOpsData for activation and finalize data setup.
    • Updated activationData.topK and finalizeData.numExperts, finalizeData.topK to use total expert counts.
    • Modified getWorkspaceSizeInBytes to use totalLocalExperts and totalExpertsPerToken for workspace calculations.
    • Updated mPermuteGemm1.run and mGemm2.run calls to use totalLocalExperts and totalExpertsPerToken.
  • flashinfer/fused_moe/core.py
    • Added num_fused_shared_experts as an optional argument to trtllm_fp8_block_scale_moe_op and trtllm_fp8_block_scale_moe functions.
    • Passed num_fused_shared_experts to the C++ backend calls.
    • Updated docstrings to describe the new num_fused_shared_experts parameter and its implications on weight tensors.
    • Modified the forward method to pass num_fused_shared_experts to the underlying C++ function.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
    • Added mNumFusedSharedExperts, mSharedExpertTokenOffset, mSharedExpertNumTokens, and mTotalExpertsPerToken members to DataBase and KernelParamsBase structs.
  • include/flashinfer/trtllm/fused_moe/runner.h
    • Added numFusedSharedExpert parameter to the Runner::run method declaration.
    • Added num_fused_shared_experts member to the MoERunnerArgs struct.
  • tests/moe/test_trtllm_gen_fused_moe.py
    • Introduced num_routed_experts calculation in call_moe.
    • Adjusted num_experts and top_k parameters passed to trtllm_fp8_block_scale_moe to account for fused shared experts.
    • Modified routing_reference to accept num_fused_shared_experts and include shared experts in its logic.
    • Updated routing_reference_no_aux to pass num_fused_shared_experts to routing_reference.
    • Modified _compute_moe_actual_unified to pass num_fused_shared_experts to call_moe.
    • Adjusted run_moe_test to calculate total_experts and pass num_fused_shared_experts to reference implementations and moe_args.
    • Updated gemm1_weights and gemm2_weights tensor shapes in run_moe_test to accommodate total_experts.
    • Added new pytest.param entries for DSv3_fused_shared_1 and DSv3_fused_shared_2 to test the new functionality.
Activity
  • The pull request aims to integrate a feature from TensorRT-LLM PR 11143.
  • It addresses issue Shared expert fusion integration #2551, likely related to MoE functionality.
  • The author, nv-yunzheq, has included standard pre-commit checks and test requirements in the PR description.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 23, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 62808177-e5b9-4c60-93fc-67eb44edb576

📥 Commits

Reviewing files that changed from the base of the PR and between 2255bca and 95a01d0.

📒 Files selected for processing (1)
  • csrc/moe_utils_binding.cu

📝 Walkthrough

Walkthrough

Adds support for fused shared experts across fused-MoE: new parameter num_fused_shared_experts is threaded through CUDA kernels, runner APIs, routing logic, Python bindings, and tests; routing, workspace, and GEMM sizing use totals (top_k + fused) and local-expert totals where applicable.

Changes

Cohort / File(s) Summary
Kernel Launcher
csrc/trtllm_fused_moe_kernel_launcher.cu
Introduce totalExpertsPerToken/totalNumExperts (top_k + fused) and update routing/workspace allocations, histogram sizing, and workspace bindings to use totals.
DeepSeek Routing Kernel
csrc/trtllm_fused_moe_routing_deepseek.cu
Extend per-token top-K logic to include fused shared experts, write fused shared entries into top-K outputs, add runtime checks and capacity adjustments, and update indexing/size calculations.
MoE Runner & Headers
csrc/trtllm_fused_moe_runner.cu, include/flashinfer/.../runner.h, include/flashinfer/.../RoutingKernel.h
Add numFusedSharedExpert param to Runner::run and MoERunnerArgs; add fused-expert fields to routing params; compute/propagate totalExpertsPerToken and totalLocalExperts; update GEMM/workspace sizing and kernel param propagation.
Bindings / Python
flashinfer/fused_moe/core.py
Expose num_fused_shared_experts in FP8 block-scale MoE API, compute local integer _nfse, and forward it to the C++/CUDA kernel call; update docstrings.
Tests
tests/moe/test_trtllm_gen_fused_moe.py
Extend reference helpers and tests to accept num_fused_shared_experts; use total_experts = num_experts + num_fused_shared_experts; adjust weight shapes, top_k calculations, and add fused-shared test cases.
Routing Data Init
csrc/moe_utils_binding.cu
Zero-initialize fused-shared routing fields (mNumFusedSharedExperts, mSharedExpertTokenOffset, mSharedExpertNumTokens, mTotalExpertsPerToken) to ensure safe kernel reads.

Sequence Diagram

sequenceDiagram
    participant Client
    participant Runner as MoE Runner
    participant Launcher as Kernel Launcher
    participant Router as Routing Kernel
    participant GEMM as GEMM Kernels
    participant Tests as Test Harness

    Client->>Runner: run(topK, numFusedSharedExpert, ...)
    Runner->>Runner: totalExpertsPerToken = topK + numFusedSharedExpert
    Runner->>Launcher: launch routing with totalExpertsPerToken, totalNumExperts
    Launcher->>Router: execute routing (select per-token topK + fused)
    Router->>Launcher: return routing indexes & histograms
    Launcher->>GEMM: invoke PermuteGemm1/Gemm2 with totalExpertsPerToken/totalLocalExperts
    Tests->>Client: validate outputs using total_experts-aware reference
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

op: moe-routing

Suggested reviewers

  • IwakuraRein
  • jiahanc
  • aleozlx
  • djmmoss
  • joker-eph
  • yzh119

Poem

🐰 I counted experts both local and shared,
I hopped through top-K with new friends paired.
Totals now travel from Python to C,
Kernels and tests nod—fusion, we see!
Hop, hop—routing sings, the rabbits cheer.

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is largely incomplete; it references related issues and the integration source but lacks substantive detail about what was changed, why it was needed, or testing status. Most checklist items remain unchecked, suggesting incomplete submission. Add a clear summary of changes in the Description section, explain the integration rationale, document test results, and mark completed checklist items accordingly.
Docstring Coverage ⚠️ Warning Docstring coverage is 30.30% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly describes the main feature being implemented: fusing shared experts into trtllm_gen MoE with FP8 support, which aligns with the changes across multiple CUDA and Python files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !344 has been created, and the CI pipeline #44669282 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the fusion of shared experts into the trtllm_gen MoE implementation, specifically for FP8. The changes cover the routing kernel, the launcher, and the Python API. While the integration logic for shared experts is mostly sound, there are a few critical issues regarding histogram initialization and template dispatching in the routing kernel that could lead to undefined behavior or incorrect results in multi-GPU or large-token scenarios.

Comment on lines +668 to +672
if (data.mNumFusedSharedExperts > 0) {
data.mNumExperts += data.mNumFusedSharedExperts;
data.mTopK += data.mNumFusedSharedExperts;
data.mNumLocalExperts += data.mNumFusedSharedExperts;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Updating data.mNumExperts and data.mTopK after the first kernel launch (line 656 or 662) leads to several issues:

  1. numThreadsMain (line 655) and the histogram initialization inside routingMainKernel (line 85) use the original routed expert count, meaning the histogram entries for shared experts are never initialized to zero. This can cause garbage values to be used as offsets in subsequent permutation kernels.
  2. The dispatching macro LAUNCH_ROUTING_DEEPSEEK uses data.mNumExperts to select the MaxNumExperts template parameter. If the total expert count (routed + shared) crosses a threshold (e.g., 256 to 257), the first and second launches will use different template instantiations, which is inconsistent.

You should calculate the total expert count and top-k at the beginning of runImpl and ensure that initialization kernels use the total count, while routingMainKernel receives the routed count for its indexing logic.

Comment on lines +617 to +619
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
"Number of fused shared experts (%d) must be less than warp size.",
data.mNumFusedSharedExperts);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for mNumFusedSharedExperts <= WarpSize is currently placed inside the if (data.mNumExpertGroups > 1) block. However, routingMainKernel always assumes that shared experts can be handled by a single warp (using laneIdx), regardless of whether expert groups are used. This check should be moved outside the conditional block to ensure it is always enforced.

weight_layout=weight_layout,
do_finalize=do_finalize,
enable_pdl=enable_pdl,
num_fused_shared_experts=num_fused_shared_experts,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The num_fused_shared_experts parameter should be included in the instance_key used by the MoERunner (around line 1045). Since the kernel's performance and configuration depend on the total number of experts (routed + shared), omitting this from the key might lead to the autotuner returning a suboptimal tactic if multiple calls with different shared expert counts are made.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

810-821: ⚠️ Potential issue | 🟠 Major

Shape validation for precomputed routing with fused shared experts is inconsistent.

The shape check at line 818 validates expert_indices.size(1) == args->top_k, but when fused shared experts are enabled, precomputed indices should account for the additional fused entries. At line 892, totalExpertsPerToken is calculated as args->top_k + args->num_fused_shared_experts, and the expert_weights tensor is allocated with this dimension (line 897). If precomputed routing is used alongside fused shared experts, the shape validation should check expert_indices.size(1) == totalExpertsPerToken instead of just args->top_k to ensure consistency with the routing output tensors.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 810 - 821, In
check_routing(), the validation of expert_indices.dim(1) only compares to
args->top_k but must account for fused shared experts; compute an expected width
like int expectedPerToken = args->top_k + args->num_fused_shared_experts (or
just use args->top_k when num_fused_shared_experts is zero) and replace the
existing TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) with a check
against expectedPerToken so precomputed routing matches the allocation for
totalExpertsPerToken and expert_weights.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 616-619: The check ensuring data.mNumFusedSharedExperts <=
WarpSize must be unconditional because fused shared-expert writes use laneIdx <
mNumFusedSharedExperts regardless of expert group count; move the
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if
(data.mNumExpertGroups > 1) block so it always runs, ensuring
data.mNumFusedSharedExperts is validated before any code paths that use
mNumFusedSharedExperts/mNumFusedSharedExperts-induced lane comparisons or writes
(references: data.mNumFusedSharedExperts, WarpSize, mNumFusedSharedExperts,
mNumExpertGroups).
- Around line 571-574: routingInitExpertCounts currently initializes only 2 *
data.mNumExperts using the pre-fusion value then data.mNumExperts is incremented
to include mNumFusedSharedExperts, leaving histogram slots for fused-shared
experts uninitialized; fix by making the initialization cover the full fused
range (initialize 2 * (data.mNumExperts + data.mNumFusedSharedExperts)) or by
moving the mutation of data.mNumExperts (add mNumFusedSharedExperts) before
calling routingInitExpertCounts so the kernel initializes the correct size, and
ensure subsequent kernels that atomicAdd into expert-count slots will see zeros
for indices [original_mNumExperts, original_mNumExperts +
mNumFusedSharedExperts). Also move the check data.mNumFusedSharedExperts <=
WarpSize out of the if (data.mNumExpertGroups > 1) block so the
fused-shared-expert write logic (the unconditional write at the fused shared
expert site) consistently validates the WarpSize constraint regardless of group
count.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h`:
- Around line 102-107: The new fused-shared expert members
(mNumFusedSharedExperts, mSharedExpertTokenOffset, mSharedExpertNumTokens,
mTotalExpertsPerToken) are uninitialized; initialize them to the same safe
defaults used in KernelParamsBase (e.g., zero) by adding default member
initializers or setting them in the DataBase constructor so callers that don't
set them won't propagate garbage into kernel params and cause routing/OOB
errors.

---

Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 810-821: In check_routing(), the validation of
expert_indices.dim(1) only compares to args->top_k but must account for fused
shared experts; compute an expected width like int expectedPerToken =
args->top_k + args->num_fused_shared_experts (or just use args->top_k when
num_fused_shared_experts is zero) and replace the existing
TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) with a check against
expectedPerToken so precomputed routing matches the allocation for
totalExpertsPerToken and expert_weights.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26ef055 and 259d279.

📒 Files selected for processing (7)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_runner.cu
  • flashinfer/fused_moe/core.py
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/moe/test_trtllm_gen_fused_moe.py

Comment on lines +616 to +619

FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
"Number of fused shared experts (%d) must be less than warp size.",
data.mNumFusedSharedExperts);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

fusedSharedExperts <= WarpSize check should be unconditional.

This validation is guarded by if (data.mNumExpertGroups > 1) (line 605), but the fused shared expert writes at lines 261-265 and 272-274 use laneIdx < mNumFusedSharedExperts regardless of expert groups. If mNumExpertGroups <= 1 and mNumFusedSharedExperts > WarpSize, the writes would silently skip some fused experts.

Suggested fix

Move the check out of the if (data.mNumExpertGroups > 1) block:

+  FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
+                   "Number of fused shared experts (%d) must be less than warp size.",
+                   data.mNumFusedSharedExperts);
+
   if (data.mNumExpertGroups > 1) {
     FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups,
                      ...);
     ...
-
-    FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
-                     "Number of fused shared experts (%d) must be less than warp size.",
-                     data.mNumFusedSharedExperts);
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 616 - 619, The check
ensuring data.mNumFusedSharedExperts <= WarpSize must be unconditional because
fused shared-expert writes use laneIdx < mNumFusedSharedExperts regardless of
expert group count; move the FLASHINFER_CHECK(data.mNumFusedSharedExperts <=
WarpSize, ...) out of the if (data.mNumExpertGroups > 1) block so it always
runs, ensuring data.mNumFusedSharedExperts is validated before any code paths
that use mNumFusedSharedExperts/mNumFusedSharedExperts-induced lane comparisons
or writes (references: data.mNumFusedSharedExperts, WarpSize,
mNumFusedSharedExperts, mNumExpertGroups).

Comment on lines +102 to +107

/// For fused shared expert
int32_t mNumFusedSharedExperts;
int32_t mSharedExpertTokenOffset;
int32_t mSharedExpertNumTokens;
int32_t mTotalExpertsPerToken;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Initialize fused-shared expert metadata fields to safe defaults.

These new DataBase members are currently uninitialized, so any caller that forgets to set them will propagate garbage into kernel params and risk incorrect routing / OOB indexing. Mirror the KernelParamsBase defaults.

🛠️ Suggested fix
-  int32_t mNumFusedSharedExperts;
-  int32_t mSharedExpertTokenOffset;
-  int32_t mSharedExpertNumTokens;
-  int32_t mTotalExpertsPerToken;
+  int32_t mNumFusedSharedExperts{0};
+  int32_t mSharedExpertTokenOffset{0};
+  int32_t mSharedExpertNumTokens{0};
+  int32_t mTotalExpertsPerToken{0};
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/// For fused shared expert
int32_t mNumFusedSharedExperts;
int32_t mSharedExpertTokenOffset;
int32_t mSharedExpertNumTokens;
int32_t mTotalExpertsPerToken;
/// For fused shared expert
int32_t mNumFusedSharedExperts{0};
int32_t mSharedExpertTokenOffset{0};
int32_t mSharedExpertNumTokens{0};
int32_t mTotalExpertsPerToken{0};
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h` around lines 102 - 107,
The new fused-shared expert members (mNumFusedSharedExperts,
mSharedExpertTokenOffset, mSharedExpertNumTokens, mTotalExpertsPerToken) are
uninitialized; initialize them to the same safe defaults used in
KernelParamsBase (e.g., zero) by adding default member initializers or setting
them in the DataBase constructor so callers that don't set them won't propagate
garbage into kernel params and cause routing/OOB errors.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44669282: 13/20 passed

@nv-yunzheq nv-yunzheq force-pushed the DSR1_shared_expert_fusion branch from 259d279 to 2255bca Compare March 9, 2026 17:53
@nv-yunzheq
Copy link
Collaborator Author

/bot run

@nv-yunzheq nv-yunzheq marked this pull request as ready for review March 9, 2026 18:45
@nv-yunzheq nv-yunzheq requested a review from bkryu as a code owner March 9, 2026 18:45
@flashinfer-bot
Copy link
Collaborator

GitLab MR !344 has been updated with latest changes, and the CI pipeline #45731067 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/fused_moe/core.py (1)

1804-1836: ⚠️ Potential issue | 🟠 Major

Fake op signature mismatch: missing num_fused_shared_experts parameter.

The _fake_trtllm_fp8_block_scale_moe function signature must exactly mirror the real op trtllm_fp8_block_scale_moe_op. The real op has num_fused_shared_experts: int = 0 at line 1661, but the fake op is missing this parameter. This will cause issues with torch.compile or other tracing scenarios.

Suggested fix
 `@register_fake_op`("flashinfer::trtllm_fp8_block_scale_moe")
 def _fake_trtllm_fp8_block_scale_moe(
     routing_logits: Optional[torch.Tensor],
     topk_ids: Optional[torch.Tensor],
     expert_weights: Optional[torch.Tensor],
     routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     hidden_states_scale: torch.Tensor,
     gemm1_weights: torch.Tensor,
     gemm1_weights_scale: torch.Tensor,
     gemm2_weights: torch.Tensor,
     gemm2_weights_scale: torch.Tensor,
     output: torch.Tensor,
     num_experts: int,
     top_k: int,
     n_group: Optional[int],
     topk_group: Optional[int],
     intermediate_size: int,
     local_expert_offset: int,
     local_num_experts: int,
     routed_scaling_factor: Optional[float],
     routing_method_type: int = 0,
     use_shuffled_weight: bool = False,
     weight_layout: int = 0,
     do_finalize: bool = True,
     enable_pdl: Optional[bool] = None,
     tune_max_num_tokens: int = 8192,
     fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
+    num_fused_shared_experts: int = 0,
 ) -> List[torch.Tensor]:

Based on learnings: "When reviewing files that define fake ops decorated with register_fake_op (e.g., in flashinfer/fused_moe/*), ensure the function signatures exactly mirror the real op they stand in for."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1804 - 1836, The fake op
_fake_trtllm_fp8_block_scale_moe must exactly mirror the real op signature
trtllm_fp8_block_scale_moe_op: add the missing parameter
num_fused_shared_experts: int = 0 to the fake function signature (position it
where the real op declares it) so tracing/torch.compile sees identical
parameters; update any callers or tests if they rely on positional args to
ensure compatibility.
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

1857-1874: ⚠️ Potential issue | 🟠 Major

Validate num_fused_shared_experts before using it in size math.

The new FFI parameter is folded directly into totalExpertsPerToken and totalLocalExperts. A negative value can drive those counts to zero or below and break tile selection/workspace sizing before any lower-layer routing checks run.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1857 - 1874, The code
uses the FFI parameter num_fused_shared_experts directly in size math
(totalExpertsPerToken, totalLocalExperts) which can be negative; validate and
clamp it before use (e.g., ensure num_fused_shared_experts >= 0 and fits
expected bounds) and reject or adjust invalid values; specifically,
check/convert num_fused_shared_experts (and the optional
num_fused_shared_experts.value_or(0)) to a non-negative int64_t before computing
totalExpertsPerToken and totalLocalExperts, and add a defensive check that
aborts or logs an error if the provided FFI value is out of acceptable range so
computeSelectedTileN and downstream launchers
(Fp8BlockScaleLauncher::getSupportedTileNums, computeSelectedTileN,
MoERunnerArgs) never see negative counts.

919-929: ⚠️ Potential issue | 🟠 Major

Precomputed routing tensors still use the old top_k width.

Only the internally allocated expert_weights buffer is widened to top_k + num_fused_shared_experts. If the caller provides precomputed expert_indices / expert_weights, this path still follows the old top_k contract elsewhere, so fused-shared precomputed routing will either reject correctly sized tensors or consume too few columns during finalize.

♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (2)

620-623: ⚠️ Potential issue | 🟡 Minor

mNumFusedSharedExperts <= WarpSize check should be unconditional.

This validation is guarded by if (data.mNumExpertGroups > 1) (line 609), but the fused shared expert writes at lines 261-265 and 272-274 use laneIdx < mNumFusedSharedExperts regardless of expert groups. If mNumExpertGroups <= 1 and mNumFusedSharedExperts > WarpSize, the writes would silently skip some fused experts.

Suggested fix

Move the check out of the if (data.mNumExpertGroups > 1) block:

+  FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
+                   "Number of fused shared experts (%d) must be less than warp size.",
+                   data.mNumFusedSharedExperts);
+
   if (data.mNumExpertGroups > 1) {
     FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups,
                      ...);
     ...
-
-    FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
-                     "Number of fused shared experts (%d) must be less than warp size.",
-                     data.mNumFusedSharedExperts);
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 620 - 623, The check
ensuring data.mNumFusedSharedExperts <= WarpSize must be unconditional: move the
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if
(data.mNumExpertGroups > 1) block so it always runs; this prevents laneIdx <
mNumFusedSharedExperts conditions (used in the fused shared expert writes) from
silently skipping experts when data.mNumExpertGroups <= 1 and
mNumFusedSharedExperts > WarpSize. Ensure the unique symbols FLASHINFER_CHECK,
data.mNumFusedSharedExperts, WarpSize, data.mNumExpertGroups, and laneIdx are
referenced when relocating the check.

666-676: ⚠️ Potential issue | 🟠 Major

Expert count histogram not initialized for fused shared expert indices.

The routingInitExpertCounts kernel (line 666-669) initializes 2 * data.mNumExperts elements using the pre-mutation value. After the kernel completes, data.mNumExperts is incremented at lines 673-675 to include mNumFusedSharedExperts. Subsequent kernels (lines 678+) use the mutated value but access uninitialized histogram slots for indices [original_mNumExperts, original_mNumExperts + mNumFusedSharedExperts).

This causes atomicAdd operations to accumulate into uninitialized values for fused shared expert slots.

Suggested fix

Either move the mutation before the histogram initialization or expand the initialization range:

+  if (data.mNumFusedSharedExperts > 0) {
+    data.mNumExperts += data.mNumFusedSharedExperts;
+    data.mTopK += data.mNumFusedSharedExperts;
+    data.mNumLocalExperts += data.mNumFusedSharedExperts;
+  }
+
   if (data.mPtrTopKIds == nullptr) {
     ...
   } else {
     // Reset the global histograms.
     LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts,
                             (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist,
                             /*smemSize=*/0,
                             stream, data.mNumExpertGroups > 1);
   }

-  if (data.mNumFusedSharedExperts > 0) {
-    data.mNumExperts += data.mNumFusedSharedExperts;
-    data.mTopK += data.mNumFusedSharedExperts;
-    data.mNumLocalExperts += data.mNumFusedSharedExperts;
-  }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 666 - 676, The
histogram for expert counts is initialized by the routingInitExpertCounts kernel
using the pre-mutation value of data.mNumExperts, but data.mNumExperts is then
increased by mNumFusedSharedExperts, leaving the new fused-shared slots
uninitialized; fix by either moving the mutation of data.mNumExperts (and
data.mTopK/data.mNumLocalExperts) to occur before the
LAUNCH_ROUTING_DEEPSEEK(...) call that invokes routingInitExpertCounts so the
kernel initializes the full range, or modify the initialization invocation to
cover (2 * (data.mNumExperts + data.mNumFusedSharedExperts)) elements (or
equivalent) so routingInitExpertCounts explicitly zeroes the fused-shared
indices; update references to routingInitExpertCounts, data.mNumExperts,
data.mNumFusedSharedExperts, and the LAUNCH_ROUTING_DEEPSEEK call accordingly.
🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)

1757-1759: Redundant None check for num_fused_shared_experts.

Since num_fused_shared_experts is typed as int = 0 at line 1661 (not Optional[int]), the None check on line 1759 is unnecessary. The parameter can never be None at this point.

Suggested simplification
-        _nfse = num_fused_shared_experts if num_fused_shared_experts is not None else 0
+        _nfse = num_fused_shared_experts
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1757 - 1759, The assignment uses
an unnecessary None check for num_fused_shared_experts (typed as int with
default 0); simplify by removing the conditional and directly assign _nfse =
num_fused_shared_experts in the scope where num_fused_shared_experts is passed
(refer to the variables num_fused_shared_experts and _nfse in this
function/class), ensuring no Optional handling remains.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_runner.cu`:
- Around line 106-117: The partition math wrongly assumes uniform shards —
replace the division-based computation (numDevices, deviceIndex derived from
numExperts / localNumExperts and localExpertOffset / localNumExperts) with logic
that computes device boundaries from actual per-rank expert counts: build the
cumulative expert-count prefix (using the actual localNumExperts for each
device/rank) to find the device index and the exact token-offset/length for
routingData.mSharedExpertTokenOffset and routingData.mSharedExpertNumTokens;
ensure you use numTokens scaled by each device's expert count slice (not simple
baseTokensPerDevice/remainingTokens across a uniform numDevices), and reference
localExpertOffset, localNumExperts, numExperts when mapping into the cumulative
ranges so uneven sharding yields correct offsets and lengths.

---

Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1857-1874: The code uses the FFI parameter
num_fused_shared_experts directly in size math (totalExpertsPerToken,
totalLocalExperts) which can be negative; validate and clamp it before use
(e.g., ensure num_fused_shared_experts >= 0 and fits expected bounds) and reject
or adjust invalid values; specifically, check/convert num_fused_shared_experts
(and the optional num_fused_shared_experts.value_or(0)) to a non-negative
int64_t before computing totalExpertsPerToken and totalLocalExperts, and add a
defensive check that aborts or logs an error if the provided FFI value is out of
acceptable range so computeSelectedTileN and downstream launchers
(Fp8BlockScaleLauncher::getSupportedTileNums, computeSelectedTileN,
MoERunnerArgs) never see negative counts.

In `@flashinfer/fused_moe/core.py`:
- Around line 1804-1836: The fake op _fake_trtllm_fp8_block_scale_moe must
exactly mirror the real op signature trtllm_fp8_block_scale_moe_op: add the
missing parameter num_fused_shared_experts: int = 0 to the fake function
signature (position it where the real op declares it) so tracing/torch.compile
sees identical parameters; update any callers or tests if they rely on
positional args to ensure compatibility.

---

Duplicate comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 620-623: The check ensuring data.mNumFusedSharedExperts <=
WarpSize must be unconditional: move the
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if
(data.mNumExpertGroups > 1) block so it always runs; this prevents laneIdx <
mNumFusedSharedExperts conditions (used in the fused shared expert writes) from
silently skipping experts when data.mNumExpertGroups <= 1 and
mNumFusedSharedExperts > WarpSize. Ensure the unique symbols FLASHINFER_CHECK,
data.mNumFusedSharedExperts, WarpSize, data.mNumExpertGroups, and laneIdx are
referenced when relocating the check.
- Around line 666-676: The histogram for expert counts is initialized by the
routingInitExpertCounts kernel using the pre-mutation value of data.mNumExperts,
but data.mNumExperts is then increased by mNumFusedSharedExperts, leaving the
new fused-shared slots uninitialized; fix by either moving the mutation of
data.mNumExperts (and data.mTopK/data.mNumLocalExperts) to occur before the
LAUNCH_ROUTING_DEEPSEEK(...) call that invokes routingInitExpertCounts so the
kernel initializes the full range, or modify the initialization invocation to
cover (2 * (data.mNumExperts + data.mNumFusedSharedExperts)) elements (or
equivalent) so routingInitExpertCounts explicitly zeroes the fused-shared
indices; update references to routingInitExpertCounts, data.mNumExperts,
data.mNumFusedSharedExperts, and the LAUNCH_ROUTING_DEEPSEEK call accordingly.

---

Nitpick comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1757-1759: The assignment uses an unnecessary None check for
num_fused_shared_experts (typed as int with default 0); simplify by removing the
conditional and directly assign _nfse = num_fused_shared_experts in the scope
where num_fused_shared_experts is passed (refer to the variables
num_fused_shared_experts and _nfse in this function/class), ensuring no Optional
handling remains.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e7cc95b5-315d-4900-b7a7-8eb9b5984c8d

📥 Commits

Reviewing files that changed from the base of the PR and between 259d279 and 2255bca.

📒 Files selected for processing (7)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_runner.cu
  • flashinfer/fused_moe/core.py
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/moe/test_trtllm_gen_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h

Comment on lines +106 to +117
int32_t const numDevices = (localNumExperts > 0) ? numExperts / localNumExperts : 1;
int32_t const deviceIndex = (localNumExperts > 0) ? localExpertOffset / localNumExperts : 0;
int32_t const baseTokensPerDevice = numTokens / numDevices;
int32_t const remainingTokens = numTokens % numDevices;

if (deviceIndex < remainingTokens) {
routingData.mSharedExpertTokenOffset = (baseTokensPerDevice + 1) * deviceIndex;
routingData.mSharedExpertNumTokens = baseTokensPerDevice + 1;
} else {
routingData.mSharedExpertTokenOffset = remainingTokens + deviceIndex * baseTokensPerDevice;
routingData.mSharedExpertNumTokens = baseTokensPerDevice;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Shared-expert token partition assumes uniform expert shards.

numDevices = numExperts / localNumExperts and deviceIndex = localExpertOffset / localNumExperts are only correct when every rank owns the same routed-expert count. The visible checks here only require localExpertOffset + localNumExperts <= numExperts, so uneven sharding will compute the wrong mSharedExpertTokenOffset/mSharedExpertNumTokens range and route fused shared experts against the wrong token slice.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_runner.cu` around lines 106 - 117, The partition math
wrongly assumes uniform shards — replace the division-based computation
(numDevices, deviceIndex derived from numExperts / localNumExperts and
localExpertOffset / localNumExperts) with logic that computes device boundaries
from actual per-rank expert counts: build the cumulative expert-count prefix
(using the actual localNumExperts for each device/rank) to find the device index
and the exact token-offset/length for routingData.mSharedExpertTokenOffset and
routingData.mSharedExpertNumTokens; ensure you use numTokens scaled by each
device's expert count slice (not simple baseTokensPerDevice/remainingTokens
across a uniform numDevices), and reference localExpertOffset, localNumExperts,
numExperts when mapping into the cumulative ranges so uneven sharding yields
correct offsets and lengths.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #45731067: 8/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants