Skip to content

Bf16 routed moe#2594

Merged
IwakuraRein merged 21 commits intoflashinfer-ai:mainfrom
IwakuraRein:bf16-routed-moe
Mar 3, 2026
Merged

Bf16 routed moe#2594
IwakuraRein merged 21 commits intoflashinfer-ai:mainfrom
IwakuraRein:bf16-routed-moe

Conversation

@IwakuraRein
Copy link
Collaborator

@IwakuraRein IwakuraRein commented Feb 19, 2026

📌 Description

Add trtllm_bf16_routed_moe api

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

pytest tests/moe/test_trtllm_gen_routed_fused_moe.py::test_trtllm_gen_bf16_routed_fused_moe

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added support for pre-computed routing in MoE operations, enabling flexible routing input strategies.
    • New routed MoE APIs now available: BF16 and FP8 variants support pre-packed top-k routing information.
    • Introduced dual-path mechanism allowing MoE operations to accept either routing logits or pre-computed routing data.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 19, 2026

📝 Walkthrough

Walkthrough

This PR extends the MoE kernel infrastructure to support precomputed routing data (expert indices and weights) alongside traditional routing logits, enabling alternative routing input paths. The changes propagate expert_indices and expert_weights through launcher constructors, add optional routing parameters to Python APIs, introduce a new trtllm_bf16_routed_moe function, and expand test coverage for routed MoE inference.

Changes

Cohort / File(s) Summary
C++ MoE Launcher Routing
csrc/trtllm_fused_moe_kernel_launcher.cu
Added expert_indices and expert_weights as private data members to Bf16MoeLauncher; updated constructor signature to accept these parameters as Optional inputs; modified routing path to use precomputed workspace.routing_expert_indexes when expert_indices are available, with fallback to routing_logits; added validation checks for precomputed routing data shape, dtype, and top_k compatibility.
Python Package Exports
flashinfer/__init__.py, flashinfer/fused_moe/__init__.py
Added three new public API exports: trtllm_bf16_moe, trtllm_bf16_routed_moe, and trtllm_fp8_block_scale_routed_moe; expanded top-level namespace for fused MoE operations without removing existing exports.
Python Core MoE Operations
flashinfer/fused_moe/core.py
Extended trtllm_bf16_moe_op signature to accept Optional routing_logits, topk_ids, and expert_weights; added validation requiring at least routing_logits or topk_ids; implemented dynamic buffer handling based on which routing input is provided; introduced new trtllm_bf16_routed_moe function for pre-packed top-k routing inference; updated _fake_trtllm_bf16_moe to match new optional parameter scheme.
MoE Test Implementations
tests/moe/test_trtllm_gen_fused_moe.py
Modified return value handling in FP4Moe and MXInt4BlockScaleMoe call_moe methods to extract first element of output tuple as float tensor (output[0].to(torch.float)).
Routed MoE Tests
tests/moe/test_trtllm_gen_routed_fused_moe.py
Added routing method branching (Renormalize/RenormalizeNaive/TopK) selection based on routing_method_type; introduced utilities for shuffled weight preparation and block layout conversion; added comprehensive BF16 routed MoE test matching FP8 structure with top-k packing, reference routing validation, and equality checks.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

ready, run-ci, op: moe

Suggested reviewers

  • aleozlx
  • cyx-6
  • djmmoss
  • jiahanc
  • yzh119
  • wenscarl

Poem

🐰 Hopping through MoE paths so bright,
With expert routes packed just right,
Precomputed indices lead the way,
Routed inference saves the day! 🎯✨

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description provides a clear title but lacks detail about what changes enable the new API, the design rationale, or implementation specifics. Expand the description to explain: (1) what trtllm_bf16_routed_moe does and how it differs from existing APIs, (2) key implementation changes (e.g., support for precomputed routing indices/weights), (3) use cases or benefits, and (4) any breaking changes or migration considerations.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Bf16 routed moe' is concise and directly reflects the main change—adding a new API for BF16 routed MOE functionality as documented in the PR objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the Mixture of Experts (MoE) functionality by introducing a new API for bfloat16 operations with pre-computed routing. It also refactors existing MoE implementations across different data types to provide greater control over output finalization, allowing for the retrieval of intermediate results. These changes aim to improve flexibility and integration within MoE pipelines.

Highlights

  • New API for BF16 Routed MoE: Introduced trtllm_bf16_routed_moe API to support Mixture of Experts (MoE) operations with pre-computed routing information in bfloat16 precision.
  • Flexible Output Finalization: Added a do_finalize parameter to various MoE functions (BF16, FP8, MXINT4) in both C++ and Python, allowing users to choose between receiving the final output or intermediate tensors for further processing.
  • Support for Pre-computed Expert Indices and Weights: Modified the Bf16MoeLauncher to accept and utilize pre-computed expert_indices and expert_weights, enabling more efficient routing scenarios.
  • Unified Return Type for MoE Functions: Changed the return type of several trtllm_ MoE functions from a single Tensor to an Array<Tensor> (C++) or List[torch.Tensor] (Python) to consistently handle both final and intermediate outputs.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Updated FusedMoeLauncher::run to use workspace pointers for routing_expert_indexes and expert_weights.
    • Modified FusedMoeLauncher::run return statement to reference FusedMoeLauncher::expert_weights.
    • Refactored Bf16MoeLauncher constructor to accept expert_indices and expert_weights.
    • Added validation for expert_indices in Bf16MoeLauncher::check_routing.
    • Implemented logic in Bf16MoeLauncher::init to handle pre-computed expert_indices and expert_weights.
    • Modified Bf16MoeLauncher::init to conditionally allocate output if args->output is null.
    • Added expert_weights and expert_indices as private members to Bf16MoeLauncher.
    • Removed args->do_finalize = true; from Fp8PerTensorLauncher::init and Fp8BlockScaleLauncher::init.
    • Changed return types of trtllm_bf16_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_block_scale_moe, and trtllm_mxint4_block_scale_moe from Tensor to Array<Tensor>.
    • Added do_finalize parameter to trtllm_bf16_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_block_scale_moe, and trtllm_mxint4_block_scale_moe functions.
    • Updated trtllm_bf16_moe to pass expert_indices and expert_weights to the Bf16MoeLauncher constructor.
    • Modified trtllm_bf16_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_block_scale_moe, and trtllm_mxint4_block_scale_moe to return the full Array<Tensor> from the launcher.
    • Updated FP4BlockScaleLauncher::run to return output when do_finalize is true, and FusedMoeLauncher::expert_weights otherwise.
  • flashinfer/fused_moe/core.py
    • Modified forward methods for various MoE operations to pass topk_ids, expert_weights, output, and do_finalize parameters.
    • Changed the return type annotations of trtllm_bf16_moe_op, trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe_op, _fake_trtllm_bf16_moe, _fake_trtllm_fp8_per_tensor_scale_moe, _fake_trtllm_fp8_block_scale_moe, and trtllm_mxint4_block_scale_moe_op from torch.Tensor to List[torch.Tensor].
    • Updated trtllm_bf16_moe_op to handle routing either from routing_logits or from pre-computed topk_ids and expert_weights.
    • Added logic in trtllm_bf16_moe_op to conditionally allocate topk_ids and expert_weights based on the presence of routing_logits.
    • Modified return logic in trtllm_bf16_moe_op, trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe_op, and trtllm_mxint4_block_scale_moe_op to return [output] if do_finalize is true, or intermediate tensors otherwise.
    • Introduced do_finalize parameter to trtllm_bf16_moe_op and its fake counterpart, as well as to trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe_op, and trtllm_mxint4_block_scale_moe_op.
    • Added a new Python API trtllm_bf16_routed_moe for bfloat16 MoE with pre-computed routing.
    • Updated docstrings for trtllm_bf16_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_block_scale_routed_moe, trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_routed_moe, and trtllm_mxint4_block_scale_moe to reflect the do_finalize parameter and List[torch.Tensor] return type.
    • Removed some activation types from docstrings for FP8 and FP4 MoE functions.
    • Corrected the return value handling for trtllm_fp4_block_scale_moe_op when do_finalize is false.
  • tests/moe/test_trtllm_gen_fused_moe.py
    • Updated call_moe functions for MXINT4, FP8 block-scale, FP8 per-tensor, and BF16 MoE tests to extract the first element from the returned list (output[0]) to match the new return type.
  • tests/moe/test_trtllm_gen_routed_fused_moe.py
    • Imported new functions (convert_to_block_layout, trtllm_bf16_moe, trtllm_bf16_routed_moe, shuffle_matrix_a, WeightLayout) for expanded testing.
    • Updated test_trtllm_gen_fp8_routed_fused_moe to extract output[0] from the returned list and expanded routing method types for reference computation.
    • Added a new test test_trtllm_gen_bf16_routed_fused_moe to verify the functionality of the new trtllm_bf16_routed_moe API against standard BF16 MoE.
Activity
  • IwakuraRein created this pull request to add a trtllm_bf16_routed_moe API.
  • The PR description indicates that pre-commit checks and tests are yet to be completed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new API, trtllm_bf16_routed_moe, and enhances existing MoE functions to support pre-computed routing and the option to return intermediate results. The changes are consistently applied across the C++ kernel launcher and Python bindings, improving flexibility and control over the MoE computation flow. New tests have been added to validate the functionality of the BF16 routed MoE. The do_finalize parameter is a valuable addition, allowing users to either get the final output or inspect intermediate tensors for further processing.

if (has_precomputed_indices) {
// Use expert_indices directly
workspace.routing_expert_indexes =
static_cast<int*>(const_cast<void*>(expert_indices.data_ptr()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of const_cast<void*>(expert_indices.data_ptr()) here bypasses const-correctness. While it might be necessary due to the underlying API expecting a non-const pointer, it's important to ensure that the expert_indices data is indeed treated as read-only within the kernel to prevent unintended modifications. If the kernel truly modifies this data, it should be explicitly copied to a mutable buffer.

@IwakuraRein IwakuraRein marked this pull request as ready for review February 23, 2026 19:39
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)

1317-1371: ⚠️ Potential issue | 🟠 Major

Validate topk_ids/expert_weights when routing_logits is absent.
Without dtype/shape checks, empty or malformed packed routing data can reach the kernel and produce undefined routing. Add explicit validation for topk_ids (int32, [num_tokens, top_k]) and expert_weights (if provided) before proceeding.

🔧 Suggested guardrails
     assert routing_logits is not None or topk_ids is not None, (
         "either routing_logits or topk_ids must be provided"
     )
@@
     num_tokens = hidden_states.shape[0]
     hidden_size = hidden_states.shape[-1]
+    if routing_logits is None:
+        if topk_ids is None:
+            raise ValueError("topk_ids must be provided when routing_logits is None.")
+        if topk_ids.dtype != torch.int32:
+            raise TypeError("topk_ids must be int32 when routing_logits is None.")
+        if topk_ids.ndim != 2 or topk_ids.shape[0] != num_tokens or topk_ids.shape[1] != top_k:
+            raise ValueError("topk_ids must have shape [num_tokens, top_k].")
+        if expert_weights is not None and expert_weights.numel() > 0:
+            if expert_weights.dtype != torch.bfloat16:
+                raise TypeError("expert_weights must be bfloat16.")
+            if expert_weights.shape != topk_ids.shape:
+                raise ValueError("expert_weights must match topk_ids shape.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1317 - 1371, When routing_logits
is None, validate that topk_ids and expert_weights are well-formed before use:
in the function handling routing (the block that currently assigns
topk_ids/topk_ids = topk_ids and expert_weights = ...), assert topk_ids is a
torch.Tensor of dtype torch.int32 and shape [num_tokens, top_k], and if
expert_weights is provided assert it is a torch.Tensor with matching first two
dims [num_tokens, top_k] and a sensible dtype (e.g., routing_logits.dtype if
available or torch.bfloat16), otherwise raise a clear ValueError; use the local
symbols topk_ids, expert_weights, num_tokens, and top_k to implement these
checks so malformed/empty packed routing data cannot reach the kernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 463-474: In check_routing(), enforce that exactly one routing
source is provided: if routing_logits is null/absent then expert_indices must be
non-empty, and if expert_indices is provided/non-empty then routing_logits must
be absent; error out when both are present or both absent. Modify
FusedMoeLauncher::check_routing_common() caller logic in check_routing() to add
TVM_FFI_ICHECK-style guards referencing expert_indices, routing_logits, and
hidden_states (and args->top_k where relevant) so you reject the case of empty
expert_indices with no routing_logits and the case where both routing_logits and
a non-empty expert_indices are supplied.

In `@flashinfer/fused_moe/core.py`:
- Around line 1448-1452: The function _fake_trtllm_bf16_moe has parameters
(routing_logits, routing_bias, expert_indices, expert_weights, hidden_states)
flagged as unused by Ruff; to silence the lint while preserving the signature,
rename the unused parameters by prefixing them with an underscore (e.g.,
routing_logits -> _routing_logits, routing_bias -> _routing_bias, expert_indices
-> _expert_indices, expert_weights -> _expert_weights, and if hidden_states is
unused rename to _hidden_states) in the _fake_trtllm_bf16_moe definition and any
internal references so the signature stays compatible but Ruff no longer reports
them as unused.

In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 403-530: The test's call to shuffle_matrix_a(..., 64) produces a
shuffle_block_size of 16 that doesn't match
BF16Moe.prepare_static_weights_for_kernel which uses epilogue_tile_m=128
(shuffle_block_size=32); update the test to use the same tile param (use 128
instead of 64) or derive the tile size from
BF16Moe.prepare_static_weights_for_kernel so that shuffle_matrix_a and the
production preprocessing use the same epilogue_tile_m/shuffle_block_size,
ensuring gemm1_weights/gemm2_weights are shuffled into the identical layout the
kernel expects.

---

Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1317-1371: When routing_logits is None, validate that topk_ids and
expert_weights are well-formed before use: in the function handling routing (the
block that currently assigns topk_ids/topk_ids = topk_ids and expert_weights =
...), assert topk_ids is a torch.Tensor of dtype torch.int32 and shape
[num_tokens, top_k], and if expert_weights is provided assert it is a
torch.Tensor with matching first two dims [num_tokens, top_k] and a sensible
dtype (e.g., routing_logits.dtype if available or torch.bfloat16), otherwise
raise a clear ValueError; use the local symbols topk_ids, expert_weights,
num_tokens, and top_k to implement these checks so malformed/empty packed
routing data cannot reach the kernel.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26ef055 and 83a634d.

📒 Files selected for processing (6)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • flashinfer/__init__.py
  • flashinfer/fused_moe/__init__.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/test_trtllm_gen_routed_fused_moe.py

Comment on lines 463 to +474
void check_routing() const override {
FusedMoeLauncher::check_routing_common();
if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) {
// Pre-computed routing: expert_indices is a packed tensor
// Format: (expert_id << 16) | (weight_bf16.view(int16))
TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D.";
TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0))
<< "expert_indices and hidden_states must have same number of tokens.";
TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k)
<< "expert_indices dim1 must match top_k.";
TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32.";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Enforce exactly one routing source to avoid undefined routing data.
If routing_logits is absent and expert_indices is empty, routing runs with uninitialized indices; if both are present, routing can overwrite user-provided indices. Add an explicit exclusivity/required check.

🔧 Suggested validation
   void check_routing() const override {
     FusedMoeLauncher::check_routing_common();
-    if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) {
+    bool has_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
+    bool has_logits = routing_logits.has_value();
+    TVM_FFI_ICHECK(has_precomputed || has_logits)
+        << "either routing_logits or expert_indices must be provided.";
+    TVM_FFI_ICHECK(!(has_precomputed && has_logits))
+        << "provide either routing_logits or expert_indices, not both.";
+    if (has_precomputed) {
       // Pre-computed routing: expert_indices is a packed tensor
       // Format: (expert_id << 16) | (weight_bf16.view(int16))
       TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D.";
       TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0))
           << "expert_indices and hidden_states must have same number of tokens.";
       TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k)
           << "expert_indices dim1 must match top_k.";
       TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32.";
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void check_routing() const override {
FusedMoeLauncher::check_routing_common();
if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) {
// Pre-computed routing: expert_indices is a packed tensor
// Format: (expert_id << 16) | (weight_bf16.view(int16))
TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D.";
TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0))
<< "expert_indices and hidden_states must have same number of tokens.";
TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k)
<< "expert_indices dim1 must match top_k.";
TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32.";
}
void check_routing() const override {
FusedMoeLauncher::check_routing_common();
bool has_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
bool has_logits = routing_logits.has_value();
TVM_FFI_ICHECK(has_precomputed || has_logits)
<< "either routing_logits or expert_indices must be provided.";
TVM_FFI_ICHECK(!(has_precomputed && has_logits))
<< "provide either routing_logits or expert_indices, not both.";
if (has_precomputed) {
// Pre-computed routing: expert_indices is a packed tensor
// Format: (expert_id << 16) | (weight_bf16.view(int16))
TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D.";
TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0))
<< "expert_indices and hidden_states must have same number of tokens.";
TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k)
<< "expert_indices dim1 must match top_k.";
TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32.";
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 463 - 474, In
check_routing(), enforce that exactly one routing source is provided: if
routing_logits is null/absent then expert_indices must be non-empty, and if
expert_indices is provided/non-empty then routing_logits must be absent; error
out when both are present or both absent. Modify
FusedMoeLauncher::check_routing_common() caller logic in check_routing() to add
TVM_FFI_ICHECK-style guards referencing expert_indices, routing_logits, and
hidden_states (and args->top_k where relevant) so you reject the case of empty
expert_indices with no routing_logits and the case where both routing_logits and
a non-empty expert_indices are supplied.

Comment on lines +1448 to 1452
routing_logits: Optional[torch.Tensor],
routing_bias: Optional[torch.Tensor],
expert_indices: Optional[torch.Tensor],
expert_weights: Optional[torch.Tensor],
hidden_states: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silence unused-argument lint in _fake_trtllm_bf16_moe.
Ruff flags these parameters as unused; prefixing with _ keeps lint clean while preserving signature parity.

🔧 Minimal lint-safe rename
 def _fake_trtllm_bf16_moe(
-    routing_logits: Optional[torch.Tensor],
-    routing_bias: Optional[torch.Tensor],
-    expert_indices: Optional[torch.Tensor],
-    expert_weights: Optional[torch.Tensor],
+    _routing_logits: Optional[torch.Tensor],
+    _routing_bias: Optional[torch.Tensor],
+    _expert_indices: Optional[torch.Tensor],
+    _expert_weights: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     gemm1_weights: torch.Tensor,
     gemm2_weights: torch.Tensor,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
routing_logits: Optional[torch.Tensor],
routing_bias: Optional[torch.Tensor],
expert_indices: Optional[torch.Tensor],
expert_weights: Optional[torch.Tensor],
hidden_states: torch.Tensor,
def _fake_trtllm_bf16_moe(
_routing_logits: Optional[torch.Tensor],
_routing_bias: Optional[torch.Tensor],
_expert_indices: Optional[torch.Tensor],
_expert_weights: Optional[torch.Tensor],
hidden_states: torch.Tensor,
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 1448-1448: Unused function argument: routing_logits

(ARG001)


[warning] 1449-1449: Unused function argument: routing_bias

(ARG001)


[warning] 1450-1450: Unused function argument: expert_indices

(ARG001)


[warning] 1451-1451: Unused function argument: expert_weights

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1448 - 1452, The function
_fake_trtllm_bf16_moe has parameters (routing_logits, routing_bias,
expert_indices, expert_weights, hidden_states) flagged as unused by Ruff; to
silence the lint while preserving the signature, rename the unused parameters by
prefixing them with an underscore (e.g., routing_logits -> _routing_logits,
routing_bias -> _routing_bias, expert_indices -> _expert_indices, expert_weights
-> _expert_weights, and if hidden_states is unused rename to _hidden_states) in
the _fake_trtllm_bf16_moe definition and any internal references so the
signature stays compatible but Ruff no longer reports them as unused.

Comment on lines +403 to +530
@pytest.mark.parametrize("num_tokens", [8, 64])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [8, 16])
@pytest.mark.parametrize("top_k", [2, 4])
@pytest.mark.parametrize(
"routing_method_type",
[
RoutingMethodType.Renormalize,
],
)
def test_trtllm_gen_bf16_routed_fused_moe(
num_tokens: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
num_experts: int,
routing_method_type: RoutingMethodType,
):
"""Test Bf16 scale routed MoE matches standard routing."""
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] not in [10]:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
torch.manual_seed(42)
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)

# Generate random routing logits for reference
routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
torch.bfloat16
)

# Generate random hidden states in FP8
hidden_states = (
torch.randn(num_tokens, hidden_size, device=device).to(torch.bfloat16) * 0.1
)

# Generate weights
gemm1_weights = torch.randn(
num_experts, 2 * intermediate_size, hidden_size, device=device
).to(torch.bfloat16)
gemm2_weights = torch.randn(
num_experts, hidden_size, intermediate_size, device=device
).to(torch.bfloat16)

gemm1_weights_shuffled = []
gemm2_weights_shuffled = []
for i in range(num_experts):
tmp_weights1 = shuffle_matrix_a(gemm1_weights[i].view(torch.uint8), 64)
tmp_weights2 = shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), 64)
block_k = 128
gemm1_weights_shuffled.append(convert_to_block_layout(tmp_weights1, block_k))
gemm2_weights_shuffled.append(convert_to_block_layout(tmp_weights2, block_k))
gemm1_weights = torch.stack(gemm1_weights_shuffled).view(torch.bfloat16)
gemm2_weights = torch.stack(gemm2_weights_shuffled).view(torch.bfloat16)

# Run reference with routing_logits
reference_output = trtllm_bf16_moe(
routing_logits=routing_logits,
routing_bias=None,
hidden_states=hidden_states,
gemm1_weights=gemm1_weights,
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=None,
topk_group=None,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
routing_method_type=routing_method_type.value,
use_shuffled_weight=True,
weight_layout=WeightLayout.BlockMajorK,
do_finalize=True,
enable_pdl=enable_pdl,
).to(torch.float)

# Compute routing using reference implementation
if routing_method_type == RoutingMethodType.Renormalize:
permute_info, expert_weights_ref = routing_reference_renormalize(
routing_logits, top_k, num_experts, 8
)
elif routing_method_type == RoutingMethodType.RenormalizeNaive:
permute_info, expert_weights_ref = routing_reference_renormalize_naive(
routing_logits, top_k, num_experts, 8
)
elif routing_method_type == RoutingMethodType.TopK:
permute_info, expert_weights_ref = routing_reference_topk(
routing_logits, top_k, num_experts, 8
)
topk_ids = permute_info["topKIndices"].to(torch.int32)
expert_weights = expert_weights_ref.view(num_tokens, num_experts)[
torch.arange(num_tokens, device=device).unsqueeze(1), topk_ids
].to(torch.bfloat16)

# Pack topk_ids and expert_weights into single tensor
# Format: (expert_id << 16) | (weight_bf16.view(int16))
packed_topk_ids = (topk_ids << 16) | expert_weights.view(torch.int16).to(
torch.int32
)

# Run with pre-computed routing (packed format)
output = trtllm_bf16_routed_moe(
topk_ids=packed_topk_ids,
hidden_states=hidden_states,
gemm1_weights=gemm1_weights,
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=None,
topk_group=None,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
routing_method_type=routing_method_type.value,
use_shuffled_weight=True,
weight_layout=WeightLayout.BlockMajorK,
do_finalize=True,
enable_pdl=enable_pdl,
).to(torch.float)

mask = torch.isclose(output, reference_output, rtol=1e-2, atol=1e-2)

# mismatch percentage
mismatch_pct = (~mask).float().mean().item() * 100
assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for BF16 weight preprocessing patterns and tile sizes
rg -n "class BF16Moe|prepare_static_weights_for_kernel|epilogue_tile_m|shuffle_matrix_a" \
  tests/moe/ flashinfer/fused_moe/

Repository: flashinfer-ai/flashinfer

Length of output: 4614


🏁 Script executed:

#!/bin/bash
# Look at BF16Moe class and its prepare_static_weights_for_kernel implementation
sed -n '1375,1450p' tests/moe/test_trtllm_gen_fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 2885


🏁 Script executed:

#!/bin/bash
# Get the full BF16Moe.prepare_static_weights_for_kernel implementation
sed -n '1401,1460p' tests/moe/test_trtllm_gen_fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 2501


🏁 Script executed:

#!/bin/bash
# Look at the get_shuffle_matrix_a_row_indices and shuffle_matrix_a implementations
grep -A 20 "def get_shuffle_matrix_a_row_indices" flashinfer/fused_moe/core.py
echo "---"
grep -B 5 -A 10 "def shuffle_matrix_a" flashinfer/fused_moe/core.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 71


🏁 Script executed:

#!/bin/bash
# Check imports in the test file
head -50 tests/moe/test_trtllm_gen_routed_fused_moe.py | grep -E "import|from"

Repository: flashinfer-ai/flashinfer

Length of output: 322


🏁 Script executed:

#!/bin/bash
# Get the full imports from both flashinfer and flashinfer.fused_moe
head -100 tests/moe/test_trtllm_gen_routed_fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 3404


🏁 Script executed:

#!/bin/bash
# Find where shuffle_matrix_a is defined
rg -n "def shuffle_matrix_a|shuffle_matrix_a =" flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 190


🏁 Script executed:

#!/bin/bash
# Look at shuffle_matrix_a definition and its implementation
sed -n '771,850p' flashinfer/fp4_quantization.py

Repository: flashinfer-ai/flashinfer

Length of output: 2855


🏁 Script executed:

#!/bin/bash
# Look at get_shuffle_matrix_a_row_indices definition
rg -n "def get_shuffle_matrix_a_row_indices" flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 129


🏁 Script executed:

#!/bin/bash
# Look at get_shuffle_matrix_a_row_indices implementation
sed -n '757,820p' flashinfer/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 2231


🏁 Script executed:

#!/bin/bash
# Look at get_shuffle_block_size to understand how epilogue_tile_m affects shuffling
rg -n "def get_shuffle_block_size" flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 148


🏁 Script executed:

#!/bin/bash
# Look at get_shuffle_block_size implementation
sed -n '750,756p' flashinfer/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 251


🏁 Script executed:

#!/bin/bash
# Check if there are other BF16 tests in the routed test file and what tile sizes they use
grep -n "def test_trtllm_gen.*bf16\|shuffle_matrix_a" tests/moe/test_trtllm_gen_routed_fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 307


🏁 Script executed:

#!/bin/bash
# Let's verify the exact context around how the routed test uses these functions
sed -n '445,460p' tests/moe/test_trtllm_gen_routed_fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 846


🏁 Script executed:

#!/bin/bash
# Let's double-check the relationship between epilogue_tile_m and block size
cat > /tmp/verify_tile_size.py << 'EOF'
# Test epilogue_tile_m values and their corresponding shuffle_block_size
values = [64, 128]
for epilogue_tile_m in values:
    shuffle_block_size = 16
    if epilogue_tile_m % 128 == 0:
        shuffle_block_size = 32
    print(f"epilogue_tile_m={epilogue_tile_m} => shuffle_block_size={shuffle_block_size}")
EOF
python /tmp/verify_tile_size.py

Repository: flashinfer-ai/flashinfer

Length of output: 156


🏁 Script executed:

#!/bin/bash
# Check if there are comments explaining the 64 tile size choice in the routed test
sed -n '403,460p' tests/moe/test_trtllm_gen_routed_fused_moe.py | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 2281


Fix hardcoded tile size to match BF16Moe preprocessing path.

The test uses shuffle_matrix_a(..., 64) which produces a shuffle_block_size=16 row permutation, but BF16Moe.prepare_static_weights_for_kernel uses epilogue_tile_m=128, producing shuffle_block_size=32. This mismatch means the test shuffles weights with a different layout than the production code path, potentially allowing the test to pass while feeding incorrect weights to the kernel. Use epilogue_tile_m=128 to match the BF16Moe preprocessing logic, or document why a different tile size is required.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 403 - 530, The
test's call to shuffle_matrix_a(..., 64) produces a shuffle_block_size of 16
that doesn't match BF16Moe.prepare_static_weights_for_kernel which uses
epilogue_tile_m=128 (shuffle_block_size=32); update the test to use the same
tile param (use 128 instead of 64) or derive the tile size from
BF16Moe.prepare_static_weights_for_kernel so that shuffle_matrix_a and the
production preprocessing use the same epilogue_tile_m/shuffle_block_size,
ensuring gemm1_weights/gemm2_weights are shuffled into the identical layout the
kernel expects.

@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !348 has been created, and the CI pipeline #44689639 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44689639: 10/20 passed

@aleozlx aleozlx self-assigned this Feb 27, 2026
@aleozlx aleozlx added the run-ci label Feb 27, 2026
@aleozlx
Copy link
Collaborator

aleozlx commented Feb 27, 2026

tests are good. need to resolve one api compatibility issue as above. otherwise lgtm

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@aleozlx aleozlx added the ready label Feb 27, 2026
@aleozlx
Copy link
Collaborator

aleozlx commented Feb 27, 2026

this is ready for merging

@IwakuraRein IwakuraRein requested a review from kahyunnam as a code owner March 2, 2026 17:50
Copy link
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm based on other's approvals to help unblock

@IwakuraRein IwakuraRein merged commit a6a60f1 into flashinfer-ai:main Mar 3, 2026
92 of 106 checks passed
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Add `trtllm_bf16_routed_moe` api

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

`pytest
tests/moe/test_trtllm_gen_routed_fused_moe.py::test_trtllm_gen_bf16_routed_fused_moe`

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added support for pre-computed routing in MoE operations, enabling
flexible routing input strategies.
* New routed MoE APIs now available: BF16 and FP8 variants support
pre-packed top-k routing information.
* Introduced dual-path mechanism allowing MoE operations to accept
either routing logits or pre-computed routing data.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants