fix: cute dsl nvfp4 moe routing index error#2629
Conversation
📝 WalkthroughWalkthroughThis PR introduces a new mapping pointer Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical indexing bug within the CuteDSL MoE kernels. During the port from TensorRT-LLM to FlashInfer, the essential Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a critical bug in the CuteDSL MoE routing kernels caused by a missing mPtrPermutedIdxToExpandedIdx field during the port from TensorRT-LLM. The changes correctly reintroduce this field, pass it through the call stack, and populate it in the various routing kernels. The fix appears to be correct and comprehensive. I have one minor suggestion to improve code clarity in the llama4 routing kernel for better maintainability. The associated update to tighten accuracy checks in the tests is a positive change that validates the effectiveness of the fix.
| } | ||
| // write out `mPtrPermutedIdxToExpandedIdx` if required | ||
| if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert && isTokenRouted) { | ||
| params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx; |
There was a problem hiding this comment.
For clarity and consistency with other routing kernels, it would be better to make it explicit that tokenIdx is being used as expandedIdx. While expandedIdx is equivalent to tokenIdx in this kernel (since topK=1), this is an important implementation detail. Adding an inline comment would help future maintainers understand the code's intent more easily, especially when comparing with other routing kernels that use an expandedIdx variable.
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx; // For llama4 routing, topK=1, so expandedIdx is equivalent to tokenIdx.
There was a problem hiding this comment.
🧹 Nitpick comments (2)
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
53-59: Minor: missing "uninitialized padding slots" caveat formPtrPermutedIdxToExpandedIdx
mPtrPermutedIdxToTokenIdxdocuments that padding slots are left uninitialized ("Note: this array is uninitialized. Any out-of-bounds values are undefined.").mPtrPermutedIdxToExpandedIdxhas the same semantics — only validpermutedIdxslots are written — but the note was not carried over.📝 Suggested documentation update
- // optional: if `nullptr`, it is not filled - // dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts] - int32_t* mPtrPermutedIdxToExpandedIdx{nullptr}; + // optional: if `nullptr`, it is not filled + // dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts] + // Note: padding slots are uninitialized; only entries at valid permuted indices are written. + int32_t* mPtrPermutedIdxToExpandedIdx{nullptr};🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h` around lines 53 - 59, The comment for mPtrPermutedIdxToExpandedIdx is missing the same "uninitialized padding slots" caveat as mPtrPermutedIdxToTokenIdx; update the comment for mPtrPermutedIdxToExpandedIdx (the int32_t* member) to state that the array is uninitialized for padding slots and that only valid permutedIdx slots are written (any out-of-bounds/unwritten values are undefined), matching the semantics and dimensional note already present for mPtrPermutedIdxToTokenIdx so both members share the same documentation.csrc/trtllm_fused_moe_routing_llama4.cu (1)
305-308: Minor: storestokenIdxinstead ofexpandedIdx— correct only because Llama4 enforcestopK == 1For topK=1,
expandedIdx = tokenIdx * 1 + 0 = tokenIdx, so the value is functionally identical. However, every other routing kernel (routingPermutation,routingIndicesCoopKernel,storeLoopBody) writesexpandedIdxexplicitly tomPtrPermutedIdxToExpandedIdx. WritingtokenIdxhere is semantically misleading and would silently break if Llama4 ever supports topK>1.Consider using the explicit form for consistency:
♻️ Optional refactor for naming clarity
- // write out `mPtrPermutedIdxToExpandedIdx` if required - if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert && isTokenRouted) { - params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx; - } + // write out `mPtrPermutedIdxToExpandedIdx` if required + // Note: for Llama4 (topK==1), expandedIdx == tokenIdx + if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert && isTokenRouted) { + const int32_t expandedIdx = tokenIdx; // topK==1 enforced by runImpl + params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_routing_llama4.cu` around lines 305 - 308, The code currently writes tokenIdx into params.mPtrPermutedIdxToExpandedIdx (using permutedIdx key), which only matches expandedIdx when topK==1; change this to store the explicit expandedIdx instead (i.e., either use the existing expandedIdx variable if present or compute expandedIdx = tokenIdx * params.topK + localK/the appropriate offset used in other kernels) and write that to params.mPtrPermutedIdxToExpandedIdx[permutedIdx] so the logic matches routingPermutation/routingIndicesCoopKernel/storeLoopBody and remains correct if topK>1.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@csrc/trtllm_fused_moe_routing_llama4.cu`:
- Around line 305-308: The code currently writes tokenIdx into
params.mPtrPermutedIdxToExpandedIdx (using permutedIdx key), which only matches
expandedIdx when topK==1; change this to store the explicit expandedIdx instead
(i.e., either use the existing expandedIdx variable if present or compute
expandedIdx = tokenIdx * params.topK + localK/the appropriate offset used in
other kernels) and write that to
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] so the logic matches
routingPermutation/routingIndicesCoopKernel/storeLoopBody and remains correct if
topK>1.
In `@include/flashinfer/trtllm/fused_moe/RoutingKernel.h`:
- Around line 53-59: The comment for mPtrPermutedIdxToExpandedIdx is missing the
same "uninitialized padding slots" caveat as mPtrPermutedIdxToTokenIdx; update
the comment for mPtrPermutedIdxToExpandedIdx (the int32_t* member) to state that
the array is uninitialized for padding slots and that only valid permutedIdx
slots are written (any out-of-bounds/unwritten values are undefined), matching
the semantics and dimensional note already present for mPtrPermutedIdxToTokenIdx
so both members share the same documentation.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
csrc/moe_utils_binding.cucsrc/trtllm_fused_moe_routing_deepseek.cucsrc/trtllm_fused_moe_routing_llama4.cucsrc/trtllm_fused_moe_routing_renormalize.cucsrc/trtllm_fused_moe_runner.cuinclude/flashinfer/trtllm/fused_moe/RoutingKernel.cuhinclude/flashinfer/trtllm/fused_moe/RoutingKernel.htests/moe/test_cute_dsl_fused_moe.py
|
/bot run |
|
[FAILED] Pipeline #44679068: 8/20 passed |
|
|
||
| def check_accuracy( | ||
| actual: torch.Tensor, expected: torch.Tensor, percent_threshold: float = 0.925 | ||
| actual: torch.Tensor, expected: torch.Tensor, percent_threshold: float = 0.97 |
There was a problem hiding this comment.
Any specific reason we make this change?
<!-- .github/pull_request_template.md --> ## 📌 Description To fix the following bug: When the CuteDSL MoE kernels were ported from TensorRT-LLM to FlashInfer, the mPtrPermutedIdxToExpandedIdx field was accidentally dropped from the routing kernel's DataBase struct in RoutingKernel.h. TRT-LLM's routing kernel produces three reverse-mapping outputs: 1. mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx — forward mapping 2. mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx — reverse to expanded index (token_idx * topk + k) 3. mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx — reverse to token index only FlashInfer's port kept only flashinfer-ai#1 and flashinfer-ai#3, dropping flashinfer-ai#2. The binding in moe_utils_binding.cu then had to wire the Python buffer permuted_idx_to_expanded_idx to the only available reverse-mapping field — mPtrPermutedIdxToTokenIdx — which writes plain tokenIdx instead of expandedIdx. The Impact The CuteDSL kernels (GEMM1 gather, moe_output_memset, GEMM2 finalize) all expect expanded indices and derive the token index via expanded_idx // topk. When they received plain tokenIdx instead, they computed tokenIdx // topk — yielding the wrong A row for gather, wrong zero-init for memset, and wrong scatter position + wrong routing scale for finalize. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Refined MOE (Mixture of Experts) routing infrastructure by extending index mapping capabilities across multiple kernel implementations to improve internal data flow consistency. * **Tests** * Strengthened accuracy validation thresholds from 0.925 to 0.97 with adjusted error tolerance parameters, ensuring more rigorous testing of MOE operations under FP4 quantization conditions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
To fix the following bug:
When the CuteDSL MoE kernels were ported from TensorRT-LLM to FlashInfer, the mPtrPermutedIdxToExpandedIdx field was accidentally dropped from the routing kernel's DataBase struct in RoutingKernel.h. TRT-LLM's routing kernel produces three reverse-mapping outputs:
FlashInfer's port kept only #1 and #3, dropping #2. The binding in moe_utils_binding.cu then had to wire the Python buffer permuted_idx_to_expanded_idx to the only available reverse-mapping field — mPtrPermutedIdxToTokenIdx — which writes plain tokenIdx instead of expandedIdx.
The Impact
The CuteDSL kernels (GEMM1 gather, moe_output_memset, GEMM2 finalize) all expect expanded indices and derive the token index via expanded_idx // topk. When they received plain tokenIdx instead, they computed tokenIdx // topk — yielding the wrong A row for gather, wrong zero-init for memset, and wrong scatter position + wrong routing scale for finalize.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Refactor
Tests