Skip to content

feat: support routing replay in trtllm_fp8_block_scale_moe and fused_topk_deepseek#2685

Open
TomerBN-Nvidia wants to merge 2 commits intoflashinfer-ai:mainfrom
TomerBN-Nvidia:support-routing-replay
Open

feat: support routing replay in trtllm_fp8_block_scale_moe and fused_topk_deepseek#2685
TomerBN-Nvidia wants to merge 2 commits intoflashinfer-ai:mainfrom
TomerBN-Nvidia:support-routing-replay

Conversation

@TomerBN-Nvidia
Copy link

@TomerBN-Nvidia TomerBN-Nvidia commented Mar 4, 2026

📌 Description

Motivation

vLLM's expert-parallel routing replay feature (RoutedExpertsCapturer) needs to record which experts each token is routed to during MoE inference. This works fine for the non-fused path where BaseRouter.select_experts() is called and capture_fn fires. However, the TRTLLM-GEN fused MoE path is monolithicFp8MoEMethod.apply_monolithic() calls FlashInfer's trtllm_fp8_block_scale_moe directly, bypassing the router entirely. This means capture_fn is never invoked and the RoutedExpertsCapturer never receives expert IDs from this code path.

Without this change, any vLLM deployment using the TRTLLM-GEN fused MoE backend cannot use routing replay, which is required for expert-parallel inference and load-balancing analytics.

What this PR does

Adds an optional routing_replay_out parameter to FlashInfer's fused routing and MoE kernels. When a pre-allocated int16 tensor of shape [num_tokens, topk] is provided, the CUDA routing kernel writes selected expert IDs per token directly into it during routing — inside the same fused kernel call that computes the MoE output. When None (the default), the kernel skips the write entirely with zero overhead.

API surface (all backward-compatible, default None):

  • flashinfer.fused_moe.fused_topk_deepseek(..., routing_replay_out=None)
  • flashinfer.fused_moe.trtllm_fp8_block_scale_moe(..., routing_replay_out=None)

Changes across the stack:

  • CUDA kernels: plumb routing_replay_out through deepseek_v3_topk_kernel, routingMainKernel, invokeNoAuxTc, and all three launcher classes (FusedMoeLauncher, Fp8BlockScaleLauncher, FP4BlockScaleLauncher)
  • C++ bindings: add int16_code constant, Optional<TensorView> parameter with shape/dtype validation in NoAuxTc and trtllm_fp8_block_scale_moe
  • Python API: add routing_replay_out to fused_topk_deepseek, trtllm_fp8_block_scale_moe, and their torch.compile registrations (custom_op, fake_op, mutates_args)
  • Input validation: dtype and shape checks in _check_dsv3_fused_routing_supported

vLLM integration status

This feature has been tested end-to-end with vLLM's RoutedExpertsCapturer and the corresponding vLLM integration PR is in progress toward merging into vLLM main. The vLLM-side change is minimal (3 files, ~15 lines) — it slices the existing _RoutedExpertsDeviceCache buffer and passes it as routing_replay_out.

🔍 Related Issues

  • vLLM expert-parallel routing replay for the TRTLLM-GEN fused MoE path

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

New tests added:

  • test_routing_replay_out — standalone fused_topk_deepseek routing kernel: verifies routing_replay_out matches topk_indices per token and that None produces identical results (no side effects)
  • test_fp8_block_scale_moe_routing_replay — end-to-end FP8 block-scale MoE: verifies replay IDs are valid, unique per token, and MoE output is bit-identical with/without replay (SM100+)

Reviewer Notes

  • All changes are backward-compatible: routing_replay_out defaults to None, so existing callers are unaffected.
  • The overhead when routing_replay_out is None is a single pointer null-check per thread — effectively zero.
  • The write when enabled is K coalesced int16 stores per token (e.g., 8 stores for DeepSeek-V3's top-8), negligible relative to the routing computation.

Summary by CodeRabbit

  • New Features

    • Optional routing-replay output: record per-token selected expert IDs (num_tokens × topk) for replay/debug; written only when provided.
  • API

    • Public ops and launchers accept an optional routing_replay_out tensor/pointer (int16) and propagate it through routing paths without changing default behavior.
  • Documentation

    • Docstrings and operator metadata updated to describe routing_replay_out and its in-place semantics.
  • Tests

    • Added tests confirming replay contents and that providing replay produces no side effects.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

📝 Walkthrough

Walkthrough

Threads an optional int16_t* routing replay output through Python APIs → launcher → runner → kernel. When provided, kernels record selected expert IDs per token in a [num_tokens, topk] int16 buffer; behavior is unchanged when the pointer is null.

Changes

Cohort / File(s) Summary
Routing Kernel Implementation
csrc/fused_moe/noAuxTcKernels.cu, csrc/trtllm_fused_moe_routing_deepseek.cu
Extended deepseek_v3_topk_kernel to accept int16_t* routingReplayOut and optionally write expertIdx per token/topk; updated kernel invocation to pass replay pointer.
Launcher & Runner Implementation
csrc/trtllm_fused_moe_kernel_launcher.cu, csrc/trtllm_fused_moe_runner.cu
Added Optional<TensorView> routing_replay_out member and setter to FusedMoeLauncher; threaded routing_replay_out into launcher run paths and passed an int16_t* replay pointer into Runner::run and kernel calls.
Kernel & Runner Headers
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h, include/flashinfer/trtllm/fused_moe/runner.h, include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Updated invokeNoAuxTc and Runner::run signatures to accept routing replay pointer; added mPtrRoutingReplayOut members and param wiring in kernel/params structures.
Python API Bindings & Wiring
flashinfer/fused_moe/core.py, flashinfer/fused_moe/fused_routing_dsv3.py
Added routing_replay_out: Optional[torch.Tensor] to public ops and fused routing APIs; validate dtype/shape (int16, [num_tokens, topk]), add to mutates_args, and forward to native module.
Utilities
csrc/tvm_ffi_utils.h
Added constexpr int64_t int16_code for DLPack int16 dtype encoding.
Tests
tests/model_optimizations/test_dsv3_fused_routing.py, tests/moe/test_trtllm_gen_routed_fused_moe.py
Added tests test_routing_replay_out and test_fp8_block_scale_moe_routing_replay to validate replay buffer contents and ensure no side-effects on primary outputs.

Sequence Diagram(s)

sequenceDiagram
  participant Python
  participant Launcher
  participant Runner
  participant Kernel
  participant ReplayBuf as "Replay Buffer (int16_t*)"

  Python->>Launcher: trtllm_..._moe(..., routing_replay_out)
  Launcher->>Runner: set_routing_replay_out(...) / run(..., routingReplayOut)
  Runner->>Kernel: invokeNoAuxTc(..., routing_replay_out)
  Kernel->>ReplayBuf: write expertIdx at [token, topk] if pointer non-null
  Kernel-->>Runner: return topk_values/topk_indices
  Runner-->>Launcher: propagate outputs
  Launcher-->>Python: return (topk_values, topk_indices, mutated routing_replay_out)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

run-ci, op: moe

Suggested reviewers

  • cyx-6
  • jiahanc
  • yzh119
  • djmmoss
  • aleozlx
  • nvmbreughe
  • wenscarl

Poem

🐇 I hop through kernels, pointer in paw,
I jot each expert, neat and raw.
For every token's top-K I write,
An int16 trail by day and night.
Optional replay — recorded just right.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately describes the main feature addition—adding routing replay support to two key MoE functions—and is concise and specific.
Description check ✅ Passed The PR description comprehensively covers motivation, detailed implementation changes across the stack, new tests, and backward-compatibility guarantees, following the template structure.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a limitation in vLLM's expert-parallel routing replay feature when used with FlashInfer's TRTLLM-GEN fused Mixture-of-Experts (MoE) backend. By introducing an optional output tensor, routing_replay_out, to the fused routing kernels, the system can now capture and record the expert IDs selected for each token. This enhancement allows for critical expert-parallel inference and load-balancing analytics that were previously unavailable in the fused MoE path, without impacting performance for existing use cases.

Highlights

  • Routing Replay Support: Implemented an optional routing_replay_out parameter in FlashInfer's fused routing and MoE kernels (fused_topk_deepseek and trtllm_fp8_block_scale_moe) to record selected expert IDs per token.
  • vLLM Integration Enablement: This feature enables vLLM's expert-parallel routing replay (RoutedExpertsCapturer) for the TRTLLM-GEN fused MoE path, which previously bypassed the necessary expert ID capture.
  • Backward Compatibility and Performance: The new routing_replay_out parameter defaults to None, ensuring backward compatibility with existing callers and introducing zero overhead when the replay feature is not utilized.
  • Comprehensive Testing: Added new tests to verify the correctness of routing_replay_out in standalone routing kernels and end-to-end FP8 block-scale MoE, confirming valid expert IDs and bit-identical MoE output with and without replay.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/fused_moe/noAuxTcKernels.cu
    • Added routingReplayOut parameter to deepseek_v3_topk_kernel to enable expert ID recording.
    • Implemented logic within deepseek_v3_topk_kernel to write selected expert IDs to routingReplayOut.
    • Updated invokeNoAuxTc function signature and calls to pass the new routing_replay_out parameter.
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Introduced routing_replay_out as an optional TensorView member in FusedMoeLauncher and its subclasses.
    • Added a set_routing_replay_out method to FusedMoeLauncher for setting the replay tensor.
    • Plumbed the routing_replay_out pointer through the run methods of FusedMoeLauncher, Fp8BlockScaleLauncher, and FP4BlockScaleLauncher to the underlying routing kernels.
    • Modified trtllm_fp8_block_scale_moe function signature to accept routing_replay_out and pass it to the launcher.
  • csrc/trtllm_fused_moe_routing_deepseek.cu
    • Added mPtrRoutingReplayOut member to the KernelParams struct for routing data.
    • Implemented expert ID recording logic within routingMainKernel to write to mPtrRoutingReplayOut.
  • csrc/trtllm_fused_moe_runner.cu
    • Updated Runner::run function signature to accept routingReplayOut.
    • Passed routingReplayOut to the moe::dev::routing::routingDeepSeek::run function.
  • csrc/tvm_ffi_utils.h
    • Added int16_code constant for DLPack dtype encoding, enabling int16 tensor validation.
  • flashinfer/fused_moe/core.py
    • Updated trtllm_fp8_block_scale_moe Python API to include routing_replay_out as an optional parameter.
    • Registered routing_replay_out as a mutating argument for torch.compile integration.
  • flashinfer/fused_moe/fused_routing_dsv3.py
    • Added routing_replay_out parameter to _check_dsv3_fused_routing_supported for input validation.
    • Implemented shape and dtype validation for routing_replay_out in _check_dsv3_fused_routing_supported.
    • Updated NoAuxTc and fused_topk_deepseek Python APIs to accept routing_replay_out.
    • Registered routing_replay_out as a mutating argument for flashinfer::NoAuxTc in torch.compile.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
    • Added mPtrRoutingReplayOut to DataBase and KernelParamsBase structs to hold the replay tensor pointer.
  • include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h
    • Updated the declaration of invokeNoAuxTc to include the routing_replay_out parameter.
  • include/flashinfer/trtllm/fused_moe/runner.h
    • Updated the declaration of Runner::run to include the routingReplayOut parameter.
  • tests/model_optimizations/test_dsv3_fused_routing.py
    • Added test_routing_replay_out to verify that routing_replay_out correctly records expert IDs for fused_topk_deepseek.
  • tests/moe/test_trtllm_gen_routed_fused_moe.py
    • Added test_fp8_block_scale_moe_routing_replay to validate end-to-end FP8 block-scale MoE with routing replay, ensuring output consistency and correct expert ID recording.
Activity
  • The author identified a gap in vLLM's expert-parallel routing replay for the TRTLLM-GEN fused MoE path, where expert IDs were not being captured.
  • The pull request introduces an optional routing_replay_out parameter to the relevant FlashInfer kernels and APIs to address this.
  • Changes were made across CUDA kernels, C++ bindings, and Python APIs to plumb this new parameter and implement the recording logic.
  • New tests were added to validate the functionality of the routing_replay_out parameter, both for standalone routing and end-to-end MoE operations.
  • The author confirmed that the changes are backward-compatible and incur negligible overhead when the replay feature is not used.
  • Pre-commit checks were run and passed, and all existing and new tests are passing.
  • The author noted that a corresponding vLLM integration PR is in progress, indicating a coordinated effort for this feature.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a routing_replay_out parameter to trtllm_fp8_block_scale_moe and fused_topk_deepseek to enable routing replay, a valuable feature for analyzing Mixture-of-Experts models. While the implementation is generally well-executed, a critical vulnerability exists: the trtllm_fp8_block_scale_moe path lacks validation for the new routing_replay_out tensor. This missing validation could lead to out-of-bounds memory writes in the CUDA kernel if an incorrectly sized or typed tensor is provided. It is recommended to add shape and dtype checks in both the Python wrapper and the C++ FFI implementation for the FP8 block-scale MoE path to mitigate this risk.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1807-1901: ⚠️ Potential issue | 🔴 Critical

Validate routing_replay_out before passing raw pointer to kernels.

routing_replay_out is forwarded without dtype/shape/device checks. A mismatched buffer can cause out-of-bounds writes or invalid-device access when the routing kernel writes replay IDs.

🛡️ Proposed fix
 Array<Tensor> trtllm_fp8_block_scale_moe(
@@
   auto const num_tokens = hidden_states.size(0);
   auto const hidden_size = hidden_states.size(1);
+  if (routing_replay_out.has_value()) {
+    auto const& replay = routing_replay_out.value();
+    TVM_FFI_ICHECK_EQ(replay.dtype(), dl_int16)
+        << "routing_replay_out must be int16.";
+    TVM_FFI_ICHECK_EQ(replay.ndim(), 2)
+        << "routing_replay_out must be 2D.";
+    TVM_FFI_ICHECK_EQ(replay.size(0), num_tokens)
+        << "routing_replay_out dim0 must match num_tokens.";
+    TVM_FFI_ICHECK_EQ(replay.size(1), top_k)
+        << "routing_replay_out dim1 must match top_k.";
+    TVM_FFI_ICHECK_EQ(replay.device().device_type, hidden_states.device().device_type)
+        << "routing_replay_out must be on the same device type as hidden_states.";
+    TVM_FFI_ICHECK_EQ(replay.device().device_id, hidden_states.device().device_id)
+        << "routing_replay_out must be on the same device as hidden_states.";
+  }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1807 - 1901, The
review points out that routing_replay_out is forwarded into the kernels without
validation; before calling launcher->set_routing_replay_out(routing_replay_out)
(and before constructing/initializing Fp8BlockScaleLauncher), validate that
routing_replay_out.has_value() only when its TensorView has the expected dtype,
shape and device (e.g., correct ndim/size matching num_tokens/top_k and expected
integer dtype), and if invalid either clear the Optional or throw with a clear
message; update the call site around set_routing_replay_out and any code that
dereferences routing_replay_out inside Fp8BlockScaleLauncher to rely on this
validated contract.
🧹 Nitpick comments (2)
tests/model_optimizations/test_dsv3_fused_routing.py (1)

560-566: Use tensor equality instead of set equality for replay validation.

Set comparison can miss multiplicity bugs (e.g., duplicate IDs). Since both tensors should represent the same top-k picks, compare tensors directly.

🔍 Suggested fix
-    for t in range(num_tokens):
-        replay_set = set(routing_replay_out[t].tolist())
-        indices_set = set(topk_indices[t].tolist())
-        assert replay_set == indices_set, (
-            f"Token {t}: routing_replay_out experts {replay_set} "
-            f"!= topk_indices experts {indices_set}"
-        )
+    torch.testing.assert_close(
+        routing_replay_out.to(torch.int32),
+        topk_indices,
+        msg="routing_replay_out must match topk_indices element-wise",
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/model_optimizations/test_dsv3_fused_routing.py` around lines 560 - 566,
Replace the set-based comparison with a direct tensor equality check between
routing_replay_out and topk_indices for each token: for each t in the loop,
assert that the tensors for that token are equal (e.g., use torch.equal or
(tensor_a == tensor_b).all()) so duplicates and ordering mismatches are caught;
update the assertion message to include the tensors (routing_replay_out[t],
topk_indices[t]) for clearer failure diagnostics.
tests/moe/test_trtllm_gen_routed_fused_moe.py (1)

655-666: Strengthen replay correctness check beyond range/uniqueness.

These assertions can still pass with wrong expert IDs. Consider validating replay IDs against reference routing outputs for the same logits/top-k.

✅ Suggested enhancement
+    permute_info, _ = routing_reference_renormalize(
+        routing_logits, top_k, num_experts, 8
+    )
+    expected_topk_ids = permute_info["topKIndices"].to(torch.int16)
+
     # Each token should have top_k unique experts
     for t in range(num_tokens):
         unique_experts = routing_replay_out[t].unique()
         assert unique_experts.numel() == top_k, (
             f"Token {t}: expected {top_k} unique experts, got {unique_experts.numel()}"
         )
+        assert set(routing_replay_out[t].tolist()) == set(expected_topk_ids[t].tolist())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 655 - 666,
Compute a reference routing for the same logits and top_k and assert
routing_replay_out equals that reference instead of only range/uniqueness
checks: use the test's logits and the same top_k to produce expected_routing
(e.g., by applying torch.topk on logits along expert dimension and mapping to
expert IDs), then check expected_routing.shape == routing_replay_out.shape and
torch.equal(routing_replay_out, expected_routing) (or per-token equality with
clear assertion messages); reference symbols: routing_replay_out, logits, top_k,
num_tokens, num_experts.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/fused_moe/noAuxTcKernels.cu`:
- Around line 353-364: The code validates routing_replay_out's shape and dtype
but does not check its device, allowing a CPU tensor or a tensor on a different
CUDA device to be cast to replay_ptr and passed to the kernel; update the block
that handles routing_replay_out (references: routing_replay_out, replay,
replay_ptr, encode_dlpack_dtype, int16_code, TVM_FFI_ICHECK) to assert the
tensor is on the expected CUDA device before extracting the raw pointer — e.g.,
verify replay.device().device_type is CUDA and replay.device().device_id equals
the current/target CUDA device (or error out via TVM_FFI_ICHECK with a clear
message) so the kernel receives only pointers to memory on the correct device.

In `@csrc/trtllm_fused_moe_runner.cu`:
- Around line 61-63: The code only sets routingData.mPtrRoutingReplayOut when
routingMethodType == RoutingMethodType::DeepSeekV3, so routingReplayOut passed
into the API is ignored for other methods; fix by unconditionally wiring
routingReplayOut into routingData.mPtrRoutingReplayOut before or outside the
routing-method switch/if (so Llama4/Renormalize/TopK also get the pointer), and
only use it later if non-null in the DeepSeekV3-specific logic; reference
routingReplayOut, RoutingMethodType, DeepSeekV3, and
routingData.mPtrRoutingReplayOut to locate the assignment to change.

In `@flashinfer/fused_moe/core.py`:
- Line 1829: The new fake-op parameter routing_replay_out is unused and triggers
ARG001; inside the fake-op function (the function that declares
routing_replay_out) explicitly consume it to silence the linter by adding a
no-op usage such as "_ = routing_replay_out" (or "del routing_replay_out" / "if
False: _ = routing_replay_out") at the start of the function body so the symbol
is referenced but behavior remains unchanged.

In `@tests/model_optimizations/test_dsv3_fused_routing.py`:
- Around line 504-533: The new parameterized test test_routing_replay_out lacks
a GPU compute-capability guard—use the flashinfer.utils helpers to skip
unsupported architectures: call get_compute_capability() at the start of
test_routing_replay_out and use is_sm90a_supported(cc) / is_sm100a_supported(cc)
(or their negations) to pytest.skip when the current GPU isn't supported; add
this check near the top of the function alongside the existing pytest.skip
conditions so the test is skipped on unsupported SM versions instead of failing.

---

Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1807-1901: The review points out that routing_replay_out is
forwarded into the kernels without validation; before calling
launcher->set_routing_replay_out(routing_replay_out) (and before
constructing/initializing Fp8BlockScaleLauncher), validate that
routing_replay_out.has_value() only when its TensorView has the expected dtype,
shape and device (e.g., correct ndim/size matching num_tokens/top_k and expected
integer dtype), and if invalid either clear the Optional or throw with a clear
message; update the call site around set_routing_replay_out and any code that
dereferences routing_replay_out inside Fp8BlockScaleLauncher to rely on this
validated contract.

---

Nitpick comments:
In `@tests/model_optimizations/test_dsv3_fused_routing.py`:
- Around line 560-566: Replace the set-based comparison with a direct tensor
equality check between routing_replay_out and topk_indices for each token: for
each t in the loop, assert that the tensors for that token are equal (e.g., use
torch.equal or (tensor_a == tensor_b).all()) so duplicates and ordering
mismatches are caught; update the assertion message to include the tensors
(routing_replay_out[t], topk_indices[t]) for clearer failure diagnostics.

In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 655-666: Compute a reference routing for the same logits and top_k
and assert routing_replay_out equals that reference instead of only
range/uniqueness checks: use the test's logits and the same top_k to produce
expected_routing (e.g., by applying torch.topk on logits along expert dimension
and mapping to expert IDs), then check expected_routing.shape ==
routing_replay_out.shape and torch.equal(routing_replay_out, expected_routing)
(or per-token equality with clear assertion messages); reference symbols:
routing_replay_out, logits, top_k, num_tokens, num_experts.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ee2eb0d1-eac2-4ab5-a15f-95ec508c12ea

📥 Commits

Reviewing files that changed from the base of the PR and between 635505f and d8d8638.

📒 Files selected for processing (12)
  • csrc/fused_moe/noAuxTcKernels.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_runner.cu
  • csrc/tvm_ffi_utils.h
  • flashinfer/fused_moe/core.py
  • flashinfer/fused_moe/fused_routing_dsv3.py
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/model_optimizations/test_dsv3_fused_routing.py
  • tests/moe/test_trtllm_gen_routed_fused_moe.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

385-389: Consider centralizing replay pointer extraction to remove duplication.

The same optional-to-pointer conversion appears in three run() implementations. A tiny helper (e.g., get_routing_replay_ptr()) would reduce drift risk across launcher variants.

Also applies to: 1091-1095, 1563-1567

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 385 - 389, Three run()
implementations duplicate conversion from routing_replay_out to an int16_t*
(e.g., the snippet that sets replay_ptr from
routing_replay_out.value().data_ptr()); extract this logic into a small helper
function such as get_routing_replay_ptr(const c10::optional<Tensor>&
routing_replay_out) that returns int16_t* (nullptr when not present) and replace
the inline conversions in the run() methods (and the occurrences around the
other blocks you noted) with calls to that helper to centralize behavior and
prevent drift.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1869-1880: The validation for routing_replay_out currently checks
shape/dtype/device but does not ensure it's only used with the DeepSeekV3
routing implementation; add a fail-fast guard that checks the active routing
method (compare against RoutingMethodType::DeepSeekV3) before accepting
routing_replay_out and raise a clear error (TVM_FFI_ICHECK or equivalent) if
routing_replay_out.has_value() while the routing method is not DeepSeekV3 so we
don't silently accept unsupported replay buffers; place this check adjacent to
the existing routing_replay_out block (same scope) referencing
routing_replay_out and RoutingMethodType::DeepSeekV3.

---

Nitpick comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 385-389: Three run() implementations duplicate conversion from
routing_replay_out to an int16_t* (e.g., the snippet that sets replay_ptr from
routing_replay_out.value().data_ptr()); extract this logic into a small helper
function such as get_routing_replay_ptr(const c10::optional<Tensor>&
routing_replay_out) that returns int16_t* (nullptr when not present) and replace
the inline conversions in the run() methods (and the occurrences around the
other blocks you noted) with calls to that helper to centralize behavior and
prevent drift.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 75a19f9d-6934-42a2-83e0-f704c3efa2ec

📥 Commits

Reviewing files that changed from the base of the PR and between d8d8638 and ddfc58b.

📒 Files selected for processing (2)
  • csrc/fused_moe/noAuxTcKernels.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu

Comment on lines +1869 to +1880
if (routing_replay_out.has_value()) {
auto replay = routing_replay_out.value();
TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
<< "routing_replay_out must be a CUDA tensor";
TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id)
<< "routing_replay_out must be on the same device as hidden_states";
TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]";
TVM_FFI_ICHECK(replay.size(0) == num_tokens) << "routing_replay_out dim0 must equal num_tokens";
TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k";
TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code)
<< "routing_replay_out must be int16 dtype";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast when replay is requested for unsupported routing methods.

routing_replay_out is validated for shape/dtype/device, but there is no guard that it is only used with RoutingMethodType::DeepSeekV3. Today this can silently accept replay buffers for unsupported routing methods and produce no replay writes.

🔧 Suggested guard
   if (routing_replay_out.has_value()) {
+    TVM_FFI_ICHECK_EQ(static_cast<RoutingMethodType>(routing_method_type),
+                      RoutingMethodType::DeepSeekV3)
+        << "routing_replay_out is currently supported only for DeepSeekV3 routing";
     auto replay = routing_replay_out.value();
     TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
         << "routing_replay_out must be a CUDA tensor";

Based on learnings: In csrc/trtllm_fused_moe_runner.cu, routingReplayOut is intentionally wired only into the DeepSeekV3 routing path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1869 - 1880, The
validation for routing_replay_out currently checks shape/dtype/device but does
not ensure it's only used with the DeepSeekV3 routing implementation; add a
fail-fast guard that checks the active routing method (compare against
RoutingMethodType::DeepSeekV3) before accepting routing_replay_out and raise a
clear error (TVM_FFI_ICHECK or equivalent) if routing_replay_out.has_value()
while the routing method is not DeepSeekV3 so we don't silently accept
unsupported replay buffers; place this check adjacent to the existing
routing_replay_out block (same scope) referencing routing_replay_out and
RoutingMethodType::DeepSeekV3.

Add an optional `routing_replay_out` parameter to the fused routing and
MoE kernels. When a pre-allocated int16 tensor of shape [num_tokens, topk]
is provided, the CUDA routing kernel writes selected expert IDs per token
directly into it during routing — with zero overhead when None.

This enables downstream frameworks (e.g. vLLM) to capture expert routing
decisions from the monolithic fused MoE code path without modifying the
router.

Changes:
- CUDA kernels: plumb routing_replay_out through deepseek_v3_topk_kernel,
  routingMainKernel, invokeNoAuxTc, and all three launcher classes
  (FusedMoeLauncher, Fp8BlockScaleLauncher, FP4BlockScaleLauncher)
- C++ bindings: add int16_code constant, Optional<TensorView> parameter
  with shape/dtype validation in NoAuxTc and trtllm_fp8_block_scale_moe
- Python API: add routing_replay_out to fused_topk_deepseek,
  trtllm_fp8_block_scale_moe, and their torch.compile registrations
  (custom_op, fake_op, mutates_args)
- Input validation in _check_dsv3_fused_routing_supported
- Tests for both standalone routing and end-to-end FP8 block-scale MoE

Signed-off-by: Tomer Natan <tbarnatan@nvidia.com>
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
Made-with: Cursor
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-7.nvidia.com>
Made-with: Cursor
…entry points

Address review feedback: validate routing_replay_out tensor at both C++
entry points (NoAuxTc and trtllm_fp8_block_scale_moe) to prevent
out-of-bounds writes from incorrectly sized, typed, or off-device tensors.

Signed-off-by: Tomer Natan <tbarnatan@nvidia.com>
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-7.nvidia.com>
Made-with: Cursor
@TomerBN-Nvidia TomerBN-Nvidia force-pushed the support-routing-replay branch from ddfc58b to 4695a72 Compare March 9, 2026 12:01
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)

1633-1661: ⚠️ Potential issue | 🟠 Major

Reject routing_replay_out for non-DeepSeek routing modes.

This now advertises replay for every routing_method_type, but the backend only writes replay IDs on the DeepSeekV3 path. Renormalize, RenormalizeNaive, TopK, and Llama4 callers will silently get untouched/stale data instead of expert IDs.

🛠️ Suggested guard
 def trtllm_fp8_block_scale_moe(
@@
     fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
     routing_replay_out: Optional[torch.Tensor] = None,
 ) -> Union[List[torch.Tensor], torch.Tensor]:
+    if (
+        routing_replay_out is not None
+        and routing_method_type != RoutingMethodType.DeepSeekV3.value
+    ):
+        raise NotImplementedError(
+            "routing_replay_out is only supported with RoutingMethodType.DeepSeekV3."
+        )
+
     """FP8 block scale MoE operation.

Based on learnings: "In csrc/trtllm_fused_moe_runner.cu, routingReplayOut is intentionally wired only into the DeepSeekV3 routing path. The Llama4 and Renormalize/TopK CUDA kernels do not implement the replay write logic, so passing the pointer to those data structs would silently have no effect. Extension to those routing methods is planned as a future follow-up."

Also applies to: 1758-1785, 2525-2619

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 599-645: The two calls to trtllm_fp8_block_scale_moe are passing
enable_pdl positionally into the do_finalize parameter (after weight_layout)
which is incorrect; update both invocations (the first that passes enable_pdl
then routing_replay_out, and the second that passes enable_pdl then
routing_replay_out=None) to pass do_finalize and enable_pdl as explicit keyword
arguments (e.g., do_finalize=<bool>, enable_pdl=<bool>) so the boolean flags
bind to the correct parameters and leave routing_replay_out passed as before.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bc7cb548-27c8-4f31-9ff9-ca55d21269f7

📥 Commits

Reviewing files that changed from the base of the PR and between ddfc58b and 4695a72.

📒 Files selected for processing (12)
  • csrc/fused_moe/noAuxTcKernels.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_runner.cu
  • csrc/tvm_ffi_utils.h
  • flashinfer/fused_moe/core.py
  • flashinfer/fused_moe/fused_routing_dsv3.py
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/model_optimizations/test_dsv3_fused_routing.py
  • tests/moe/test_trtllm_gen_routed_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • csrc/trtllm_fused_moe_runner.cu
  • include/flashinfer/trtllm/fused_moe/runner.h
  • csrc/tvm_ffi_utils.h
  • tests/model_optimizations/test_dsv3_fused_routing.py

Comment on lines +599 to +645
output_with_replay = trtllm_fp8_block_scale_moe(
routing_logits,
None, # routing_bias
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
num_experts,
top_k,
None, # n_group
None, # topk_group
intermediate_size,
0, # local_expert_offset
num_experts,
None, # routed_scaling_factor
RoutingMethodType.Renormalize.value,
False, # use_shuffled_weight
0, # weight_layout
enable_pdl,
routing_replay_out=routing_replay_out,
)

output_without_replay = trtllm_fp8_block_scale_moe(
routing_logits,
None,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
num_experts,
top_k,
None,
None,
intermediate_size,
0,
num_experts,
None,
RoutingMethodType.Renormalize.value,
False,
0,
enable_pdl,
routing_replay_out=None,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\nSignature:\n'
rg -n -A12 -B2 '^def trtllm_fp8_block_scale_moe\(' flashinfer/fused_moe/core.py

printf '\nNew test call sites:\n'
rg -n -A24 -B2 'output_(with|without)_replay = trtllm_fp8_block_scale_moe\(' tests/moe/test_trtllm_gen_routed_fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 2059


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get the complete function signature with all parameters
rg -n -A30 '^def trtllm_fp8_block_scale_moe\(' flashinfer/fused_moe/core.py

Repository: flashinfer-ai/flashinfer

Length of output: 1343


Pass do_finalize and enable_pdl by keyword.

After weight_layout, the next positional parameter is do_finalize, not enable_pdl. Currently enable_pdl is being passed to the do_finalize parameter slot at line 619 and 643, which causes incorrect behavior.

🐛 Proposed fix
     output_with_replay = trtllm_fp8_block_scale_moe(
         routing_logits,
         None,  # routing_bias
         hidden_states,
         hidden_states_scale,
         gemm1_weights,
         gemm1_weights_scale,
         gemm2_weights,
         gemm2_weights_scale,
         num_experts,
         top_k,
         None,  # n_group
         None,  # topk_group
         intermediate_size,
         0,  # local_expert_offset
         num_experts,
         None,  # routed_scaling_factor
         RoutingMethodType.Renormalize.value,
         False,  # use_shuffled_weight
         0,  # weight_layout
-        enable_pdl,
+        do_finalize=True,
+        enable_pdl=enable_pdl,
         routing_replay_out=routing_replay_out,
     )

     output_without_replay = trtllm_fp8_block_scale_moe(
         routing_logits,
         None,
         hidden_states,
         hidden_states_scale,
         gemm1_weights,
         gemm1_weights_scale,
         gemm2_weights,
         gemm2_weights_scale,
         num_experts,
         top_k,
         None,
         None,
         intermediate_size,
         0,
         num_experts,
         None,
         RoutingMethodType.Renormalize.value,
         False,
         0,
-        enable_pdl,
+        do_finalize=True,
+        enable_pdl=enable_pdl,
         routing_replay_out=None,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 599 - 645, The
two calls to trtllm_fp8_block_scale_moe are passing enable_pdl positionally into
the do_finalize parameter (after weight_layout) which is incorrect; update both
invocations (the first that passes enable_pdl then routing_replay_out, and the
second that passes enable_pdl then routing_replay_out=None) to pass do_finalize
and enable_pdl as explicit keyword arguments (e.g., do_finalize=<bool>,
enable_pdl=<bool>) so the boolean flags bind to the correct parameters and leave
routing_replay_out passed as before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants