Fused moe all-reduce routed scaling factor + quant support#2966
Fused moe all-reduce routed scaling factor + quant support#2966aleozlx merged 7 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThis change extends the MOE all-reduce finalization to accept optional output buffers ( Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python Wrapper
participant OP as CustomOp Module
participant Kernel as CUDA Kernel
participant GPU as GPU Memory
Py->>OP: call trtllm_moe_finalize_allreduce_fusion(..., norm_out?, residual_out?, quant_out?, scale_out?, routed_scaling_factor?)
OP->>Kernel: build params (device ptrs or nullptr, routed_scaling_factor)
Kernel->>GPU: read inputs (allreduce_in, idx maps, expert_scale_factor?)
Kernel->>Kernel: for each top_k: accumulate (apply expert_scale_factor? per-element)
Kernel->>Kernel: if routed_scaling_factor != 1.0: accumulator *= routed_scaling_factor
Kernel->>GPU: write outputs to provided buffers (norm_out?, residual_out?, quant_out?, scale_out?)
Kernel-->>OP: return status
OP-->>Py: return / raise on error
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request extends the MoE finalize all-reduce fusion kernel to support optional quantization outputs (quant_out, scale_out) and a global routed_scaling_factor. It also updates the Python API and test suite to accommodate these changes while making existing output tensors optional. Feedback suggests optimizing performance by moving the routed_scaling_factor multiplication outside the expert loop to reduce redundant operations and refining a validation check to exclude an unsupported output parameter.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py (1)
99-104: Consider adding correctness verification forquant_outandscale_out.The test allocates these buffers and passes them to the kernel, but their contents are never verified against expected values. The existing TODO at line 12 acknowledges this gap.
For completeness, consider adding a follow-up to verify the quantized output correctness, similar to how
residual_outandnorm_outare validated.Would you like me to open an issue to track adding quant output verification tests?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py` around lines 99 - 104, Add assertions that verify quant_out and scale_out contents after the kernel runs: compute the expected quantized outputs and scales from the same inputs used to produce residual_out/norm_out (using seq_len, hidden_size, SF_VEC_SIZE and dtype/device) and compare them to quant_out and scale_out with tolerant checks (e.g., torch.testing.assert_allclose or equivalent); place these checks immediately after the kernel invocation in the test (the block that allocates quant_out and scale_out) so the test fails if quantization or scale computation is incorrect.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py`:
- Around line 99-104: Add assertions that verify quant_out and scale_out
contents after the kernel runs: compute the expected quantized outputs and
scales from the same inputs used to produce residual_out/norm_out (using
seq_len, hidden_size, SF_VEC_SIZE and dtype/device) and compare them to
quant_out and scale_out with tolerant checks (e.g.,
torch.testing.assert_allclose or equivalent); place these checks immediately
after the kernel invocation in the test (the block that allocates quant_out and
scale_out) so the test fails if quantization or scale computation is incorrect.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3e6ee43d-e260-4ae2-bfb4-7ef262d30998
📒 Files selected for processing (4)
csrc/trtllm_moe_allreduce_fusion.cuflashinfer/comm/trtllm_ar.pyinclude/flashinfer/comm/trtllm_moe_allreduce_fusion.cuhtests/comm/test_trtllm_moe_allreduce_fusion_finalize.py
|
/bot run |
|
[SUCCESS] Pipeline #47858231: 10/20 passed |
|
@yzh119 seems this PR is stuck on pending checks. What shall I do to get this landed? |
…n API PR flashinfer-ai#2966 added quant_out, scale_out, and routed_scaling_factor params to trtllm_moe_finalize_allreduce_fusion(). PR flashinfer-ai#2982 (unified API) was developed before flashinfer-ai#2966 merged, and git merge produced no conflict since they touched different files (trtllm_ar.py vs allreduce.py). However the call in allreduce_fusion() was missing the three new positional args, causing TypeError at runtime for kMoEFinalizeARResidualRMSNorm pattern and mypy failure in pre-commit. Fix: - Add quant_out, scale_out, routed_scaling_factor to the finalize call - Add routed_scaling_factor to allreduce_fusion() function signature - Update docstring AI-assisted Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rnel
📌 Description
In some cases like kimi k2.5, glm, we require an extra scale factor applied over all routed experts. While we can add a constant multiplier over the experted scaling factors, it is more efficient to fuse this multiplication inside the kernel.
e.g. https://huggingface.co/moonshotai/Kimi-K2.5/blob/main/config.json#L143
Furthermore, we want to also support quantized outputs for this operator.
Likewise, I can also open a PR against TRT-LLM for functionality equivalence.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit