Skip to content

Added mxint4 routed moe#2669

Open
zyongye wants to merge 5 commits intoflashinfer-ai:mainfrom
zyongye:mxint4_routed_moe
Open

Added mxint4 routed moe#2669
zyongye wants to merge 5 commits intoflashinfer-ai:mainfrom
zyongye:mxint4_routed_moe

Conversation

@zyongye
Copy link

@zyongye zyongye commented Mar 2, 2026

📌 Description

Add mxint4 routed moe version

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a routed MoE entry that accepts pre-computed packed routing (top-k indices and weights) and optional in-place final output, enabling a routed MXInt4 pathway alongside logits-based routing.
    • Public APIs extended to accept optional routing inputs and return in-place output when requested.
  • Tests

    • Added MXInt4 routed MoE tests validating pre-computed routing outputs against reference implementations.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Mixture-of-Experts (MoE) functionality by adding support for MXInt4 block-scale MoE with pre-computed routing. This change is crucial for optimizing performance in advanced use cases, such as CUDA Graph capture and distributed MoE systems, where routing decisions can be prepared beforehand. By allowing external routing inputs, the system gains flexibility and efficiency, avoiding redundant computations and enabling more complex inference pipelines.

Highlights

  • MXInt4 Routed MoE Support: Introduced a new trtllm_mxint4_block_scale_routed_moe function that allows for pre-computed routing in MXInt4 block-scale Mixture-of-Experts (MoE) operations. This enables scenarios where routing decisions are determined externally or packed for efficiency.
  • Flexible Routing Inputs: Modified the underlying C++ kernel and Python API to accept either traditional routing_logits (for on-the-fly routing computation) or pre-computed expert_indices and expert_weights (for routed MoE), providing greater flexibility in MoE execution.
  • Enhanced Kernel Logic: Updated the MxInt4BlockScaleLauncher in the CUDA kernel to correctly handle and validate pre-computed expert_indices and expert_weights, ensuring proper data flow and checks for the new routing mechanism.
  • Comprehensive Testing: Added a dedicated test case to verify the correctness of the MXInt4 routed fused MoE implementation against the standard routing approach, ensuring functional parity and reliability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Updated MxInt4BlockScaleLauncher constructor to accept expert_indices and expert_weights_in.
    • Modified check_routing to validate pre-computed expert_indices dimensions and data type.
    • Adjusted prepare_routing to use either routing_logits or pre-computed expert_indices and expert_weights_in for routing.
    • Overrode the run method in MxInt4BlockScaleLauncher to integrate the pre-computed routing logic.
  • flashinfer/fused_moe/init.py
    • Imported the new trtllm_mxint4_block_scale_routed_moe function.
    • Added trtllm_mxint4_block_scale_routed_moe to the module's __all__ export list.
  • flashinfer/fused_moe/core.py
    • Modified trtllm_mxint4_block_scale_moe_op to accept topk_ids and expert_weights as optional inputs, making routing_logits optional.
    • Updated the logic within trtllm_mxint4_block_scale_moe_op to handle the presence or absence of routing_logits and topk_ids for routing.
    • Introduced a new API function trtllm_mxint4_block_scale_routed_moe for executing MoE with pre-computed routing, passing None for routing_logits and using topk_ids.
  • tests/moe/test_trtllm_gen_routed_fused_moe.py
    • Imported additional utility functions for quantization and weight shuffling (convert_to_block_layout, mxint4_quantize, block_scale_interleave, _maybe_get_cached_w3_w1_permute_indices, get_w2_permute_indices_with_cache).
    • Added test_trtllm_gen_mxint4_routed_fused_moe to test the new MXInt4 routed MoE functionality, comparing its output against the standard trtllm_mxint4_block_scale_moe.
Activity
  • Pre-commit checks have been installed and run, with reported issues fixed.
  • Tests have been added and updated as needed, and all tests are passing.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds pre-computed (packed top-k) routing support to fused MoE: C++ launchers and Python ops now accept expert indices and expert weights, include validation and wiring for dual routing modes (routing_logits or pre-computed), and a new MXInt4 routed test was added.

Changes

Cohort / File(s) Summary
Kernel launcher / CUDA
csrc/trtllm_fused_moe_kernel_launcher.cu
MxInt4BlockScaleLauncher constructor, members, routing checks, prepare_routing, and run() updated to accept and handle optional expert_indices and expert_weights_in, enabling pre-computed routing path alongside routing_logits.
Python core API
flashinfer/fused_moe/core.py
Made routing_logits, topk_ids, and expert_weights optional in op/fake-op; added trtllm_mxint4_block_scale_routed_moe; added optional output (in-place finalization) and routing-dtype selection; backend call paths forward precomputed routing when present.
Package exports
flashinfer/fused_moe/__init__.py
Exported new symbol trtllm_mxint4_block_scale_routed_moe and added it to __all__.
Tests
tests/moe/test_trtllm_gen_routed_fused_moe.py
New MXInt4 routed MoE test: prepares MXInt4-quantized, permuted, block-layout weights; packs top-k ids & expert weights; calls routed API and compares outputs to reference.
Public API surface / Signatures
flashinfer/fused_moe/...
Extended public signatures (MXInt4 MOE factory/launcher and trtllm_mxint4_block_scale_moe) to accept optional precomputed routing inputs and optional in-place output across precision modes.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant MoE_API
    participant MoE_Kernel
    participant Expert

    rect rgba(100,150,200,0.5)
        Note over Client,MoE_Kernel: Routing via logits
        Client->>MoE_API: trtllm_mxint4_block_scale_moe(routing_logits,...)
        MoE_API->>MoE_Kernel: call with routing_logits
        MoE_Kernel->>MoE_Kernel: compute top-k & expert weights
        MoE_Kernel->>Expert: dispatch per-expert GEMM
        Expert-->>MoE_Kernel: aggregated outputs
        MoE_Kernel-->>MoE_API: final output
        MoE_API-->>Client: result
    end

    rect rgba(150,200,100,0.5)
        Note over Client,MoE_Kernel: Pre-computed routing (packed)
        Client->>MoE_API: trtllm_mxint4_block_scale_routed_moe(topk_ids,expert_weights,...)
        MoE_API->>MoE_Kernel: call with expert_indices & expert_weights
        MoE_Kernel->>MoE_Kernel: validate/wire precomputed routing
        MoE_Kernel->>Expert: dispatch per-expert GEMM using precomputed indices/weights
        Expert-->>MoE_Kernel: aggregated outputs
        MoE_Kernel-->>MoE_API: final output
        MoE_API-->>Client: result
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

op: moe-routing

Suggested reviewers

  • yzh119
  • cyx-6
  • bkryu
  • jimmyzho
  • nv-yunzheq

Poem

🐰 I packed my hops in top‑k rows,
indices snug where the cool wind blows.
Experts line up, weights held tight,
no softmax fuss — I hop just right.
Routed MoE, a joyful flight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.89% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding MXInt4 routed MOE functionality.
Description check ✅ Passed The description follows the template structure with checked checklist items and section headings present, though the description content is minimal.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for pre-computed routing in the mxint4 MoE kernel, which is a valuable optimization for scenarios like CUDA graph capture. The changes are well-structured across the C++ kernel, Python bindings, and tests. My review focuses on improving code quality by addressing some minor code duplication and redundancies. Overall, this is a solid contribution.

if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) {
// Pre-computed routing: expert_indices is a packed tensor
// Format: (expert_id << 16) | (weight_bf16.view(int16))
TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check for expert_indices.ndim() is redundant, as the if condition on line 1156 already validates that expert_indices.ndim() == 2. Removing this line will make the code cleaner.

expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
bool has_precomputed_indices = expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition expert_indices.ndim() == 2 && expert_indices.size(0) > 0 is repeated in check_routing, prepare_routing, and run. To improve maintainability and reduce code duplication, consider encapsulating this logic in a private helper method, for example:

private:
  bool has_precomputed_indices() const {
    // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
    return expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
  }

You could then call has_precomputed_indices() where this check is needed. A similar helper could be created for checking pre-computed weights.

# When routing_logits is None, we either have topk_ids/expert_weights,
# packed into a single tensor as topk_ids
# or have them individually as topk_ids and expert_weights respectively
topk_ids = topk_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assignment topk_ids = topk_ids is a no-op and can be removed for better code clarity. It appears to be a leftover from refactoring.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/fused_moe/core.py (2)

2125-2131: ⚠️ Potential issue | 🔴 Critical

Routed mode can break autotuning by passing None as routing_logits input.

When routing_logits is None, this inputs list still includes None; later tuner runner paths dereference routing_logits.shape[0].

🔧 Suggested fix
-        inputs = [
-            output,
-            routing_logits,
-            topk_ids,
-            expert_weights,
-            hidden_states,
-        ]
+        routing_logits_for_tuning = (
+            torch.empty(
+                num_tokens, num_experts, dtype=routing_dtype, device="meta"
+            )
+            if routing_logits is None
+            else routing_logits
+        )
+        inputs = [
+            output,
+            routing_logits_for_tuning,
+            topk_ids,
+            expert_weights,
+            hidden_states,
+        ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2125 - 2131, The inputs list
construction currently includes routing_logits even when it's None, which later
causes dereferences like routing_logits.shape[0]; update the code that builds
the inputs list (the variable named inputs in fused_moe/core.py) to omit
routing_logits when it is None (or replace it with a safe sentinel tensor) so
downstream tuner/runner paths never receive a None; specifically, gate the
inclusion of routing_logits in the inputs list (or ensure routing_logits is a
valid tensor before building inputs) so references to routing_logits.shape[...]
are safe.

2055-2090: ⚠️ Potential issue | 🟠 Major

Honor the caller-provided output tensor instead of always reallocating.

Line 2088 always creates a new tensor, so an explicit output passed by callers is silently ignored.

🔧 Suggested fix
-        output: torch.Tensor,
+        output: Optional[torch.Tensor],
@@
-        # Create workspace buffers
-        output = torch.empty(
-            num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
-        )
+        # Create workspace buffers
+        if output is None:
+            output = torch.empty(
+                num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
+            )
+        else:
+            check_shape_dtype_device(
+                output,
+                (num_tokens, hidden_size),
+                torch.bfloat16,
+                hidden_states.device,
+                "output",
+            )
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1191-1200: The code currently forwards
expert_weights_in.data_ptr() to workspace.expert_weights based only on ndim and
non-empty size; instead, validate the full shape, dtype and device before
binding raw pointers: ensure expert_weights_in.ndim() == 2,
expert_weights_in.size(0) == args->num_tokens, expert_weights_in.size(1) ==
args->top_k, expert_weights_in.dtype() == dl_bfloat16 and
expert_weights_in.device() == hidden_states.device(); only if all checks pass
set workspace.expert_weights = const_cast<void*>(expert_weights_in.data_ptr()),
otherwise allocate expert_weights via alloc_tensor({args->num_tokens,
args->top_k}, dl_bfloat16, hidden_states.device()) and set
workspace.expert_weights to that buffer; also ensure the tensor is contiguous
(or document required layout) before passing its data_ptr().

---

Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2125-2131: The inputs list construction currently includes
routing_logits even when it's None, which later causes dereferences like
routing_logits.shape[0]; update the code that builds the inputs list (the
variable named inputs in fused_moe/core.py) to omit routing_logits when it is
None (or replace it with a safe sentinel tensor) so downstream tuner/runner
paths never receive a None; specifically, gate the inclusion of routing_logits
in the inputs list (or ensure routing_logits is a valid tensor before building
inputs) so references to routing_logits.shape[...] are safe.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 266167d and 0ac1fac.

📒 Files selected for processing (4)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/fused_moe/__init__.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_trtllm_gen_routed_fused_moe.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2087-2105: The function currently overwrites the provided output
buffer by unconditionally allocating a new tensor named output; change the
allocation logic in fused_moe.core (the function around where output is created)
to respect an incoming output parameter by only allocating torch.empty(...) when
output is None or not provided, and otherwise validate that the provided output
has the expected shape/dtype/device before using it; mirror the conditional
pattern used in trtllm_fp4_block_scale_moe_op for in-place support, and update
the function signature/annotation for output to Optional[torch.Tensor] to
reflect that None is acceptable.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0ac1fac and 3c9f17f.

📒 Files selected for processing (2)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/fused_moe/core.py

@IwakuraRein
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !365 has been created, and the CI pipeline #45182758 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1165-1173: ⚠️ Potential issue | 🟠 Major

Precomputed expert_weights_in still needs device validation before raw-pointer binding.

expert_weights_in.data_ptr() is bound directly without checking it is on the same device as hidden_states. Cross-device binding here can cause invalid access during routing.

🔧 Proposed fix
     if (has_precomputed_weights()) {
       // Pre-computed expert weights: validate shape and dtype
       TVM_FFI_ICHECK_EQ(expert_weights_in.size(0), hidden_states.size(0))
           << "expert_weights_in and hidden_states must have same number of tokens.";
       TVM_FFI_ICHECK_EQ(expert_weights_in.size(1), args->top_k)
           << "expert_weights_in dim1 must match top_k.";
       TVM_FFI_ICHECK_EQ(expert_weights_in.dtype(), dl_bfloat16)
           << "expert_weights_in must be bfloat16.";
+      TVM_FFI_ICHECK_EQ(expert_weights_in.device().device_type, hidden_states.device().device_type)
+          << "expert_weights_in must be on the same device type as hidden_states.";
+      TVM_FFI_ICHECK_EQ(expert_weights_in.device().device_id, hidden_states.device().device_id)
+          << "expert_weights_in must be on the same device id as hidden_states.";
     }

Also applies to: 1197-1199

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1165 - 1173, The
precomputed expert_weights_in is validated for shape and dtype but not for
device, so before calling expert_weights_in.data_ptr() (in the
has_precomputed_weights() branch and the other binding site referencing
expert_weights_in.data_ptr()) add a device check to ensure
expert_weights_in.device() (or .is_cuda and .device.index) matches
hidden_states.device() (or the CUDA device used by hidden_states) and error out
(or copy/move the tensor) with a clear message if they differ; update both the
check block around has_precomputed_weights() and the other raw-pointer binding
locations that use expert_weights_in.data_ptr() accordingly.
flashinfer/fused_moe/core.py (1)

2087-2098: ⚠️ Potential issue | 🔴 Critical

Provided output is not shape-validated before kernel use.

Only dtype/device are validated. A mismatched shape can lead to invalid writes when the kernel assumes [num_tokens, hidden_size].

🛡️ Proposed fix
         else:
             check_shape_dtype_device(
-                output, None, torch.bfloat16, hidden_states.device, "output"
+                output,
+                (num_tokens, hidden_size),
+                torch.bfloat16,
+                hidden_states.device,
+                "output",
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2087 - 2098, The provided output
buffer isn't shape-validated before kernel use, risking invalid writes; update
the validation call in the block that handles a provided output so
check_shape_dtype_device verifies shape (num_tokens, hidden_size) as well as
dtype and device (i.e., replace the None shape argument with the tuple
(num_tokens, hidden_size) when calling check_shape_dtype_device for "output"),
ensuring the output tensor matches hidden_states' expected [num_tokens,
hidden_size] layout before the kernel runs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2069-2077: The code passes routing_logits (which can be None in
routed mode) into the AutoTuner/runner, and the runner later dereferences
routing_logits.shape[0], causing a crash; update the
trtllm_mxint4_block_scale_moe_op call site to create a safe substitute when
routing_logits is None (e.g., compute routing_dtype as you already do, then set
routing_logits_for_tuner = routing_logits if routing_logits is not None else
torch.empty((topk_ids.shape[0], 0), dtype=routing_dtype, device=topk_ids.device)
or alternatively pass topk_ids-derived shape info instead), and pass
routing_logits_for_tuner into AutoTuner/runner; repeat the same defensive
substitution where the other similar call is made (the block around the other
occurrence) so the runner never dereferences None.

---

Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1165-1173: The precomputed expert_weights_in is validated for
shape and dtype but not for device, so before calling
expert_weights_in.data_ptr() (in the has_precomputed_weights() branch and the
other binding site referencing expert_weights_in.data_ptr()) add a device check
to ensure expert_weights_in.device() (or .is_cuda and .device.index) matches
hidden_states.device() (or the CUDA device used by hidden_states) and error out
(or copy/move the tensor) with a clear message if they differ; update both the
check block around has_precomputed_weights() and the other raw-pointer binding
locations that use expert_weights_in.data_ptr() accordingly.

In `@flashinfer/fused_moe/core.py`:
- Around line 2087-2098: The provided output buffer isn't shape-validated before
kernel use, risking invalid writes; update the validation call in the block that
handles a provided output so check_shape_dtype_device verifies shape
(num_tokens, hidden_size) as well as dtype and device (i.e., replace the None
shape argument with the tuple (num_tokens, hidden_size) when calling
check_shape_dtype_device for "output"), ensuring the output tensor matches
hidden_states' expected [num_tokens, hidden_size] layout before the kernel runs.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3c9f17f and d26c745.

📒 Files selected for processing (2)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/fused_moe/core.py

Copy link
Collaborator

@IwakuraRein IwakuraRein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@aleozlx aleozlx added the op: moe label Mar 3, 2026
@aleozlx aleozlx added the run-ci label Mar 3, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/fused_moe/core.py (1)

2095-2098: ⚠️ Potential issue | 🟡 Minor

Missing shape validation for user-provided output tensor.

When the user provides an output tensor, only dtype and device are validated, but shape is not checked. The trtllm_fp4_block_scale_moe_op (lines 1883-1888) includes additional shape assertions that are missing here. This inconsistency could allow incorrectly sized output tensors to pass validation silently.

🛡️ Proposed fix to add shape validation
         else:
             check_shape_dtype_device(
                 output, None, torch.bfloat16, hidden_states.device, "output"
             )
+            assert output.shape[0] == num_tokens, (
+                f"output.shape[0]={output.shape[0]} must be equal to {num_tokens}"
+            )
+            assert output.shape[1] == hidden_size, (
+                f"output.shape[1]={output.shape[1]} must be equal to {hidden_size}"
+            )
         if routing_logits is not None:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2095 - 2098, The user-provided
output tensor currently only has dtype/device checked via
check_shape_dtype_device(output, None, torch.bfloat16, hidden_states.device,
"output"); update this to validate shape as well (match hidden_states' expected
shape) — either call check_shape_dtype_device with the expected shape (e.g.,
hidden_states.shape or the computed output shape used elsewhere) instead of
None, or add an explicit assertion comparing output.shape to the expected shape;
mirror the shape assertions used in trtllm_fp4_block_scale_moe_op to ensure
consistency.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2095-2098: The user-provided output tensor currently only has
dtype/device checked via check_shape_dtype_device(output, None, torch.bfloat16,
hidden_states.device, "output"); update this to validate shape as well (match
hidden_states' expected shape) — either call check_shape_dtype_device with the
expected shape (e.g., hidden_states.shape or the computed output shape used
elsewhere) instead of None, or add an explicit assertion comparing output.shape
to the expected shape; mirror the shape assertions used in
trtllm_fp4_block_scale_moe_op to ensure consistency.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d26c745 and 6131bd3.

📒 Files selected for processing (1)
  • flashinfer/fused_moe/core.py

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #45182758: 8/20 passed

@yzh119
Copy link
Collaborator

yzh119 commented Mar 4, 2026

Hi @zyongye can you resolve the merge conflict?

@yzh119 yzh119 enabled auto-merge (squash) March 4, 2026 16:59
auto-merge was automatically disabled March 5, 2026 16:38

Head branch was pushed to by a user without write access

@zyongye zyongye force-pushed the mxint4_routed_moe branch from 6131bd3 to 9f3a99d Compare March 5, 2026 16:38
@zyongye
Copy link
Author

zyongye commented Mar 5, 2026

Hi @zyongye can you resolve the merge conflict?

@yzh119 I rebased the branch and also fix a number of args mismatch in trtllm_batched_gemm_runner.cu that prevent me from compile jit-cache. PTAL. Thanks

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1193-1201: ⚠️ Potential issue | 🟠 Major

Add device checks before binding precomputed routing buffers to raw pointers.

expert_indices / expert_weights_in are shape/dtype-checked, but not device-checked. Binding cross-device pointers into CUDA kernels is unsafe.

🔧 Proposed fix
     if (has_precomputed_indices()) {
       // Pre-computed routing: expert_indices is a packed tensor
       // Format: (expert_id << 16) | (weight_bf16.view(int16))
       TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0))
           << "expert_indices and hidden_states must have same number of tokens.";
       TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k)
           << "expert_indices dim1 must match top_k.";
       TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32.";
+      TVM_FFI_ICHECK_EQ(expert_indices.device(), hidden_states.device())
+          << "expert_indices must be on the same device as hidden_states.";
     }

     if (has_precomputed_weights()) {
       // Pre-computed expert weights: validate shape and dtype
       TVM_FFI_ICHECK_EQ(expert_weights_in.size(0), hidden_states.size(0))
           << "expert_weights_in and hidden_states must have same number of tokens.";
       TVM_FFI_ICHECK_EQ(expert_weights_in.size(1), args->top_k)
           << "expert_weights_in dim1 must match top_k.";
       TVM_FFI_ICHECK_EQ(expert_weights_in.dtype(), dl_bfloat16)
           << "expert_weights_in must be bfloat16.";
+      TVM_FFI_ICHECK_EQ(expert_weights_in.device(), hidden_states.device())
+          << "expert_weights_in must be on the same device as hidden_states.";
     }

Also applies to: 1216-1227

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1193 - 1201, The
precomputed routing arrays expert_indices and expert_weights_in are validated
for shape/dtype but not checked for device, which can lead to unsafe
cross-device raw pointer bindings; update the checks (around
has_precomputed_weights(), and the similar block at 1216-1227) to verify each
array's device type and device id match hidden_states (e.g., ensure
expert_indices.device().device_type == hidden_states.device().device_type and
expert_indices.device().device_id == hidden_states.device().device_id, same for
expert_weights_in) and raise an error if they differ so raw pointer binding into
CUDA kernels only occurs when all arrays are on the same GPU.
flashinfer/fused_moe/core.py (1)

2120-2124: ⚠️ Potential issue | 🔴 Critical

Validate provided output shape before passing it to the kernel.

Current checks only enforce dtype/device. If output is undersized, the kernel can write out of bounds because it assumes [num_tokens, hidden_size].

🐛 Proposed fix
         else:
             check_shape_dtype_device(
                 output, None, torch.bfloat16, hidden_states.device, "output"
             )
+            assert output.dim() == 2, f"output must be 2D, got {output.dim()}D"
+            assert output.shape[0] == num_tokens, (
+                f"output.shape[0]={output.shape[0]} must equal num_tokens={num_tokens}"
+            )
+            assert output.shape[1] >= hidden_size, (
+                f"output.shape[1]={output.shape[1]} must be >= hidden_size={hidden_size}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2120 - 2124, The code currently
only checks dtype/device for `output` via `check_shape_dtype_device` but must
also validate its shape to prevent out-of-bounds writes by the kernel; add a
shape validation that `output.shape == (num_tokens, hidden_size)` (or raise a
clear error) immediately after the dtype/device check (before the `if
routing_logits is not None` branch) using the same symbols (`output`,
`num_tokens`, `hidden_size`, `hidden_states`) so the kernel assumptions are
enforced; keep the check adjacent to the existing `check_shape_dtype_device`
call in `fused_moe/core.py`.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1193-1201: The precomputed routing arrays expert_indices and
expert_weights_in are validated for shape/dtype but not checked for device,
which can lead to unsafe cross-device raw pointer bindings; update the checks
(around has_precomputed_weights(), and the similar block at 1216-1227) to verify
each array's device type and device id match hidden_states (e.g., ensure
expert_indices.device().device_type == hidden_states.device().device_type and
expert_indices.device().device_id == hidden_states.device().device_id, same for
expert_weights_in) and raise an error if they differ so raw pointer binding into
CUDA kernels only occurs when all arrays are on the same GPU.

In `@flashinfer/fused_moe/core.py`:
- Around line 2120-2124: The code currently only checks dtype/device for
`output` via `check_shape_dtype_device` but must also validate its shape to
prevent out-of-bounds writes by the kernel; add a shape validation that
`output.shape == (num_tokens, hidden_size)` (or raise a clear error) immediately
after the dtype/device check (before the `if routing_logits is not None` branch)
using the same symbols (`output`, `num_tokens`, `hidden_size`, `hidden_states`)
so the kernel assumptions are enforced; keep the check adjacent to the existing
`check_shape_dtype_device` call in `fused_moe/core.py`.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cc1b6b64-b38b-46c4-981a-eeaab559e41a

📥 Commits

Reviewing files that changed from the base of the PR and between 6131bd3 and bfd0e54.

📒 Files selected for processing (4)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/fused_moe/__init__.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_trtllm_gen_routed_fused_moe.py

@yzh119
Copy link
Collaborator

yzh119 commented Mar 5, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !365 has been updated with latest changes, and the CI pipeline #45441610 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #45441610: 10/20 passed

@IwakuraRein
Copy link
Collaborator

@flashinfer-bot run

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants