feat: support routing replay in trtllm_fp8_block_scale_moe and fused_topk_deepseek#2685
feat: support routing replay in trtllm_fp8_block_scale_moe and fused_topk_deepseek#2685TomerBN-Nvidia wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThreads an optional int16_t* routing replay output through Python APIs → launcher → runner → kernel. When provided, kernels record selected expert IDs per token in a [num_tokens, topk] int16 buffer; behavior is unchanged when the pointer is null. Changes
Sequence Diagram(s)sequenceDiagram
participant Python
participant Launcher
participant Runner
participant Kernel
participant ReplayBuf as "Replay Buffer (int16_t*)"
Python->>Launcher: trtllm_..._moe(..., routing_replay_out)
Launcher->>Runner: set_routing_replay_out(...) / run(..., routingReplayOut)
Runner->>Kernel: invokeNoAuxTc(..., routing_replay_out)
Kernel->>ReplayBuf: write expertIdx at [token, topk] if pointer non-null
Kernel-->>Runner: return topk_values/topk_indices
Runner-->>Launcher: propagate outputs
Launcher-->>Python: return (topk_values, topk_indices, mutated routing_replay_out)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a limitation in vLLM's expert-parallel routing replay feature when used with FlashInfer's TRTLLM-GEN fused Mixture-of-Experts (MoE) backend. By introducing an optional output tensor, Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a routing_replay_out parameter to trtllm_fp8_block_scale_moe and fused_topk_deepseek to enable routing replay, a valuable feature for analyzing Mixture-of-Experts models. While the implementation is generally well-executed, a critical vulnerability exists: the trtllm_fp8_block_scale_moe path lacks validation for the new routing_replay_out tensor. This missing validation could lead to out-of-bounds memory writes in the CUDA kernel if an incorrectly sized or typed tensor is provided. It is recommended to add shape and dtype checks in both the Python wrapper and the C++ FFI implementation for the FP8 block-scale MoE path to mitigate this risk.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1807-1901:⚠️ Potential issue | 🔴 CriticalValidate
routing_replay_outbefore passing raw pointer to kernels.
routing_replay_outis forwarded without dtype/shape/device checks. A mismatched buffer can cause out-of-bounds writes or invalid-device access when the routing kernel writes replay IDs.🛡️ Proposed fix
Array<Tensor> trtllm_fp8_block_scale_moe( @@ auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); + if (routing_replay_out.has_value()) { + auto const& replay = routing_replay_out.value(); + TVM_FFI_ICHECK_EQ(replay.dtype(), dl_int16) + << "routing_replay_out must be int16."; + TVM_FFI_ICHECK_EQ(replay.ndim(), 2) + << "routing_replay_out must be 2D."; + TVM_FFI_ICHECK_EQ(replay.size(0), num_tokens) + << "routing_replay_out dim0 must match num_tokens."; + TVM_FFI_ICHECK_EQ(replay.size(1), top_k) + << "routing_replay_out dim1 must match top_k."; + TVM_FFI_ICHECK_EQ(replay.device().device_type, hidden_states.device().device_type) + << "routing_replay_out must be on the same device type as hidden_states."; + TVM_FFI_ICHECK_EQ(replay.device().device_id, hidden_states.device().device_id) + << "routing_replay_out must be on the same device as hidden_states."; + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1807 - 1901, The review points out that routing_replay_out is forwarded into the kernels without validation; before calling launcher->set_routing_replay_out(routing_replay_out) (and before constructing/initializing Fp8BlockScaleLauncher), validate that routing_replay_out.has_value() only when its TensorView has the expected dtype, shape and device (e.g., correct ndim/size matching num_tokens/top_k and expected integer dtype), and if invalid either clear the Optional or throw with a clear message; update the call site around set_routing_replay_out and any code that dereferences routing_replay_out inside Fp8BlockScaleLauncher to rely on this validated contract.
🧹 Nitpick comments (2)
tests/model_optimizations/test_dsv3_fused_routing.py (1)
560-566: Use tensor equality instead of set equality for replay validation.Set comparison can miss multiplicity bugs (e.g., duplicate IDs). Since both tensors should represent the same top-k picks, compare tensors directly.
🔍 Suggested fix
- for t in range(num_tokens): - replay_set = set(routing_replay_out[t].tolist()) - indices_set = set(topk_indices[t].tolist()) - assert replay_set == indices_set, ( - f"Token {t}: routing_replay_out experts {replay_set} " - f"!= topk_indices experts {indices_set}" - ) + torch.testing.assert_close( + routing_replay_out.to(torch.int32), + topk_indices, + msg="routing_replay_out must match topk_indices element-wise", + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/model_optimizations/test_dsv3_fused_routing.py` around lines 560 - 566, Replace the set-based comparison with a direct tensor equality check between routing_replay_out and topk_indices for each token: for each t in the loop, assert that the tensors for that token are equal (e.g., use torch.equal or (tensor_a == tensor_b).all()) so duplicates and ordering mismatches are caught; update the assertion message to include the tensors (routing_replay_out[t], topk_indices[t]) for clearer failure diagnostics.tests/moe/test_trtllm_gen_routed_fused_moe.py (1)
655-666: Strengthen replay correctness check beyond range/uniqueness.These assertions can still pass with wrong expert IDs. Consider validating replay IDs against reference routing outputs for the same logits/top-k.
✅ Suggested enhancement
+ permute_info, _ = routing_reference_renormalize( + routing_logits, top_k, num_experts, 8 + ) + expected_topk_ids = permute_info["topKIndices"].to(torch.int16) + # Each token should have top_k unique experts for t in range(num_tokens): unique_experts = routing_replay_out[t].unique() assert unique_experts.numel() == top_k, ( f"Token {t}: expected {top_k} unique experts, got {unique_experts.numel()}" ) + assert set(routing_replay_out[t].tolist()) == set(expected_topk_ids[t].tolist())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 655 - 666, Compute a reference routing for the same logits and top_k and assert routing_replay_out equals that reference instead of only range/uniqueness checks: use the test's logits and the same top_k to produce expected_routing (e.g., by applying torch.topk on logits along expert dimension and mapping to expert IDs), then check expected_routing.shape == routing_replay_out.shape and torch.equal(routing_replay_out, expected_routing) (or per-token equality with clear assertion messages); reference symbols: routing_replay_out, logits, top_k, num_tokens, num_experts.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/fused_moe/noAuxTcKernels.cu`:
- Around line 353-364: The code validates routing_replay_out's shape and dtype
but does not check its device, allowing a CPU tensor or a tensor on a different
CUDA device to be cast to replay_ptr and passed to the kernel; update the block
that handles routing_replay_out (references: routing_replay_out, replay,
replay_ptr, encode_dlpack_dtype, int16_code, TVM_FFI_ICHECK) to assert the
tensor is on the expected CUDA device before extracting the raw pointer — e.g.,
verify replay.device().device_type is CUDA and replay.device().device_id equals
the current/target CUDA device (or error out via TVM_FFI_ICHECK with a clear
message) so the kernel receives only pointers to memory on the correct device.
In `@csrc/trtllm_fused_moe_runner.cu`:
- Around line 61-63: The code only sets routingData.mPtrRoutingReplayOut when
routingMethodType == RoutingMethodType::DeepSeekV3, so routingReplayOut passed
into the API is ignored for other methods; fix by unconditionally wiring
routingReplayOut into routingData.mPtrRoutingReplayOut before or outside the
routing-method switch/if (so Llama4/Renormalize/TopK also get the pointer), and
only use it later if non-null in the DeepSeekV3-specific logic; reference
routingReplayOut, RoutingMethodType, DeepSeekV3, and
routingData.mPtrRoutingReplayOut to locate the assignment to change.
In `@flashinfer/fused_moe/core.py`:
- Line 1829: The new fake-op parameter routing_replay_out is unused and triggers
ARG001; inside the fake-op function (the function that declares
routing_replay_out) explicitly consume it to silence the linter by adding a
no-op usage such as "_ = routing_replay_out" (or "del routing_replay_out" / "if
False: _ = routing_replay_out") at the start of the function body so the symbol
is referenced but behavior remains unchanged.
In `@tests/model_optimizations/test_dsv3_fused_routing.py`:
- Around line 504-533: The new parameterized test test_routing_replay_out lacks
a GPU compute-capability guard—use the flashinfer.utils helpers to skip
unsupported architectures: call get_compute_capability() at the start of
test_routing_replay_out and use is_sm90a_supported(cc) / is_sm100a_supported(cc)
(or their negations) to pytest.skip when the current GPU isn't supported; add
this check near the top of the function alongside the existing pytest.skip
conditions so the test is skipped on unsupported SM versions instead of failing.
---
Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1807-1901: The review points out that routing_replay_out is
forwarded into the kernels without validation; before calling
launcher->set_routing_replay_out(routing_replay_out) (and before
constructing/initializing Fp8BlockScaleLauncher), validate that
routing_replay_out.has_value() only when its TensorView has the expected dtype,
shape and device (e.g., correct ndim/size matching num_tokens/top_k and expected
integer dtype), and if invalid either clear the Optional or throw with a clear
message; update the call site around set_routing_replay_out and any code that
dereferences routing_replay_out inside Fp8BlockScaleLauncher to rely on this
validated contract.
---
Nitpick comments:
In `@tests/model_optimizations/test_dsv3_fused_routing.py`:
- Around line 560-566: Replace the set-based comparison with a direct tensor
equality check between routing_replay_out and topk_indices for each token: for
each t in the loop, assert that the tensors for that token are equal (e.g., use
torch.equal or (tensor_a == tensor_b).all()) so duplicates and ordering
mismatches are caught; update the assertion message to include the tensors
(routing_replay_out[t], topk_indices[t]) for clearer failure diagnostics.
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 655-666: Compute a reference routing for the same logits and top_k
and assert routing_replay_out equals that reference instead of only
range/uniqueness checks: use the test's logits and the same top_k to produce
expected_routing (e.g., by applying torch.topk on logits along expert dimension
and mapping to expert IDs), then check expected_routing.shape ==
routing_replay_out.shape and torch.equal(routing_replay_out, expected_routing)
(or per-token equality with clear assertion messages); reference symbols:
routing_replay_out, logits, top_k, num_tokens, num_experts.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ee2eb0d1-eac2-4ab5-a15f-95ec508c12ea
📒 Files selected for processing (12)
csrc/fused_moe/noAuxTcKernels.cucsrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_routing_deepseek.cucsrc/trtllm_fused_moe_runner.cucsrc/tvm_ffi_utils.hflashinfer/fused_moe/core.pyflashinfer/fused_moe/fused_routing_dsv3.pyinclude/flashinfer/trtllm/fused_moe/RoutingKernel.hinclude/flashinfer/trtllm/fused_moe/noAuxTcKernels.hinclude/flashinfer/trtllm/fused_moe/runner.htests/model_optimizations/test_dsv3_fused_routing.pytests/moe/test_trtllm_gen_routed_fused_moe.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
385-389: Consider centralizing replay pointer extraction to remove duplication.The same optional-to-pointer conversion appears in three
run()implementations. A tiny helper (e.g.,get_routing_replay_ptr()) would reduce drift risk across launcher variants.Also applies to: 1091-1095, 1563-1567
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 385 - 389, Three run() implementations duplicate conversion from routing_replay_out to an int16_t* (e.g., the snippet that sets replay_ptr from routing_replay_out.value().data_ptr()); extract this logic into a small helper function such as get_routing_replay_ptr(const c10::optional<Tensor>& routing_replay_out) that returns int16_t* (nullptr when not present) and replace the inline conversions in the run() methods (and the occurrences around the other blocks you noted) with calls to that helper to centralize behavior and prevent drift.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1869-1880: The validation for routing_replay_out currently checks
shape/dtype/device but does not ensure it's only used with the DeepSeekV3
routing implementation; add a fail-fast guard that checks the active routing
method (compare against RoutingMethodType::DeepSeekV3) before accepting
routing_replay_out and raise a clear error (TVM_FFI_ICHECK or equivalent) if
routing_replay_out.has_value() while the routing method is not DeepSeekV3 so we
don't silently accept unsupported replay buffers; place this check adjacent to
the existing routing_replay_out block (same scope) referencing
routing_replay_out and RoutingMethodType::DeepSeekV3.
---
Nitpick comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 385-389: Three run() implementations duplicate conversion from
routing_replay_out to an int16_t* (e.g., the snippet that sets replay_ptr from
routing_replay_out.value().data_ptr()); extract this logic into a small helper
function such as get_routing_replay_ptr(const c10::optional<Tensor>&
routing_replay_out) that returns int16_t* (nullptr when not present) and replace
the inline conversions in the run() methods (and the occurrences around the
other blocks you noted) with calls to that helper to centralize behavior and
prevent drift.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 75a19f9d-6934-42a2-83e0-f704c3efa2ec
📒 Files selected for processing (2)
csrc/fused_moe/noAuxTcKernels.cucsrc/trtllm_fused_moe_kernel_launcher.cu
| if (routing_replay_out.has_value()) { | ||
| auto replay = routing_replay_out.value(); | ||
| TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) | ||
| << "routing_replay_out must be a CUDA tensor"; | ||
| TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) | ||
| << "routing_replay_out must be on the same device as hidden_states"; | ||
| TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; | ||
| TVM_FFI_ICHECK(replay.size(0) == num_tokens) << "routing_replay_out dim0 must equal num_tokens"; | ||
| TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; | ||
| TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code) | ||
| << "routing_replay_out must be int16 dtype"; | ||
| } |
There was a problem hiding this comment.
Fail fast when replay is requested for unsupported routing methods.
routing_replay_out is validated for shape/dtype/device, but there is no guard that it is only used with RoutingMethodType::DeepSeekV3. Today this can silently accept replay buffers for unsupported routing methods and produce no replay writes.
🔧 Suggested guard
if (routing_replay_out.has_value()) {
+ TVM_FFI_ICHECK_EQ(static_cast<RoutingMethodType>(routing_method_type),
+ RoutingMethodType::DeepSeekV3)
+ << "routing_replay_out is currently supported only for DeepSeekV3 routing";
auto replay = routing_replay_out.value();
TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
<< "routing_replay_out must be a CUDA tensor";Based on learnings: In csrc/trtllm_fused_moe_runner.cu, routingReplayOut is intentionally wired only into the DeepSeekV3 routing path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1869 - 1880, The
validation for routing_replay_out currently checks shape/dtype/device but does
not ensure it's only used with the DeepSeekV3 routing implementation; add a
fail-fast guard that checks the active routing method (compare against
RoutingMethodType::DeepSeekV3) before accepting routing_replay_out and raise a
clear error (TVM_FFI_ICHECK or equivalent) if routing_replay_out.has_value()
while the routing method is not DeepSeekV3 so we don't silently accept
unsupported replay buffers; place this check adjacent to the existing
routing_replay_out block (same scope) referencing routing_replay_out and
RoutingMethodType::DeepSeekV3.
Add an optional `routing_replay_out` parameter to the fused routing and MoE kernels. When a pre-allocated int16 tensor of shape [num_tokens, topk] is provided, the CUDA routing kernel writes selected expert IDs per token directly into it during routing — with zero overhead when None. This enables downstream frameworks (e.g. vLLM) to capture expert routing decisions from the monolithic fused MoE code path without modifying the router. Changes: - CUDA kernels: plumb routing_replay_out through deepseek_v3_topk_kernel, routingMainKernel, invokeNoAuxTc, and all three launcher classes (FusedMoeLauncher, Fp8BlockScaleLauncher, FP4BlockScaleLauncher) - C++ bindings: add int16_code constant, Optional<TensorView> parameter with shape/dtype validation in NoAuxTc and trtllm_fp8_block_scale_moe - Python API: add routing_replay_out to fused_topk_deepseek, trtllm_fp8_block_scale_moe, and their torch.compile registrations (custom_op, fake_op, mutates_args) - Input validation in _check_dsv3_fused_routing_supported - Tests for both standalone routing and end-to-end FP8 block-scale MoE Signed-off-by: Tomer Natan <tbarnatan@nvidia.com> Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com> Made-with: Cursor Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-7.nvidia.com> Made-with: Cursor
…entry points Address review feedback: validate routing_replay_out tensor at both C++ entry points (NoAuxTc and trtllm_fp8_block_scale_moe) to prevent out-of-bounds writes from incorrectly sized, typed, or off-device tensors. Signed-off-by: Tomer Natan <tbarnatan@nvidia.com> Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-7.nvidia.com> Made-with: Cursor
ddfc58b to
4695a72
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)
1633-1661:⚠️ Potential issue | 🟠 MajorReject
routing_replay_outfor non-DeepSeek routing modes.This now advertises replay for every
routing_method_type, but the backend only writes replay IDs on theDeepSeekV3path.Renormalize,RenormalizeNaive,TopK, andLlama4callers will silently get untouched/stale data instead of expert IDs.🛠️ Suggested guard
def trtllm_fp8_block_scale_moe( @@ fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, routing_replay_out: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: + if ( + routing_replay_out is not None + and routing_method_type != RoutingMethodType.DeepSeekV3.value + ): + raise NotImplementedError( + "routing_replay_out is only supported with RoutingMethodType.DeepSeekV3." + ) + """FP8 block scale MoE operation.Based on learnings: "In
csrc/trtllm_fused_moe_runner.cu,routingReplayOutis intentionally wired only into theDeepSeekV3routing path. The Llama4 and Renormalize/TopK CUDA kernels do not implement the replay write logic, so passing the pointer to those data structs would silently have no effect. Extension to those routing methods is planned as a future follow-up."Also applies to: 1758-1785, 2525-2619
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 599-645: The two calls to trtllm_fp8_block_scale_moe are passing
enable_pdl positionally into the do_finalize parameter (after weight_layout)
which is incorrect; update both invocations (the first that passes enable_pdl
then routing_replay_out, and the second that passes enable_pdl then
routing_replay_out=None) to pass do_finalize and enable_pdl as explicit keyword
arguments (e.g., do_finalize=<bool>, enable_pdl=<bool>) so the boolean flags
bind to the correct parameters and leave routing_replay_out passed as before.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bc7cb548-27c8-4f31-9ff9-ca55d21269f7
📒 Files selected for processing (12)
csrc/fused_moe/noAuxTcKernels.cucsrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_routing_deepseek.cucsrc/trtllm_fused_moe_runner.cucsrc/tvm_ffi_utils.hflashinfer/fused_moe/core.pyflashinfer/fused_moe/fused_routing_dsv3.pyinclude/flashinfer/trtllm/fused_moe/RoutingKernel.hinclude/flashinfer/trtllm/fused_moe/noAuxTcKernels.hinclude/flashinfer/trtllm/fused_moe/runner.htests/model_optimizations/test_dsv3_fused_routing.pytests/moe/test_trtllm_gen_routed_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (5)
- include/flashinfer/trtllm/fused_moe/RoutingKernel.h
- csrc/trtllm_fused_moe_runner.cu
- include/flashinfer/trtllm/fused_moe/runner.h
- csrc/tvm_ffi_utils.h
- tests/model_optimizations/test_dsv3_fused_routing.py
| output_with_replay = trtllm_fp8_block_scale_moe( | ||
| routing_logits, | ||
| None, # routing_bias | ||
| hidden_states, | ||
| hidden_states_scale, | ||
| gemm1_weights, | ||
| gemm1_weights_scale, | ||
| gemm2_weights, | ||
| gemm2_weights_scale, | ||
| num_experts, | ||
| top_k, | ||
| None, # n_group | ||
| None, # topk_group | ||
| intermediate_size, | ||
| 0, # local_expert_offset | ||
| num_experts, | ||
| None, # routed_scaling_factor | ||
| RoutingMethodType.Renormalize.value, | ||
| False, # use_shuffled_weight | ||
| 0, # weight_layout | ||
| enable_pdl, | ||
| routing_replay_out=routing_replay_out, | ||
| ) | ||
|
|
||
| output_without_replay = trtllm_fp8_block_scale_moe( | ||
| routing_logits, | ||
| None, | ||
| hidden_states, | ||
| hidden_states_scale, | ||
| gemm1_weights, | ||
| gemm1_weights_scale, | ||
| gemm2_weights, | ||
| gemm2_weights_scale, | ||
| num_experts, | ||
| top_k, | ||
| None, | ||
| None, | ||
| intermediate_size, | ||
| 0, | ||
| num_experts, | ||
| None, | ||
| RoutingMethodType.Renormalize.value, | ||
| False, | ||
| 0, | ||
| enable_pdl, | ||
| routing_replay_out=None, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '\nSignature:\n'
rg -n -A12 -B2 '^def trtllm_fp8_block_scale_moe\(' flashinfer/fused_moe/core.py
printf '\nNew test call sites:\n'
rg -n -A24 -B2 'output_(with|without)_replay = trtllm_fp8_block_scale_moe\(' tests/moe/test_trtllm_gen_routed_fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 2059
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Get the complete function signature with all parameters
rg -n -A30 '^def trtllm_fp8_block_scale_moe\(' flashinfer/fused_moe/core.pyRepository: flashinfer-ai/flashinfer
Length of output: 1343
Pass do_finalize and enable_pdl by keyword.
After weight_layout, the next positional parameter is do_finalize, not enable_pdl. Currently enable_pdl is being passed to the do_finalize parameter slot at line 619 and 643, which causes incorrect behavior.
🐛 Proposed fix
output_with_replay = trtllm_fp8_block_scale_moe(
routing_logits,
None, # routing_bias
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
num_experts,
top_k,
None, # n_group
None, # topk_group
intermediate_size,
0, # local_expert_offset
num_experts,
None, # routed_scaling_factor
RoutingMethodType.Renormalize.value,
False, # use_shuffled_weight
0, # weight_layout
- enable_pdl,
+ do_finalize=True,
+ enable_pdl=enable_pdl,
routing_replay_out=routing_replay_out,
)
output_without_replay = trtllm_fp8_block_scale_moe(
routing_logits,
None,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
num_experts,
top_k,
None,
None,
intermediate_size,
0,
num_experts,
None,
RoutingMethodType.Renormalize.value,
False,
0,
- enable_pdl,
+ do_finalize=True,
+ enable_pdl=enable_pdl,
routing_replay_out=None,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 599 - 645, The
two calls to trtllm_fp8_block_scale_moe are passing enable_pdl positionally into
the do_finalize parameter (after weight_layout) which is incorrect; update both
invocations (the first that passes enable_pdl then routing_replay_out, and the
second that passes enable_pdl then routing_replay_out=None) to pass do_finalize
and enable_pdl as explicit keyword arguments (e.g., do_finalize=<bool>,
enable_pdl=<bool>) so the boolean flags bind to the correct parameters and leave
routing_replay_out passed as before.
📌 Description
Motivation
vLLM's expert-parallel routing replay feature (
RoutedExpertsCapturer) needs to record which experts each token is routed to during MoE inference. This works fine for the non-fused path whereBaseRouter.select_experts()is called andcapture_fnfires. However, the TRTLLM-GEN fused MoE path is monolithic —Fp8MoEMethod.apply_monolithic()calls FlashInfer'strtllm_fp8_block_scale_moedirectly, bypassing the router entirely. This meanscapture_fnis never invoked and theRoutedExpertsCapturernever receives expert IDs from this code path.Without this change, any vLLM deployment using the TRTLLM-GEN fused MoE backend cannot use routing replay, which is required for expert-parallel inference and load-balancing analytics.
What this PR does
Adds an optional
routing_replay_outparameter to FlashInfer's fused routing and MoE kernels. When a pre-allocatedint16tensor of shape[num_tokens, topk]is provided, the CUDA routing kernel writes selected expert IDs per token directly into it during routing — inside the same fused kernel call that computes the MoE output. WhenNone(the default), the kernel skips the write entirely with zero overhead.API surface (all backward-compatible, default
None):flashinfer.fused_moe.fused_topk_deepseek(..., routing_replay_out=None)flashinfer.fused_moe.trtllm_fp8_block_scale_moe(..., routing_replay_out=None)Changes across the stack:
routing_replay_outthroughdeepseek_v3_topk_kernel,routingMainKernel,invokeNoAuxTc, and all three launcher classes (FusedMoeLauncher,Fp8BlockScaleLauncher,FP4BlockScaleLauncher)int16_codeconstant,Optional<TensorView>parameter with shape/dtype validation inNoAuxTcandtrtllm_fp8_block_scale_moerouting_replay_outtofused_topk_deepseek,trtllm_fp8_block_scale_moe, and theirtorch.compileregistrations (custom_op,fake_op,mutates_args)_check_dsv3_fused_routing_supportedvLLM integration status
This feature has been tested end-to-end with vLLM's
RoutedExpertsCapturerand the corresponding vLLM integration PR is in progress toward merging into vLLM main. The vLLM-side change is minimal (3 files, ~15 lines) — it slices the existing_RoutedExpertsDeviceCachebuffer and passes it asrouting_replay_out.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).New tests added:
test_routing_replay_out— standalonefused_topk_deepseekrouting kernel: verifiesrouting_replay_outmatchestopk_indicesper token and thatNoneproduces identical results (no side effects)test_fp8_block_scale_moe_routing_replay— end-to-end FP8 block-scale MoE: verifies replay IDs are valid, unique per token, and MoE output is bit-identical with/without replay (SM100+)Reviewer Notes
routing_replay_outdefaults toNone, so existing callers are unaffected.routing_replay_out is Noneis a single pointer null-check per thread — effectively zero.Summary by CodeRabbit
New Features
API
Documentation
Tests