Skip to content

fix: cute dsl nvfp4 moe routing index error#2629

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
nv-yunzheq:fix_cute_dsl_moe_error
Feb 25, 2026
Merged

fix: cute dsl nvfp4 moe routing index error#2629
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
nv-yunzheq:fix_cute_dsl_moe_error

Conversation

@nv-yunzheq
Copy link
Collaborator

@nv-yunzheq nv-yunzheq commented Feb 24, 2026

📌 Description

To fix the following bug:
When the CuteDSL MoE kernels were ported from TensorRT-LLM to FlashInfer, the mPtrPermutedIdxToExpandedIdx field was accidentally dropped from the routing kernel's DataBase struct in RoutingKernel.h. TRT-LLM's routing kernel produces three reverse-mapping outputs:

  1. mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx — forward mapping
  2. mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx — reverse to expanded index (token_idx * topk + k)
  3. mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx — reverse to token index only

FlashInfer's port kept only #1 and #3, dropping #2. The binding in moe_utils_binding.cu then had to wire the Python buffer permuted_idx_to_expanded_idx to the only available reverse-mapping field — mPtrPermutedIdxToTokenIdx — which writes plain tokenIdx instead of expandedIdx.
The Impact
The CuteDSL kernels (GEMM1 gather, moe_output_memset, GEMM2 finalize) all expect expanded indices and derive the token index via expanded_idx // topk. When they received plain tokenIdx instead, they computed tokenIdx // topk — yielding the wrong A row for gather, wrong zero-init for memset, and wrong scatter position + wrong routing scale for finalize.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor

    • Refined MOE (Mixture of Experts) routing infrastructure by extending index mapping capabilities across multiple kernel implementations to improve internal data flow consistency.
  • Tests

    • Strengthened accuracy validation thresholds from 0.925 to 0.97 with adjusted error tolerance parameters, ensuring more rigorous testing of MOE operations under FP4 quantization conditions.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 24, 2026

📝 Walkthrough

Walkthrough

This PR introduces a new mapping pointer mPtrPermutedIdxToExpandedIdx to track permuted-to-expanded index relationships in MOE routing kernels. The pointer is added to data structures, propagated through kernel parameters, and used in routing implementations to record expanded indices for local experts, with corresponding updates to kernel runners and test tolerances.

Changes

Cohort / File(s) Summary
MOE Data Structure Headers
include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Adds new pointer member mPtrPermutedIdxToExpandedIdx to DataBase and KernelParamsBase structs; propagates through setBaseParams method.
MOE Routing Kernel Implementations
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh, csrc/trtllm_fused_moe_routing_deepseek.cu, csrc/trtllm_fused_moe_routing_llama4.cu, csrc/trtllm_fused_moe_routing_renormalize.cu
Adds write operations to mPtrPermutedIdxToExpandedIdx in routing kernels for local experts; includes null-checks where needed to prevent dereferencing.
MOE Binding & Runner
csrc/moe_utils_binding.cu, csrc/trtllm_fused_moe_runner.cu
Updates output pointer mapping in binding; initializes mPtrPermutedIdxToExpandedIdx in runner for DeepSeek, Llama4, and Renormalize routing paths.
Test Tolerances
tests/moe/test_cute_dsl_fused_moe.py
Adjusts accuracy check tolerances: stricter percent_threshold (0.925→0.97) and revised atol/rtol with scaling based on output magnitude for improved FP4 quantization robustness.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Suggested labels

run-ci, v0.6.3, op: moe

Suggested reviewers

  • djmmoss
  • cyx-6
  • jiahanc
  • aleozlx
  • yzh119

Poem

🐰 A hop, skip, and a bounded jump,
New indices mapped without a thump,
From permuted paths to expanded arrays,
Routing kernels dance through clever ways,
MOE logistics bloom so bright! 🌟

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main fix: restoring the missing mPtrPermutedIdxToExpandedIdx mapping in the CuteDSL MoE routing kernel to fix index computation errors.
Description check ✅ Passed The description provides a detailed explanation of the bug, the three mappings involved, the missing field, the impact on downstream kernels, and the fix. However, the Related Issues and Reviewer Notes sections remain empty templates.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical indexing bug within the CuteDSL MoE kernels. During the port from TensorRT-LLM to FlashInfer, the essential mPtrPermutedIdxToExpandedIdx field, responsible for mapping permuted indices back to expanded indices, was inadvertently dropped. This omission led to downstream kernels receiving incorrect token indices, causing errors in subsequent operations like GEMM gather, memset, and finalize. The fix reintroduces and correctly populates this missing expanded index mapping, ensuring that all MoE kernels now operate with the precise indexing required for accurate functionality.

Highlights

  • Reintroduction of Expanded Index Mapping: The mPtrPermutedIdxToExpandedIdx field was reintroduced in RoutingKernel.h to correctly store the reverse mapping from permuted indices to expanded indices, which was previously missing.
  • Corrected Python Binding: The Python binding in moe_utils_binding.cu was updated to properly assign the permuted_idx_to_expanded_idx_ptr to the newly available mPtrPermutedIdxToExpandedIdx and explicitly set mPtrPermutedIdxToTokenIdx to nullptr when not used for its specific purpose.
  • Updated MoE Routing Kernels: Logic was implemented across various MoE routing kernels (trtllm_fused_moe_routing_deepseek.cu, trtllm_fused_moe_routing_llama4.cu, trtllm_fused_moe_routing_renormalize.cu, and RoutingKernel.cuh) to write the correct expanded or token indices to the mPtrPermutedIdxToExpandedIdx field.
  • Runner Class Integration: The Runner class in trtllm_fused_moe_runner.cu was updated to pass the permutedIdxToExpandedIdx pointer to the routing data structure, ensuring the correct data is available for processing.
  • Adjusted Accuracy Test Parameters: Accuracy check parameters in test_cute_dsl_fused_moe.py were modified to better accommodate FP4 quantization noise, including changes to percent_threshold, atol, and rtol.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/moe_utils_binding.cu
    • Corrected the assignment of permuted_idx_to_expanded_idx_ptr to routingData.mPtrPermutedIdxToExpandedIdx.
    • Explicitly set routingData.mPtrPermutedIdxToTokenIdx to nullptr.
  • csrc/trtllm_fused_moe_routing_deepseek.cu
    • Added conditional logic to write expandedIdx to params.mPtrPermutedIdxToExpandedIdx.
    • Updated the FLASHINFER_CHECK condition to include data.mPtrPermutedIdxToExpandedIdx.
  • csrc/trtllm_fused_moe_routing_llama4.cu
    • Added conditional logic to write tokenIdx to params.mPtrPermutedIdxToExpandedIdx.
  • csrc/trtllm_fused_moe_routing_renormalize.cu
    • Added conditional logic to write expandedIdx to params.mPtrPermutedIdxToExpandedIdx.
  • csrc/trtllm_fused_moe_runner.cu
    • Added assignment of permutedIdxToExpandedIdx to routingData.mPtrPermutedIdxToExpandedIdx in three Runner::run overloads.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
    • Added conditional logic to write expandedIdx to params.mPtrPermutedIdxToExpandedIdx within routingPermutation and a __global__ kernel.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
    • Introduced mPtrPermutedIdxToExpandedIdx member to the DataBase struct.
    • Introduced mPtrPermutedIdxToExpandedIdx member to the KernelParamsBase struct.
  • tests/moe/test_cute_dsl_fused_moe.py
    • Updated percent_threshold from 0.925 to 0.97.
    • Adjusted atol calculation from max(0.1, 3.0 * output_scale) to max(0.05, 1.5 * output_scale).
    • Adjusted rtol from 0.85 to 0.5.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in the CuteDSL MoE routing kernels caused by a missing mPtrPermutedIdxToExpandedIdx field during the port from TensorRT-LLM. The changes correctly reintroduce this field, pass it through the call stack, and populate it in the various routing kernels. The fix appears to be correct and comprehensive. I have one minor suggestion to improve code clarity in the llama4 routing kernel for better maintainability. The associated update to tighten accuracy checks in the tests is a positive change that validates the effectiveness of the fix.

}
// write out `mPtrPermutedIdxToExpandedIdx` if required
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert && isTokenRouted) {
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For clarity and consistency with other routing kernels, it would be better to make it explicit that tokenIdx is being used as expandedIdx. While expandedIdx is equivalent to tokenIdx in this kernel (since topK=1), this is an important implementation detail. Adding an inline comment would help future maintainers understand the code's intent more easily, especially when comparing with other routing kernels that use an expandedIdx variable.

        params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx; // For llama4 routing, topK=1, so expandedIdx is equivalent to tokenIdx.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)

53-59: Minor: missing "uninitialized padding slots" caveat for mPtrPermutedIdxToExpandedIdx

mPtrPermutedIdxToTokenIdx documents that padding slots are left uninitialized ("Note: this array is uninitialized. Any out-of-bounds values are undefined."). mPtrPermutedIdxToExpandedIdx has the same semantics — only valid permutedIdx slots are written — but the note was not carried over.

📝 Suggested documentation update
-  // optional: if `nullptr`, it is not filled
-  // dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts]
-  int32_t* mPtrPermutedIdxToExpandedIdx{nullptr};
+  // optional: if `nullptr`, it is not filled
+  // dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts]
+  // Note: padding slots are uninitialized; only entries at valid permuted indices are written.
+  int32_t* mPtrPermutedIdxToExpandedIdx{nullptr};
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h` around lines 53 - 59,
The comment for mPtrPermutedIdxToExpandedIdx is missing the same "uninitialized
padding slots" caveat as mPtrPermutedIdxToTokenIdx; update the comment for
mPtrPermutedIdxToExpandedIdx (the int32_t* member) to state that the array is
uninitialized for padding slots and that only valid permutedIdx slots are
written (any out-of-bounds/unwritten values are undefined), matching the
semantics and dimensional note already present for mPtrPermutedIdxToTokenIdx so
both members share the same documentation.
csrc/trtllm_fused_moe_routing_llama4.cu (1)

305-308: Minor: stores tokenIdx instead of expandedIdx — correct only because Llama4 enforces topK == 1

For topK=1, expandedIdx = tokenIdx * 1 + 0 = tokenIdx, so the value is functionally identical. However, every other routing kernel (routingPermutation, routingIndicesCoopKernel, storeLoopBody) writes expandedIdx explicitly to mPtrPermutedIdxToExpandedIdx. Writing tokenIdx here is semantically misleading and would silently break if Llama4 ever supports topK>1.

Consider using the explicit form for consistency:

♻️ Optional refactor for naming clarity
-      // write out `mPtrPermutedIdxToExpandedIdx` if required
-      if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert && isTokenRouted) {
-        params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx;
-      }
+      // write out `mPtrPermutedIdxToExpandedIdx` if required
+      // Note: for Llama4 (topK==1), expandedIdx == tokenIdx
+      if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert && isTokenRouted) {
+        const int32_t expandedIdx = tokenIdx;  // topK==1 enforced by runImpl
+        params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
+      }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_llama4.cu` around lines 305 - 308, The code
currently writes tokenIdx into params.mPtrPermutedIdxToExpandedIdx (using
permutedIdx key), which only matches expandedIdx when topK==1; change this to
store the explicit expandedIdx instead (i.e., either use the existing
expandedIdx variable if present or compute expandedIdx = tokenIdx * params.topK
+ localK/the appropriate offset used in other kernels) and write that to
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] so the logic matches
routingPermutation/routingIndicesCoopKernel/storeLoopBody and remains correct if
topK>1.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@csrc/trtllm_fused_moe_routing_llama4.cu`:
- Around line 305-308: The code currently writes tokenIdx into
params.mPtrPermutedIdxToExpandedIdx (using permutedIdx key), which only matches
expandedIdx when topK==1; change this to store the explicit expandedIdx instead
(i.e., either use the existing expandedIdx variable if present or compute
expandedIdx = tokenIdx * params.topK + localK/the appropriate offset used in
other kernels) and write that to
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] so the logic matches
routingPermutation/routingIndicesCoopKernel/storeLoopBody and remains correct if
topK>1.

In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h`:
- Around line 53-59: The comment for mPtrPermutedIdxToExpandedIdx is missing the
same "uninitialized padding slots" caveat as mPtrPermutedIdxToTokenIdx; update
the comment for mPtrPermutedIdxToExpandedIdx (the int32_t* member) to state that
the array is uninitialized for padding slots and that only valid permutedIdx
slots are written (any out-of-bounds/unwritten values are undefined), matching
the semantics and dimensional note already present for mPtrPermutedIdxToTokenIdx
so both members share the same documentation.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26ef055 and 2117710.

📒 Files selected for processing (8)
  • csrc/moe_utils_binding.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_routing_llama4.cu
  • csrc/trtllm_fused_moe_routing_renormalize.cu
  • csrc/trtllm_fused_moe_runner.cu
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • tests/moe/test_cute_dsl_fused_moe.py

@aleozlx aleozlx added the run-ci label Feb 24, 2026
@aleozlx
Copy link
Collaborator

aleozlx commented Feb 24, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !346 has been created, and the CI pipeline #44679068 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44679068: 8/20 passed


def check_accuracy(
actual: torch.Tensor, expected: torch.Tensor, percent_threshold: float = 0.925
actual: torch.Tensor, expected: torch.Tensor, percent_threshold: float = 0.97
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason we make this change?

@yzh119 yzh119 merged commit 9826c26 into flashinfer-ai:main Feb 25, 2026
44 of 86 checks passed
@nv-yunzheq nv-yunzheq deleted the fix_cute_dsl_moe_error branch March 2, 2026 19:52
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

To fix the following bug:
When the CuteDSL MoE kernels were ported from TensorRT-LLM to
FlashInfer, the mPtrPermutedIdxToExpandedIdx field was accidentally
dropped from the routing kernel's DataBase struct in RoutingKernel.h.
TRT-LLM's routing kernel produces three reverse-mapping outputs:

1. mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx — forward
mapping
2. mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx — reverse to
expanded index (token_idx * topk + k)
3. mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx — reverse to token
index only

FlashInfer's port kept only flashinfer-ai#1 and flashinfer-ai#3, dropping flashinfer-ai#2. The binding in
moe_utils_binding.cu then had to wire the Python buffer
permuted_idx_to_expanded_idx to the only available reverse-mapping field
— mPtrPermutedIdxToTokenIdx — which writes plain tokenIdx instead of
expandedIdx.
The Impact
The CuteDSL kernels (GEMM1 gather, moe_output_memset, GEMM2 finalize)
all expect expanded indices and derive the token index via expanded_idx
// topk. When they received plain tokenIdx instead, they computed
tokenIdx // topk — yielding the wrong A row for gather, wrong zero-init
for memset, and wrong scatter position + wrong routing scale for
finalize.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Refactor**
* Refined MOE (Mixture of Experts) routing infrastructure by extending
index mapping capabilities across multiple kernel implementations to
improve internal data flow consistency.

* **Tests**
* Strengthened accuracy validation thresholds from 0.925 to 0.97 with
adjusted error tolerance parameters, ensuring more rigorous testing of
MOE operations under FP4 quantization conditions.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants