Skip to content

Route the missing parameter for trtllm_fp8_per_tensor_scale_moe_op #3094

Merged
aleozlx merged 4 commits intoflashinfer-ai:mainfrom
pavanimajety:fix-fp8-ptwrapper
Apr 18, 2026
Merged

Route the missing parameter for trtllm_fp8_per_tensor_scale_moe_op #3094
aleozlx merged 4 commits intoflashinfer-ai:mainfrom
pavanimajety:fix-fp8-ptwrapper

Conversation

@pavanimajety
Copy link
Copy Markdown
Contributor

@pavanimajety pavanimajety commented Apr 16, 2026

📌 Description

Fix vLLM CI failure for 0.6.8 - https://buildkite.com/vllm/ci/builds/61703/steps/canvas?jid=019d97e7-d69d-4b1a-a597-95a021d29060&tab=output#019d97e7-d69d-4b1a-a597-95a021d29060

The public trtllm_fp8_per_tensor_scale_moe wrapper at line 2559 calls into _op via the SimpleNamespace returned by get_trtllm_moe_sm100_module() (line 2315), so user-facing callers hit the same error.

The routing_replay_out trailing argument was added in #3024 (2026-04-15). That PR updated the public wrapper's call list but not the inner _op's call to the C++ binding.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor

    • Improved MoE routing output handling in TensorRT-LLM FP16/BF16 and FP8 inference paths so optional replay outputs are accepted and processed for more robust Mixture-of-Experts inference.
  • Chores

    • Relaxed version tag validation in the release workflow to accept an additional optional segment after the patch number.

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 16, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 389ac700-7412-4f24-8944-8f24cf0e0155

📥 Commits

Reviewing files that changed from the base of the PR and between 311ef08 and 27e9b21.

📒 Files selected for processing (1)
  • .github/workflows/release.yml

📝 Walkthrough

Walkthrough

Forward an optional routing_replay_out tensor from the Python wrapper into TensorRT‑LLM MoE C++ ops for the BF16 and FP8-per-tensor-scale execution paths, updating the wrapper call signatures to pass the extra buffer.

Changes

Cohort / File(s) Summary
Fused MoE wrappers
flashinfer/fused_moe/core.py
Pass optional routing_replay_out into TensorRT‑LLM MoE ops: BF16 path uses kwargs.get("routing_replay_out"); FP8 per‑tensor‑scale path adds a routing_replay_out argument to trtllm_*_moe calls.
CI workflow
.github/workflows/release.yml
Relaxed workflow_dispatch tag validation regex to accept an optional additional lowercase segment after vX.Y.Z.

Sequence Diagram(s)

(Skipped — changes are limited plumbing updates without significant new multi-component control flow.)

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

op: moe-routing

Suggested reviewers

  • yzh119
  • cyx-6
  • jimmyzho
  • jiahanc
  • nv-yunzheq
  • bkryu
  • sricketts

Poem

🐰 I hopped through wrappers, soft and light,
Tucked replay routing into kernels’ sight.
BF16, FP8 — a tiny thread sewn through,
Little buffer carried, now paths renew. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title directly describes the main change: routing the missing routing_replay_out parameter through the trtllm_fp8_per_tensor_scale_moe_op call, which aligns with the primary objective of fixing the vLLM CI failure.
Description check ✅ Passed The PR description includes a detailed explanation of the issue, links to the related vLLM CI failure, references the prior PR that introduced the parameter, and confirms pre-commit checks and testing were completed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the Fused MoE core logic by converting positional arguments to keyword arguments and adding support for routing_replay_out. Several critical issues were identified in the trtllm_fp8_per_tensor_scale_moe implementation: the manual allocation of the output tensor and the enable_pdl check are redundant and lead to a positional argument mismatch in the subsequent function call. Additionally, the tune_max_num_tokens parameter was incorrectly changed from an integer to a list, which will cause failures in the autotuning logic.

Comment thread flashinfer/fused_moe/core.py Outdated
Comment on lines 2640 to 2641
output,
num_experts,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change introduces a critical bug due to a positional argument mismatch. The target function trtllm_fp8_per_tensor_scale_moe_op does not include output in its signature (defined at line 1503). By inserting output at this position, all subsequent arguments are shifted: num_experts will be passed to the top_k parameter, top_k to n_group, and so on. This will cause a TypeError or incorrect kernel execution. This line should be removed to restore the correct positional mapping.

Suggested change
output,
num_experts,
num_experts,

Comment thread flashinfer/fused_moe/core.py Outdated
Comment thread flashinfer/fused_moe/core.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)

2628-2657: ⚠️ Potential issue | 🔴 Critical

Critical: positional argument misalignment — this PR's fix will TypeError at runtime.

The public wrapper (lines 2628–2657) passes 25 positional arguments to trtllm_fp8_per_tensor_scale_moe_op, which has 24 parameters. Position 9 inserts output (tensor) where num_experts: int is expected, and position 22 inserts [-1, -1] where tune_max_num_tokens: int belongs. The 25th argument (routing_replay_out) exceeds the function's arity and will raise TypeError: too many positional arguments at runtime.

Meanwhile, the internal C++ binding call inside the _op function (lines 1603–1629) correctly uses keyword arguments and successfully passes output= and config_index= to the C++ binding—those parameters exist at the C++ level but are not exposed as formal parameters in trtllm_fp8_per_tensor_scale_moe_op's Python signature. The _op function allocates output internally (lines 1538–1540) and infers config_index from the AutoTuner (line 1625).

Fix: Remove the output allocation (lines 2628–2633) and the [-1, -1] placeholder (line 2653) from the public wrapper. Call trtllm_fp8_per_tensor_scale_moe_op positionally without these arguments, matching the 24-parameter signature:

result = get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe_op(
    routing_logits,
    routing_bias,
    hidden_states,
    gemm1_weights,
    output1_scales_scalar,
    output1_scales_gate_scalar,
    gemm2_weights,
    output2_scales_scalar,
    num_experts,
    top_k,
    n_group,
    topk_group,
    intermediate_size,
    local_expert_offset,
    local_num_experts,
    routed_scaling_factor,
    use_routing_scales_on_input,
    routing_method_type,
    do_finalize,
    enable_pdl,
    tune_max_num_tokens,  # default=8192 in _op, or parameterize the wrapper
    activation_type,
    norm_topk_prob,
    routing_replay_out,
)

The _op will allocate and return output internally; extract it from the returned list as done in lines 2630–2637.

Also add an integration test that invokes this wrapper with a non-None routing_replay_out to prevent regression when this code path is exercised in CI.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2628 - 2657, The wrapper is
passing 25 positional args to trtllm_fp8_per_tensor_scale_moe_op — misplacing
`output` and `[-1, -1]` — which will TypeError at runtime; fix by removing the
external `output` allocation (the torch.empty block) and the `[-1, -1]`
placeholder, then call
get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe_op with the 24
expected positional parameters (use the `_op` default for tune_max_num_tokens or
expose it as a wrapper param) and, after the call, extract the
internally-allocated output from the returned result as `_op` does; reference
trtllm_fp8_per_tensor_scale_moe_op, get_trtllm_moe_sm100_module, and the
internal _op allocation/return logic when making the change.
🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)

1603-1628: Nice cleanup — keyword-arg invocation is much safer.

Switching moe_op.trtllm_fp8_per_tensor_scale_moe(...) to kwargs (including output=output and routing_replay_out=routing_replay_out) eliminates the exact class of positional-drift bug that caused the vLLM CI regression in the first place. Consider applying the same treatment to the remaining positional C++ calls in this file (trtllm_fp8_block_scale_moe at Line 1821, trtllm_fp4_block_scale_moe at Line 2048, trtllm_mxint4_block_scale_moe at Line 2245) in a follow-up to prevent recurrence.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1603 - 1628, The call-site class
of positional-argument bugs was fixed by switching
moe_op.trtllm_fp8_per_tensor_scale_moe(...) to keyword args; do the same for the
remaining C++-bound calls to avoid positional-drift regressions: update
moe_op.trtllm_fp8_block_scale_moe, moe_op.trtllm_fp4_block_scale_moe, and
moe_op.trtllm_mxint4_block_scale_moe to use explicit keyword arguments for every
parameter (including output=..., routing_replay_out=..., config_index=...,
activation_type=..., etc.), preserving the existing argument names and values so
the call semantics do not change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2628-2657: The wrapper is passing 25 positional args to
trtllm_fp8_per_tensor_scale_moe_op — misplacing `output` and `[-1, -1]` — which
will TypeError at runtime; fix by removing the external `output` allocation (the
torch.empty block) and the `[-1, -1]` placeholder, then call
get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe_op with the 24
expected positional parameters (use the `_op` default for tune_max_num_tokens or
expose it as a wrapper param) and, after the call, extract the
internally-allocated output from the returned result as `_op` does; reference
trtllm_fp8_per_tensor_scale_moe_op, get_trtllm_moe_sm100_module, and the
internal _op allocation/return logic when making the change.

---

Nitpick comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1603-1628: The call-site class of positional-argument bugs was
fixed by switching moe_op.trtllm_fp8_per_tensor_scale_moe(...) to keyword args;
do the same for the remaining C++-bound calls to avoid positional-drift
regressions: update moe_op.trtllm_fp8_block_scale_moe,
moe_op.trtllm_fp4_block_scale_moe, and moe_op.trtllm_mxint4_block_scale_moe to
use explicit keyword arguments for every parameter (including output=...,
routing_replay_out=..., config_index=..., activation_type=..., etc.), preserving
the existing argument names and values so the call semantics do not change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 49aada39-68d1-4440-988e-001d72d9e64b

📥 Commits

Reviewing files that changed from the base of the PR and between a99ee72 and 8928433.

📒 Files selected for processing (1)
  • flashinfer/fused_moe/core.py

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@pavanimajety pavanimajety marked this pull request as draft April 16, 2026 23:33
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@pavanimajety pavanimajety marked this pull request as ready for review April 16, 2026 23:46
@aleozlx aleozlx added the run-ci label Apr 17, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 17, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !564 has been created, and the CI pipeline #48814530 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx merged commit 8559397 into flashinfer-ai:main Apr 18, 2026
66 of 89 checks passed
aleozlx added a commit that referenced this pull request Apr 18, 2026
…_op` (#3094)

<!-- .github/pull_request_template.md -->

## 📌 Description
Fix vLLM CI failure for 0.6.8 -
https://buildkite.com/vllm/ci/builds/61703/steps/canvas?jid=019d97e7-d69d-4b1a-a597-95a021d29060&tab=output#019d97e7-d69d-4b1a-a597-95a021d29060

The public `trtllm_fp8_per_tensor_scale_moe` wrapper at line 2559 calls
into `_op` via the `SimpleNamespace` returned by
`get_trtllm_moe_sm100_module()` (line 2315), so user-facing callers hit
the same error.
The `routing_replay_out` trailing argument was added in #3024
(2026-04-15). That PR updated the public wrapper's call list but not the
inner `_op`'s call to the C++ binding.


<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Improved MoE routing output handling in TensorRT-LLM FP16/BF16 and FP8
inference paths so optional replay outputs are accepted and processed
for more robust Mixture-of-Experts inference.

* **Chores**
* Relaxed version tag validation in the release workflow to accept an
additional optional segment after the patch number.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Co-authored-by: Alex Yang <aleyang@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants