Skip to content

Fused moe all-reduce routed scaling factor + quant support#2966

Merged
aleozlx merged 7 commits intoflashinfer-ai:mainfrom
murphymatt:fused-ar-routed-scaling-factor
Apr 13, 2026
Merged

Fused moe all-reduce routed scaling factor + quant support#2966
aleozlx merged 7 commits intoflashinfer-ai:mainfrom
murphymatt:fused-ar-routed-scaling-factor

Conversation

@murphymatt
Copy link
Copy Markdown
Contributor

@murphymatt murphymatt commented Apr 3, 2026

…rnel

📌 Description

In some cases like kimi k2.5, glm, we require an extra scale factor applied over all routed experts. While we can add a constant multiplier over the experted scaling factors, it is more efficient to fuse this multiplication inside the kernel.

e.g. https://huggingface.co/moonshotai/Kimi-K2.5/blob/main/config.json#L143

Furthermore, we want to also support quantized outputs for this operator.

Likewise, I can also open a PR against TRT-LLM for functionality equivalence.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added optional outputs for quantization and scaling and an optional routed-scaling parameter to the MOE allreduce fusion API.
  • Refactor
    • Adjusted accumulation and routed-scaling application so routed scaling can be applied post-accumulation; added validation requiring at least one output buffer.
  • Tests
    • Updated tests to cover the new optional outputs and routed-scaling behavior in warmup and capture paths.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8226c888-e835-43d0-8f8d-213dcb8690c1

📥 Commits

Reviewing files that changed from the base of the PR and between a41fd42 and a1450e4.

📒 Files selected for processing (1)
  • flashinfer/comm/trtllm_ar.py

📝 Walkthrough

Walkthrough

This change extends the MOE all-reduce finalization to accept optional output buffers (norm_out, residual_out, quant_out, scale_out) and an optional scalar routed_scaling_factor; kernel params, kernel accumulation logic, launcher/API wiring, and tests were updated accordingly.

Changes

Cohort / File(s) Summary
CUDA Launcher / Op
csrc/trtllm_moe_allreduce_fusion.cu
Signature updated to take Optional<TensorView> outputs and Optional<float> routed_scaling_factor; params fields set conditionally (device ptr or nullptr); params.routed_scaling_factor defaults to 1.0f when absent.
CUDA Kernel & Params Header
include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh
Added float routed_scaling_factor to params struct; kernel accumulation revised to handle per-expert scaling vs direct accumulation and to apply routed scaling once after accumulation; op-level validation requires at least one of residual_out/norm_out/quant_out.
Python Wrapper / Op Registration
flashinfer/comm/trtllm_ar.py
Custom-op registration and Python wrapper updated: norm_out/residual_outOptional[...], added quant_out, scale_out, and routed_scaling_factor; mutates_args extended to include quant_out and scale_out; new args forwarded to native op.
Tests
tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py
Tests allocate and pass quant_out, scale_out, and routed_scaling_factor (2.5); reference reduction updated to include routed scaling; shapes for new tensors inferred from norm_out.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python Wrapper
    participant OP as CustomOp Module
    participant Kernel as CUDA Kernel
    participant GPU as GPU Memory

    Py->>OP: call trtllm_moe_finalize_allreduce_fusion(..., norm_out?, residual_out?, quant_out?, scale_out?, routed_scaling_factor?)
    OP->>Kernel: build params (device ptrs or nullptr, routed_scaling_factor)
    Kernel->>GPU: read inputs (allreduce_in, idx maps, expert_scale_factor?)
    Kernel->>Kernel: for each top_k: accumulate (apply expert_scale_factor? per-element)
    Kernel->>Kernel: if routed_scaling_factor != 1.0: accumulator *= routed_scaling_factor
    Kernel->>GPU: write outputs to provided buffers (norm_out?, residual_out?, quant_out?, scale_out?)
    Kernel-->>OP: return status
    OP-->>Py: return / raise on error
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

op: moe, op: moe-routing

Suggested reviewers

  • yzh119
  • aleozlx
  • yongwww
  • cyx-6
  • jimmyzho

Poem

🐰 I hopped into the fusion land,
Optional buffers in my hand,
Quant and scale now join the run,
Routed factor doubles the fun,
Hop, compute, and off we stand.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main changes: adding routed scaling factor fusion and quantization support to the MOE all-reduce operator.
Description check ✅ Passed The PR description covers the motivation (efficiency for specific model architectures), implementation details (kernel fusion), and testing completeness, with pre-commit and test checkboxes marked complete.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the MoE finalize all-reduce fusion kernel to support optional quantization outputs (quant_out, scale_out) and a global routed_scaling_factor. It also updates the Python API and test suite to accommodate these changes while making existing output tensors optional. Feedback suggests optimizing performance by moving the routed_scaling_factor multiplication outside the expert loop to reduce redundant operations and refining a validation check to exclude an unsupported output parameter.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py (1)

99-104: Consider adding correctness verification for quant_out and scale_out.

The test allocates these buffers and passes them to the kernel, but their contents are never verified against expected values. The existing TODO at line 12 acknowledges this gap.

For completeness, consider adding a follow-up to verify the quantized output correctness, similar to how residual_out and norm_out are validated.

Would you like me to open an issue to track adding quant output verification tests?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py` around lines 99 -
104, Add assertions that verify quant_out and scale_out contents after the
kernel runs: compute the expected quantized outputs and scales from the same
inputs used to produce residual_out/norm_out (using seq_len, hidden_size,
SF_VEC_SIZE and dtype/device) and compare them to quant_out and scale_out with
tolerant checks (e.g., torch.testing.assert_allclose or equivalent); place these
checks immediately after the kernel invocation in the test (the block that
allocates quant_out and scale_out) so the test fails if quantization or scale
computation is incorrect.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py`:
- Around line 99-104: Add assertions that verify quant_out and scale_out
contents after the kernel runs: compute the expected quantized outputs and
scales from the same inputs used to produce residual_out/norm_out (using
seq_len, hidden_size, SF_VEC_SIZE and dtype/device) and compare them to
quant_out and scale_out with tolerant checks (e.g.,
torch.testing.assert_allclose or equivalent); place these checks immediately
after the kernel invocation in the test (the block that allocates quant_out and
scale_out) so the test fails if quantization or scale computation is incorrect.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3e6ee43d-e260-4ae2-bfb4-7ef262d30998

📥 Commits

Reviewing files that changed from the base of the PR and between fe05393 and b45d317.

📒 Files selected for processing (4)
  • csrc/trtllm_moe_allreduce_fusion.cu
  • flashinfer/comm/trtllm_ar.py
  • include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh
  • tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py

@murphymatt murphymatt changed the title Fused ar routed scaling factor Fused moe all-reduce routed scaling factor + quant support Apr 3, 2026
@murphymatt
Copy link
Copy Markdown
Contributor Author

cc. @yyihuang @yzh119

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 added the run-ci label Apr 6, 2026
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Apr 6, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !515 has been created, and the CI pipeline #47858231 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #47858231: 10/20 passed

@murphymatt
Copy link
Copy Markdown
Contributor Author

@yzh119 seems this PR is stuck on pending checks. What shall I do to get this landed?

@aleozlx aleozlx merged commit a2d2b25 into flashinfer-ai:main Apr 13, 2026
29 of 30 checks passed
samuellees added a commit to samuellees/flashinfer that referenced this pull request Apr 13, 2026
…n API

PR flashinfer-ai#2966 added quant_out, scale_out, and routed_scaling_factor params
to trtllm_moe_finalize_allreduce_fusion(). PR flashinfer-ai#2982 (unified API) was
developed before flashinfer-ai#2966 merged, and git merge produced no conflict since
they touched different files (trtllm_ar.py vs allreduce.py). However
the call in allreduce_fusion() was missing the three new positional
args, causing TypeError at runtime for kMoEFinalizeARResidualRMSNorm
pattern and mypy failure in pre-commit.

Fix:
- Add quant_out, scale_out, routed_scaling_factor to the finalize call
- Add routed_scaling_factor to allreduce_fusion() function signature
- Update docstring

AI-assisted

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants