feat: Fuse shared experts into trtllm_gen moe (fp8)#2625
feat: Fuse shared experts into trtllm_gen moe (fp8)#2625nv-yunzheq wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the concept of 'fused shared experts' into the existing TensorRT-LLM MoE framework, particularly for FP8 operations. The primary goal is to optimize the handling of experts that are shared across multiple tokens, by incorporating them directly into the MoE kernel's routing and execution logic. This change impacts how memory is allocated, how routing decisions are made, and how the overall MoE computation is performed, leading to a more streamlined and potentially faster processing of MoE layers with shared components. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughAdds support for fused shared experts across fused-MoE: new parameter Changes
Sequence DiagramsequenceDiagram
participant Client
participant Runner as MoE Runner
participant Launcher as Kernel Launcher
participant Router as Routing Kernel
participant GEMM as GEMM Kernels
participant Tests as Test Harness
Client->>Runner: run(topK, numFusedSharedExpert, ...)
Runner->>Runner: totalExpertsPerToken = topK + numFusedSharedExpert
Runner->>Launcher: launch routing with totalExpertsPerToken, totalNumExperts
Launcher->>Router: execute routing (select per-token topK + fused)
Router->>Launcher: return routing indexes & histograms
Launcher->>GEMM: invoke PermuteGemm1/Gemm2 with totalExpertsPerToken/totalLocalExperts
Tests->>Client: validate outputs using total_experts-aware reference
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request introduces the fusion of shared experts into the trtllm_gen MoE implementation, specifically for FP8. The changes cover the routing kernel, the launcher, and the Python API. While the integration logic for shared experts is mostly sound, there are a few critical issues regarding histogram initialization and template dispatching in the routing kernel that could lead to undefined behavior or incorrect results in multi-GPU or large-token scenarios.
| if (data.mNumFusedSharedExperts > 0) { | ||
| data.mNumExperts += data.mNumFusedSharedExperts; | ||
| data.mTopK += data.mNumFusedSharedExperts; | ||
| data.mNumLocalExperts += data.mNumFusedSharedExperts; | ||
| } |
There was a problem hiding this comment.
Updating data.mNumExperts and data.mTopK after the first kernel launch (line 656 or 662) leads to several issues:
numThreadsMain(line 655) and the histogram initialization insideroutingMainKernel(line 85) use the original routed expert count, meaning the histogram entries for shared experts are never initialized to zero. This can cause garbage values to be used as offsets in subsequent permutation kernels.- The dispatching macro
LAUNCH_ROUTING_DEEPSEEKusesdata.mNumExpertsto select theMaxNumExpertstemplate parameter. If the total expert count (routed + shared) crosses a threshold (e.g., 256 to 257), the first and second launches will use different template instantiations, which is inconsistent.
You should calculate the total expert count and top-k at the beginning of runImpl and ensure that initialization kernels use the total count, while routingMainKernel receives the routed count for its indexing logic.
| FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, | ||
| "Number of fused shared experts (%d) must be less than warp size.", | ||
| data.mNumFusedSharedExperts); |
There was a problem hiding this comment.
The check for mNumFusedSharedExperts <= WarpSize is currently placed inside the if (data.mNumExpertGroups > 1) block. However, routingMainKernel always assumes that shared experts can be handled by a single warp (using laneIdx), regardless of whether expert groups are used. This check should be moved outside the conditional block to ensure it is always enforced.
| weight_layout=weight_layout, | ||
| do_finalize=do_finalize, | ||
| enable_pdl=enable_pdl, | ||
| num_fused_shared_experts=num_fused_shared_experts, |
There was a problem hiding this comment.
The num_fused_shared_experts parameter should be included in the instance_key used by the MoERunner (around line 1045). Since the kernel's performance and configuration depend on the total number of experts (routed + shared), omitting this from the key might lead to the autotuner returning a suboptimal tactic if multiple calls with different shared expert counts are made.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
810-821:⚠️ Potential issue | 🟠 MajorShape validation for precomputed routing with fused shared experts is inconsistent.
The shape check at line 818 validates
expert_indices.size(1) == args->top_k, but when fused shared experts are enabled, precomputed indices should account for the additional fused entries. At line 892,totalExpertsPerTokenis calculated asargs->top_k + args->num_fused_shared_experts, and theexpert_weightstensor is allocated with this dimension (line 897). If precomputed routing is used alongside fused shared experts, the shape validation should checkexpert_indices.size(1) == totalExpertsPerTokeninstead of justargs->top_kto ensure consistency with the routing output tensors.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 810 - 821, In check_routing(), the validation of expert_indices.dim(1) only compares to args->top_k but must account for fused shared experts; compute an expected width like int expectedPerToken = args->top_k + args->num_fused_shared_experts (or just use args->top_k when num_fused_shared_experts is zero) and replace the existing TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) with a check against expectedPerToken so precomputed routing matches the allocation for totalExpertsPerToken and expert_weights.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 616-619: The check ensuring data.mNumFusedSharedExperts <=
WarpSize must be unconditional because fused shared-expert writes use laneIdx <
mNumFusedSharedExperts regardless of expert group count; move the
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if
(data.mNumExpertGroups > 1) block so it always runs, ensuring
data.mNumFusedSharedExperts is validated before any code paths that use
mNumFusedSharedExperts/mNumFusedSharedExperts-induced lane comparisons or writes
(references: data.mNumFusedSharedExperts, WarpSize, mNumFusedSharedExperts,
mNumExpertGroups).
- Around line 571-574: routingInitExpertCounts currently initializes only 2 *
data.mNumExperts using the pre-fusion value then data.mNumExperts is incremented
to include mNumFusedSharedExperts, leaving histogram slots for fused-shared
experts uninitialized; fix by making the initialization cover the full fused
range (initialize 2 * (data.mNumExperts + data.mNumFusedSharedExperts)) or by
moving the mutation of data.mNumExperts (add mNumFusedSharedExperts) before
calling routingInitExpertCounts so the kernel initializes the correct size, and
ensure subsequent kernels that atomicAdd into expert-count slots will see zeros
for indices [original_mNumExperts, original_mNumExperts +
mNumFusedSharedExperts). Also move the check data.mNumFusedSharedExperts <=
WarpSize out of the if (data.mNumExpertGroups > 1) block so the
fused-shared-expert write logic (the unconditional write at the fused shared
expert site) consistently validates the WarpSize constraint regardless of group
count.
In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h`:
- Around line 102-107: The new fused-shared expert members
(mNumFusedSharedExperts, mSharedExpertTokenOffset, mSharedExpertNumTokens,
mTotalExpertsPerToken) are uninitialized; initialize them to the same safe
defaults used in KernelParamsBase (e.g., zero) by adding default member
initializers or setting them in the DataBase constructor so callers that don't
set them won't propagate garbage into kernel params and cause routing/OOB
errors.
---
Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 810-821: In check_routing(), the validation of
expert_indices.dim(1) only compares to args->top_k but must account for fused
shared experts; compute an expected width like int expectedPerToken =
args->top_k + args->num_fused_shared_experts (or just use args->top_k when
num_fused_shared_experts is zero) and replace the existing
TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) with a check against
expectedPerToken so precomputed routing matches the allocation for
totalExpertsPerToken and expert_weights.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
csrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_routing_deepseek.cucsrc/trtllm_fused_moe_runner.cuflashinfer/fused_moe/core.pyinclude/flashinfer/trtllm/fused_moe/RoutingKernel.hinclude/flashinfer/trtllm/fused_moe/runner.htests/moe/test_trtllm_gen_fused_moe.py
|
|
||
| FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, | ||
| "Number of fused shared experts (%d) must be less than warp size.", | ||
| data.mNumFusedSharedExperts); |
There was a problem hiding this comment.
fusedSharedExperts <= WarpSize check should be unconditional.
This validation is guarded by if (data.mNumExpertGroups > 1) (line 605), but the fused shared expert writes at lines 261-265 and 272-274 use laneIdx < mNumFusedSharedExperts regardless of expert groups. If mNumExpertGroups <= 1 and mNumFusedSharedExperts > WarpSize, the writes would silently skip some fused experts.
Suggested fix
Move the check out of the if (data.mNumExpertGroups > 1) block:
+ FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
+ "Number of fused shared experts (%d) must be less than warp size.",
+ data.mNumFusedSharedExperts);
+
if (data.mNumExpertGroups > 1) {
FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups,
...);
...
-
- FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
- "Number of fused shared experts (%d) must be less than warp size.",
- data.mNumFusedSharedExperts);
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 616 - 619, The check
ensuring data.mNumFusedSharedExperts <= WarpSize must be unconditional because
fused shared-expert writes use laneIdx < mNumFusedSharedExperts regardless of
expert group count; move the FLASHINFER_CHECK(data.mNumFusedSharedExperts <=
WarpSize, ...) out of the if (data.mNumExpertGroups > 1) block so it always
runs, ensuring data.mNumFusedSharedExperts is validated before any code paths
that use mNumFusedSharedExperts/mNumFusedSharedExperts-induced lane comparisons
or writes (references: data.mNumFusedSharedExperts, WarpSize,
mNumFusedSharedExperts, mNumExpertGroups).
|
|
||
| /// For fused shared expert | ||
| int32_t mNumFusedSharedExperts; | ||
| int32_t mSharedExpertTokenOffset; | ||
| int32_t mSharedExpertNumTokens; | ||
| int32_t mTotalExpertsPerToken; |
There was a problem hiding this comment.
Initialize fused-shared expert metadata fields to safe defaults.
These new DataBase members are currently uninitialized, so any caller that forgets to set them will propagate garbage into kernel params and risk incorrect routing / OOB indexing. Mirror the KernelParamsBase defaults.
🛠️ Suggested fix
- int32_t mNumFusedSharedExperts;
- int32_t mSharedExpertTokenOffset;
- int32_t mSharedExpertNumTokens;
- int32_t mTotalExpertsPerToken;
+ int32_t mNumFusedSharedExperts{0};
+ int32_t mSharedExpertTokenOffset{0};
+ int32_t mSharedExpertNumTokens{0};
+ int32_t mTotalExpertsPerToken{0};📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| /// For fused shared expert | |
| int32_t mNumFusedSharedExperts; | |
| int32_t mSharedExpertTokenOffset; | |
| int32_t mSharedExpertNumTokens; | |
| int32_t mTotalExpertsPerToken; | |
| /// For fused shared expert | |
| int32_t mNumFusedSharedExperts{0}; | |
| int32_t mSharedExpertTokenOffset{0}; | |
| int32_t mSharedExpertNumTokens{0}; | |
| int32_t mTotalExpertsPerToken{0}; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h` around lines 102 - 107,
The new fused-shared expert members (mNumFusedSharedExperts,
mSharedExpertTokenOffset, mSharedExpertNumTokens, mTotalExpertsPerToken) are
uninitialized; initialize them to the same safe defaults used in
KernelParamsBase (e.g., zero) by adding default member initializers or setting
them in the DataBase constructor so callers that don't set them won't propagate
garbage into kernel params and cause routing/OOB errors.
|
[FAILED] Pipeline #44669282: 13/20 passed |
259d279 to
2255bca
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
flashinfer/fused_moe/core.py (1)
1804-1836:⚠️ Potential issue | 🟠 MajorFake op signature mismatch: missing
num_fused_shared_expertsparameter.The
_fake_trtllm_fp8_block_scale_moefunction signature must exactly mirror the real optrtllm_fp8_block_scale_moe_op. The real op hasnum_fused_shared_experts: int = 0at line 1661, but the fake op is missing this parameter. This will cause issues with torch.compile or other tracing scenarios.Suggested fix
`@register_fake_op`("flashinfer::trtllm_fp8_block_scale_moe") def _fake_trtllm_fp8_block_scale_moe( routing_logits: Optional[torch.Tensor], topk_ids: Optional[torch.Tensor], expert_weights: Optional[torch.Tensor], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor, gemm1_weights: torch.Tensor, gemm1_weights_scale: torch.Tensor, gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, output: torch.Tensor, num_experts: int, top_k: int, n_group: Optional[int], topk_group: Optional[int], intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, + num_fused_shared_experts: int = 0, ) -> List[torch.Tensor]:Based on learnings: "When reviewing files that define fake ops decorated with register_fake_op (e.g., in flashinfer/fused_moe/*), ensure the function signatures exactly mirror the real op they stand in for."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1804 - 1836, The fake op _fake_trtllm_fp8_block_scale_moe must exactly mirror the real op signature trtllm_fp8_block_scale_moe_op: add the missing parameter num_fused_shared_experts: int = 0 to the fake function signature (position it where the real op declares it) so tracing/torch.compile sees identical parameters; update any callers or tests if they rely on positional args to ensure compatibility.csrc/trtllm_fused_moe_kernel_launcher.cu (2)
1857-1874:⚠️ Potential issue | 🟠 MajorValidate
num_fused_shared_expertsbefore using it in size math.The new FFI parameter is folded directly into
totalExpertsPerTokenandtotalLocalExperts. A negative value can drive those counts to zero or below and break tile selection/workspace sizing before any lower-layer routing checks run.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1857 - 1874, The code uses the FFI parameter num_fused_shared_experts directly in size math (totalExpertsPerToken, totalLocalExperts) which can be negative; validate and clamp it before use (e.g., ensure num_fused_shared_experts >= 0 and fits expected bounds) and reject or adjust invalid values; specifically, check/convert num_fused_shared_experts (and the optional num_fused_shared_experts.value_or(0)) to a non-negative int64_t before computing totalExpertsPerToken and totalLocalExperts, and add a defensive check that aborts or logs an error if the provided FFI value is out of acceptable range so computeSelectedTileN and downstream launchers (Fp8BlockScaleLauncher::getSupportedTileNums, computeSelectedTileN, MoERunnerArgs) never see negative counts.
919-929:⚠️ Potential issue | 🟠 MajorPrecomputed routing tensors still use the old
top_kwidth.Only the internally allocated
expert_weightsbuffer is widened totop_k + num_fused_shared_experts. If the caller provides precomputedexpert_indices/expert_weights, this path still follows the oldtop_kcontract elsewhere, so fused-shared precomputed routing will either reject correctly sized tensors or consume too few columns during finalize.
♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (2)
620-623:⚠️ Potential issue | 🟡 Minor
mNumFusedSharedExperts <= WarpSizecheck should be unconditional.This validation is guarded by
if (data.mNumExpertGroups > 1)(line 609), but the fused shared expert writes at lines 261-265 and 272-274 uselaneIdx < mNumFusedSharedExpertsregardless of expert groups. IfmNumExpertGroups <= 1andmNumFusedSharedExperts > WarpSize, the writes would silently skip some fused experts.Suggested fix
Move the check out of the
if (data.mNumExpertGroups > 1)block:+ FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, + "Number of fused shared experts (%d) must be less than warp size.", + data.mNumFusedSharedExperts); + if (data.mNumExpertGroups > 1) { FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups, ...); ... - - FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, - "Number of fused shared experts (%d) must be less than warp size.", - data.mNumFusedSharedExperts); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 620 - 623, The check ensuring data.mNumFusedSharedExperts <= WarpSize must be unconditional: move the FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if (data.mNumExpertGroups > 1) block so it always runs; this prevents laneIdx < mNumFusedSharedExperts conditions (used in the fused shared expert writes) from silently skipping experts when data.mNumExpertGroups <= 1 and mNumFusedSharedExperts > WarpSize. Ensure the unique symbols FLASHINFER_CHECK, data.mNumFusedSharedExperts, WarpSize, data.mNumExpertGroups, and laneIdx are referenced when relocating the check.
666-676:⚠️ Potential issue | 🟠 MajorExpert count histogram not initialized for fused shared expert indices.
The
routingInitExpertCountskernel (line 666-669) initializes2 * data.mNumExpertselements using the pre-mutation value. After the kernel completes,data.mNumExpertsis incremented at lines 673-675 to includemNumFusedSharedExperts. Subsequent kernels (lines 678+) use the mutated value but access uninitialized histogram slots for indices[original_mNumExperts, original_mNumExperts + mNumFusedSharedExperts).This causes atomicAdd operations to accumulate into uninitialized values for fused shared expert slots.
Suggested fix
Either move the mutation before the histogram initialization or expand the initialization range:
+ if (data.mNumFusedSharedExperts > 0) { + data.mNumExperts += data.mNumFusedSharedExperts; + data.mTopK += data.mNumFusedSharedExperts; + data.mNumLocalExperts += data.mNumFusedSharedExperts; + } + if (data.mPtrTopKIds == nullptr) { ... } else { // Reset the global histograms. LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, /*smemSize=*/0, stream, data.mNumExpertGroups > 1); } - if (data.mNumFusedSharedExperts > 0) { - data.mNumExperts += data.mNumFusedSharedExperts; - data.mTopK += data.mNumFusedSharedExperts; - data.mNumLocalExperts += data.mNumFusedSharedExperts; - }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 666 - 676, The histogram for expert counts is initialized by the routingInitExpertCounts kernel using the pre-mutation value of data.mNumExperts, but data.mNumExperts is then increased by mNumFusedSharedExperts, leaving the new fused-shared slots uninitialized; fix by either moving the mutation of data.mNumExperts (and data.mTopK/data.mNumLocalExperts) to occur before the LAUNCH_ROUTING_DEEPSEEK(...) call that invokes routingInitExpertCounts so the kernel initializes the full range, or modify the initialization invocation to cover (2 * (data.mNumExperts + data.mNumFusedSharedExperts)) elements (or equivalent) so routingInitExpertCounts explicitly zeroes the fused-shared indices; update references to routingInitExpertCounts, data.mNumExperts, data.mNumFusedSharedExperts, and the LAUNCH_ROUTING_DEEPSEEK call accordingly.
🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)
1757-1759: Redundant None check fornum_fused_shared_experts.Since
num_fused_shared_expertsis typed asint = 0at line 1661 (notOptional[int]), the None check on line 1759 is unnecessary. The parameter can never beNoneat this point.Suggested simplification
- _nfse = num_fused_shared_experts if num_fused_shared_experts is not None else 0 + _nfse = num_fused_shared_experts🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1757 - 1759, The assignment uses an unnecessary None check for num_fused_shared_experts (typed as int with default 0); simplify by removing the conditional and directly assign _nfse = num_fused_shared_experts in the scope where num_fused_shared_experts is passed (refer to the variables num_fused_shared_experts and _nfse in this function/class), ensuring no Optional handling remains.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_runner.cu`:
- Around line 106-117: The partition math wrongly assumes uniform shards —
replace the division-based computation (numDevices, deviceIndex derived from
numExperts / localNumExperts and localExpertOffset / localNumExperts) with logic
that computes device boundaries from actual per-rank expert counts: build the
cumulative expert-count prefix (using the actual localNumExperts for each
device/rank) to find the device index and the exact token-offset/length for
routingData.mSharedExpertTokenOffset and routingData.mSharedExpertNumTokens;
ensure you use numTokens scaled by each device's expert count slice (not simple
baseTokensPerDevice/remainingTokens across a uniform numDevices), and reference
localExpertOffset, localNumExperts, numExperts when mapping into the cumulative
ranges so uneven sharding yields correct offsets and lengths.
---
Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1857-1874: The code uses the FFI parameter
num_fused_shared_experts directly in size math (totalExpertsPerToken,
totalLocalExperts) which can be negative; validate and clamp it before use
(e.g., ensure num_fused_shared_experts >= 0 and fits expected bounds) and reject
or adjust invalid values; specifically, check/convert num_fused_shared_experts
(and the optional num_fused_shared_experts.value_or(0)) to a non-negative
int64_t before computing totalExpertsPerToken and totalLocalExperts, and add a
defensive check that aborts or logs an error if the provided FFI value is out of
acceptable range so computeSelectedTileN and downstream launchers
(Fp8BlockScaleLauncher::getSupportedTileNums, computeSelectedTileN,
MoERunnerArgs) never see negative counts.
In `@flashinfer/fused_moe/core.py`:
- Around line 1804-1836: The fake op _fake_trtllm_fp8_block_scale_moe must
exactly mirror the real op signature trtllm_fp8_block_scale_moe_op: add the
missing parameter num_fused_shared_experts: int = 0 to the fake function
signature (position it where the real op declares it) so tracing/torch.compile
sees identical parameters; update any callers or tests if they rely on
positional args to ensure compatibility.
---
Duplicate comments:
In `@csrc/trtllm_fused_moe_routing_deepseek.cu`:
- Around line 620-623: The check ensuring data.mNumFusedSharedExperts <=
WarpSize must be unconditional: move the
FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, ...) out of the if
(data.mNumExpertGroups > 1) block so it always runs; this prevents laneIdx <
mNumFusedSharedExperts conditions (used in the fused shared expert writes) from
silently skipping experts when data.mNumExpertGroups <= 1 and
mNumFusedSharedExperts > WarpSize. Ensure the unique symbols FLASHINFER_CHECK,
data.mNumFusedSharedExperts, WarpSize, data.mNumExpertGroups, and laneIdx are
referenced when relocating the check.
- Around line 666-676: The histogram for expert counts is initialized by the
routingInitExpertCounts kernel using the pre-mutation value of data.mNumExperts,
but data.mNumExperts is then increased by mNumFusedSharedExperts, leaving the
new fused-shared slots uninitialized; fix by either moving the mutation of
data.mNumExperts (and data.mTopK/data.mNumLocalExperts) to occur before the
LAUNCH_ROUTING_DEEPSEEK(...) call that invokes routingInitExpertCounts so the
kernel initializes the full range, or modify the initialization invocation to
cover (2 * (data.mNumExperts + data.mNumFusedSharedExperts)) elements (or
equivalent) so routingInitExpertCounts explicitly zeroes the fused-shared
indices; update references to routingInitExpertCounts, data.mNumExperts,
data.mNumFusedSharedExperts, and the LAUNCH_ROUTING_DEEPSEEK call accordingly.
---
Nitpick comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1757-1759: The assignment uses an unnecessary None check for
num_fused_shared_experts (typed as int with default 0); simplify by removing the
conditional and directly assign _nfse = num_fused_shared_experts in the scope
where num_fused_shared_experts is passed (refer to the variables
num_fused_shared_experts and _nfse in this function/class), ensuring no Optional
handling remains.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e7cc95b5-315d-4900-b7a7-8eb9b5984c8d
📒 Files selected for processing (7)
csrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_routing_deepseek.cucsrc/trtllm_fused_moe_runner.cuflashinfer/fused_moe/core.pyinclude/flashinfer/trtllm/fused_moe/RoutingKernel.hinclude/flashinfer/trtllm/fused_moe/runner.htests/moe/test_trtllm_gen_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
- include/flashinfer/trtllm/fused_moe/RoutingKernel.h
| int32_t const numDevices = (localNumExperts > 0) ? numExperts / localNumExperts : 1; | ||
| int32_t const deviceIndex = (localNumExperts > 0) ? localExpertOffset / localNumExperts : 0; | ||
| int32_t const baseTokensPerDevice = numTokens / numDevices; | ||
| int32_t const remainingTokens = numTokens % numDevices; | ||
|
|
||
| if (deviceIndex < remainingTokens) { | ||
| routingData.mSharedExpertTokenOffset = (baseTokensPerDevice + 1) * deviceIndex; | ||
| routingData.mSharedExpertNumTokens = baseTokensPerDevice + 1; | ||
| } else { | ||
| routingData.mSharedExpertTokenOffset = remainingTokens + deviceIndex * baseTokensPerDevice; | ||
| routingData.mSharedExpertNumTokens = baseTokensPerDevice; | ||
| } |
There was a problem hiding this comment.
Shared-expert token partition assumes uniform expert shards.
numDevices = numExperts / localNumExperts and deviceIndex = localExpertOffset / localNumExperts are only correct when every rank owns the same routed-expert count. The visible checks here only require localExpertOffset + localNumExperts <= numExperts, so uneven sharding will compute the wrong mSharedExpertTokenOffset/mSharedExpertNumTokens range and route fused shared experts against the wrong token slice.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_runner.cu` around lines 106 - 117, The partition math
wrongly assumes uniform shards — replace the division-based computation
(numDevices, deviceIndex derived from numExperts / localNumExperts and
localExpertOffset / localNumExperts) with logic that computes device boundaries
from actual per-rank expert counts: build the cumulative expert-count prefix
(using the actual localNumExperts for each device/rank) to find the device index
and the exact token-offset/length for routingData.mSharedExpertTokenOffset and
routingData.mSharedExpertNumTokens; ensure you use numTokens scaled by each
device's expert count slice (not simple baseTokensPerDevice/remainingTokens
across a uniform numDevices), and reference localExpertOffset, localNumExperts,
numExperts when mapping into the cumulative ranges so uneven sharding yields
correct offsets and lengths.
|
[FAILED] Pipeline #45731067: 8/20 passed |
For #2551
Integrating NVIDIA/TensorRT-LLM#11143
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Tests
Bug Fixes