Skip to content

[feat] Add routing_replay_out support to MoE kernels and Python API#3024

Merged
aleozlx merged 24 commits intoflashinfer-ai:mainfrom
TomerBN-Nvidia:upstream-routing-replay
Apr 14, 2026
Merged

[feat] Add routing_replay_out support to MoE kernels and Python API#3024
aleozlx merged 24 commits intoflashinfer-ai:mainfrom
TomerBN-Nvidia:upstream-routing-replay

Conversation

@TomerBN-Nvidia
Copy link
Copy Markdown
Contributor

@TomerBN-Nvidia TomerBN-Nvidia commented Apr 9, 2026

📌 Description

Add an optional routing_replay_out parameter to all MoE kernel entry points. When provided (int16 tensor of shape [N, top_k]), the routing kernel writes selected expert IDs per token during MoE routing — inside the same fused kernel that computes the MoE output. When None (default), zero overhead.

This enables routing replay for RL training workflows: an inference engine (e.g., vLLM) captures which experts were selected for each token and returns the data alongside the model output, so the training pipeline can replay the same routing decisions.

The previous Python callback approach in vLLM's router breaks under torch.compile + CUDA graphs (callback is traced once, tensor reference baked at trace time). This kernel-level approach works correctly with pre-allocated buffers and CUDA graph replay.

Changes

C++ kernel changes:

  • trtllm_fused_moe_kernel_launcher.cu: Add routing_replay_out field to FusedMoeLauncher base class. Add validation + passthrough in all entry points (trtllm_fp8_block_scale_moe, trtllm_bf16_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_fp4_block_scale_moe, trtllm_mxint4_block_scale_moe).
  • trtllm_fused_moe_runner.cu: Add int16_t* routing_replay_out to Runner::run(). Pass through to routing data struct for all routing method paths (DeepSeek, Llama4, Custom/Renormalize, MiniMax2).
  • noAuxTcKernels.cu: Add routing_replay_out to standalone DSV3 fused routing kernel and NoAuxTc entry point.
  • Routing kernel writes in trtllm_fused_moe_routing_deepseek.cu, trtllm_fused_moe_routing_custom.cu, trtllm_fused_moe_routing_llama4.cu, and noAuxTcKernels.cu.

Python API changes:

  • core.py: Add routing_replay_out parameter to trtllm_fp8_block_scale_moe(), trtllm_bf16_moe(), trtllm_bf16_routed_moe(), internal op functions, fake functions, and autotuner/launcher kwargs threading.
  • fused_routing_dsv3.py: Add routing_replay_out to fused_topk_deepseek() with int16 dtype validation and relaxed dim0 check (>= instead of ==).

routing_replay_out spec:

  • dtype: torch.int16
  • shape: (num_tokens_or_larger, top_k) — buffer may be larger than num_tokens for CUDA graph pre-allocation
  • layout: row-major. replay[t, k] = k-th ranked expert ID for token t
  • when None: zero overhead, kernel skips the write entirely
  • dim0 validation: >= (not ==) — the kernel determines write extent from routing_logits.shape[0]

🔍 Related Issues

vLLM's --enable-return-routed-experts for RL training pipelines. The previous Python callback approach breaks under torch.compile + CUDA graphs.

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

New tests:

  • test_dsv3_fused_routing.py::test_routing_replay_out — Verifies replay matches topk_indices and None has no side effects
  • test_trtllm_gen_routed_fused_moe.py::test_fp8_block_scale_moe_routing_replay — Verifies replay has zero effect on MoE output, IDs are valid, each token has top_k unique experts

Test plan:

  • pytest tests/model_optimizations/test_dsv3_fused_routing.py::test_routing_replay_out -v
  • pytest tests/moe/test_trtllm_gen_routed_fused_moe.py::test_fp8_block_scale_moe_routing_replay -v
  • Verify zero overhead when routing_replay_out=None (benchmark with and without)
  • Validated end-to-end with vLLM on Super MXFP8 (2 nodes, TP=4, DP=2) — non-zero routing data at 256 concurrency

Reviewer Notes

  • The dim0 validation uses >= instead of == intentionally — this allows CUDA graph pre-allocation of oversized buffers. The kernel determines write extent from routing_logits.shape[0].
  • Also includes docs/vllm_routing_replay_integration.md as an integration guide for vLLM consumers.

Summary by CodeRabbit

  • New Features

    • Added optional routing_replay_out parameter to TensorRT-LLM MoE operations, enabling recording of selected expert indices during routing for each token.
  • Documentation

    • Added guide for integrating routing replay with vLLM's MoE workflow.
  • Tests

    • Added validation tests for routing replay functionality across DeepSeek V3 and FP8 block-scale MoE kernels.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds optional routing replay capture functionality across the MoE kernel stack. When enabled via an optional routing_replay_out tensor parameter, the fused routing kernels record selected expert IDs per token in int16 format. The feature threads through CUDA kernels, C++ wrappers, and Python APIs, with optional validation and graceful degradation when the parameter is absent.

Changes

Cohort / File(s) Summary
CUDA Kernel Implementations
csrc/fused_moe/noAuxTcKernels.cu, csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu
Added conditional int16 writes of selected expert indices to routing_replay_out buffer when non-null, storing results at per-token, per-rank positions.
Kernel Runtime Launchers
csrc/trtllm_fused_moe_runner.cu, csrc/trtllm_fused_moe_kernel_launcher.cu
Extended Runner::run and launcher classes to accept routing_replay_out pointer, with validation logic ensuring CUDA tensor on same device, int16 dtype, 2D shape with size[1] == top_k, and forwarding to routing kernels.
Kernel Interface Headers
include/flashinfer/trtllm/fused_moe/RoutingKernel.h, include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h, include/flashinfer/trtllm/fused_moe/runner.h
Added mPtrRoutingReplayOut member to routing data/params structs; updated function signatures with optional int16 pointer parameter and modified stream parameter defaults.
Python API Bindings
flashinfer/fused_moe/core.py, flashinfer/fused_moe/fused_routing_dsv3.py
Extended public API functions and custom op registrations to accept optional routing_replay_out tensor, marked as mutated argument in custom op declarations, with tensor validation including dtype and shape checks.
Tests
tests/model_optimizations/test_dsv3_fused_routing.py, tests/moe/test_trtllm_gen_routed_fused_moe.py
Added pytest test cases validating routing replay captures correct expert IDs per token, output invariance with/without replay, and correct handling of oversized buffers.
Documentation
docs/vllm_routing_replay_integration.md
New guide documenting tensor contract (int16, 2D shape, per-token expert storage), control flow semantics, CUDA graph compatibility, and vLLM device buffer layout recommendations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~28 minutes

Possibly related issues

Possibly related PRs

Suggested labels

run-ci, op: moe-routing

Suggested reviewers

  • cyx-6
  • yzh119
  • jiahanc
  • aleozlx
  • IwakuraRein
  • nv-yunzheq
  • jimmyzho

Poem

🐰 A rabbit hops through routing kernels deep,
Recording expert choices we must keep,
In int16 arrays, neat and clean,
The finest routing replay ever seen! 🎯✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the main change: adding routing_replay_out support across MoE kernels and the Python API.
Description check ✅ Passed The description fully addresses the template requirements: it explains what the PR does, links related issues (vLLM RL training), includes pre-commit and test checklist items with status marks, and provides detailed reviewer notes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new "routing replay" feature for MoE (Mixture of Experts) kernels, allowing the recording of selected expert IDs into an optional int16 tensor (routing_replay_out) during the routing process. This functionality is integrated across various CUDA kernels (noAuxTcKernels.cu, trtllm_fused_moe_routing_custom.cu, trtllm_fused_moe_routing_deepseek.cu, trtllm_fused_moe_routing_llama4.cu), the FusedMoeLauncher and Runner classes, and exposed through Python bindings. New test cases verify the feature, and a documentation file (vllm_routing_replay_integration.md) explains its usage. Feedback indicates a duplicate test function definition in test_dsv3_fused_routing.py, and suggests refactoring duplicated validation logic and replay_ptr retrieval into helper functions/methods in csrc/trtllm_fused_moe_kernel_launcher.cu, along with adding a missing int16 dtype check to the validation.

Comment thread tests/model_optimizations/test_dsv3_fused_routing.py
Comment thread csrc/trtllm_fused_moe_kernel_launcher.cu
Comment thread csrc/trtllm_fused_moe_kernel_launcher.cu
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/fused_moe/core.py (3)

2428-2451: ⚠️ Potential issue | 🟠 Major

trtllm_bf16_routed_moe never forwards the new buffer.

The function now accepts routing_replay_out, but the delegated trtllm_bf16_moe(...) call still stops at norm_topk_prob. Routed BF16 callers will silently get no replay data even when they pass a buffer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2428 - 2451, The call inside
trtllm_bf16_routed_moe currently invokes trtllm_bf16_moe but stops at the
norm_topk_prob argument and therefore never forwards the new routing_replay_out
buffer; update the delegated call to pass the routing_replay_out argument
through to trtllm_bf16_moe (i.e., add the routing_replay_out parameter to the
argument list after the norm_topk_prob/place where additional trailing args
belong) so routed BF16 callers receive the replay data; check the
trtllm_bf16_routed_moe function and the trtllm_bf16_moe call site to ensure the
parameter name routing_replay_out matches and ordering aligns with
trtllm_bf16_moe's signature.

2523-2548: ⚠️ Potential issue | 🔴 Critical

routing_replay_out is undefined at line 2547 — this will raise NameError at runtime.

The function signature (line 2463) does not include routing_replay_out as a parameter, yet line 2547 forwards it to the CUDA kernel call. Either remove this argument from the call, or add the parameter to the function signature and update the corresponding op registration (trtllm_fp8_per_tensor_scale_moe_op and _fake_trtllm_fp8_per_tensor_scale_moe) to support replay if that is intended.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2523 - 2548, The call to
get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe(...) passes
routing_replay_out which is not defined in the surrounding function signature,
causing a NameError; either remove routing_replay_out from that call so the
kernel is invoked with the existing parameters, or add a routing_replay_out
parameter to the enclosing function signature and propagate it into the op
registration (trtllm_fp8_per_tensor_scale_moe_op and
_fake_trtllm_fp8_per_tensor_scale_moe) so the backend and fake op accept and
forward the replay output consistently.

1995-2029: ⚠️ Potential issue | 🔴 Critical

Add missing routing_replay_out parameter to FP4 and MxInt4 block-scale MoE functions.

The real op functions trtllm_fp4_block_scale_moe_op and trtllm_mxint4_block_scale_moe_op reference routing_replay_out in their calls to the underlying C++ ops (lines 2029 and 2214 respectively), but neither function nor their corresponding fake ops declare this parameter. This will cause NameError: name 'routing_replay_out' is not defined at runtime.

Add routing_replay_out: Optional[torch.Tensor] = None to both real op signatures and both fake op signatures (matching the pattern in _fake_trtllm_fp8_block_scale_moe), and include _ = routing_replay_out in the fake op bodies to suppress linting warnings.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1995 - 2029, The real op functions
trtllm_fp4_block_scale_moe_op and trtllm_mxint4_block_scale_moe_op are missing
the routing_replay_out parameter used when calling the underlying C++ ops; add
routing_replay_out: Optional[torch.Tensor] = None to both real op signatures and
propagate it into the calls (where routing_replay_out is passed); also update
the corresponding fake ops (_fake_trtllm_fp4_block_scale_moe and
_fake_trtllm_mxint4_block_scale_moe) to accept routing_replay_out:
Optional[torch.Tensor] = None and add a no-op reference (_ = routing_replay_out)
in each fake body to suppress lint warnings. Ensure the torch and
typing.Optional annotations are consistent with existing file style.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/fused_moe/noAuxTcKernels.cu`:
- Around line 352-367: The routing_replay_out validation allows undersized dim0;
ensure the buffer is at least num_tokens tall before kernel launch by replacing
the equality check with a lower-bound check: when
routing_replay_out.has_value(), validate replay.sizes()[0] >= num_tokens (in
addition to existing checks on dim1/topk and dtype) and only then set replay_ptr
so the kernel cannot write past the end of the routing_replay_out buffer;
reference the variables replay, replay.sizes(), routing_replay_out, replay_ptr,
num_tokens, and topk when making the change.

In `@csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu`:
- Around line 468-471: The replay buffer is only written in the large-token
kernel path, leaving mPtrRoutingReplayOut uninitialized for the common Llama4
score paths; update the single-warp and single-cluster kernels
(routingIndicesWarpKernel and routingIndicesClusterKernel) to check
params.mPtrRoutingReplayOut and write the selected expert id (cast to int16_t)
for the corresponding token index (use the same tokenIdx/warpMaxExpertIdx logic
used in the large-token kernel), so callers requesting replay get a populated
buffer across all execution paths.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 228-230: The setter set_routing_replay_out(const
Optional<TensorView>&) is currently protected but is invoked by free functions
(not class members), causing access errors; move the declaration/definition of
set_routing_replay_out and the member routing_replay_out into the public section
of the class (or alternatively wire routing_replay_out via the constructor/init
path and remove the protected setter) so free functions can call it; locate the
symbol set_routing_replay_out and routing_replay_out in
trtllm_fused_moe_kernel_launcher.cu and change the access specifier to public
(or add a public wrapper/constructor parameter) to resolve the compilation
error.

In `@tests/model_optimizations/test_dsv3_fused_routing.py`:
- Around line 559-566: The test currently compares routing_replay_out and
topk_indices using sets which ignores ordering; change the check to assert
positional equality per token by comparing the ordered sequences (e.g., compare
routing_replay_out[t].tolist() to topk_indices[t].tolist() or use an array_equal
on routing_replay_out[t] and topk_indices[t]) and update the assertion message
to show the ordered lists for failing tokens; apply the same fix to the
duplicate check at the later block referenced (lines 649-656) so both places
validate replay order, not just membership.
- Around line 594-680: The test function test_routing_replay_out is defined
twice causing the first definition to be overwritten; fix by either renaming one
of the duplicates (e.g., rename the second occurrence to
test_routing_replay_out_extended or test_routing_replay_out_stress) or by
merging their parametrize decorators into a single test (combine the differing
num_tokens lists and other params) so both parameter sets are covered; update
any references accordingly and ensure unique test names in the module.

---

Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2428-2451: The call inside trtllm_bf16_routed_moe currently
invokes trtllm_bf16_moe but stops at the norm_topk_prob argument and therefore
never forwards the new routing_replay_out buffer; update the delegated call to
pass the routing_replay_out argument through to trtllm_bf16_moe (i.e., add the
routing_replay_out parameter to the argument list after the norm_topk_prob/place
where additional trailing args belong) so routed BF16 callers receive the replay
data; check the trtllm_bf16_routed_moe function and the trtllm_bf16_moe call
site to ensure the parameter name routing_replay_out matches and ordering aligns
with trtllm_bf16_moe's signature.
- Around line 2523-2548: The call to
get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe(...) passes
routing_replay_out which is not defined in the surrounding function signature,
causing a NameError; either remove routing_replay_out from that call so the
kernel is invoked with the existing parameters, or add a routing_replay_out
parameter to the enclosing function signature and propagate it into the op
registration (trtllm_fp8_per_tensor_scale_moe_op and
_fake_trtllm_fp8_per_tensor_scale_moe) so the backend and fake op accept and
forward the replay output consistently.
- Around line 1995-2029: The real op functions trtllm_fp4_block_scale_moe_op and
trtllm_mxint4_block_scale_moe_op are missing the routing_replay_out parameter
used when calling the underlying C++ ops; add routing_replay_out:
Optional[torch.Tensor] = None to both real op signatures and propagate it into
the calls (where routing_replay_out is passed); also update the corresponding
fake ops (_fake_trtllm_fp4_block_scale_moe and
_fake_trtllm_mxint4_block_scale_moe) to accept routing_replay_out:
Optional[torch.Tensor] = None and add a no-op reference (_ = routing_replay_out)
in each fake body to suppress lint warnings. Ensure the torch and
typing.Optional annotations are consistent with existing file style.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1ee1a5d5-e45a-46d8-a2c3-947375e46edb

📥 Commits

Reviewing files that changed from the base of the PR and between c2b4db2 and 87e96a3.

📒 Files selected for processing (14)
  • csrc/fused_moe/noAuxTcKernels.cu
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_runner.cu
  • docs/vllm_routing_replay_integration.md
  • flashinfer/fused_moe/core.py
  • flashinfer/fused_moe/fused_routing_dsv3.py
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/model_optimizations/test_dsv3_fused_routing.py
  • tests/moe/test_trtllm_gen_routed_fused_moe.py

Comment thread csrc/fused_moe/noAuxTcKernels.cu
Comment thread csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu Outdated
Comment thread csrc/trtllm_fused_moe_kernel_launcher.cu Outdated
Comment thread tests/model_optimizations/test_dsv3_fused_routing.py
Comment thread tests/model_optimizations/test_dsv3_fused_routing.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer-cubin/flashinfer_cubin/__init__.py`:
- Line 81: The file currently hard-codes __version__ = "0.6.7", which overwrites
the dynamic value from _get_version(); remove the hard-coded assignment and
ensure __version__ is set only via the call to _get_version() (or a proper
fallback if _get_version() fails) so that the module-level __version__ uses the
dynamically computed version; update the module to rely on the existing
_get_version() function and delete or replace the literal assignment to
__version__.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 69c8437a-bbcb-4983-9d4b-26873653142d

📥 Commits

Reviewing files that changed from the base of the PR and between 375a229 and c0961c9.

📒 Files selected for processing (1)
  • flashinfer-cubin/flashinfer_cubin/__init__.py

Comment thread flashinfer-cubin/flashinfer_cubin/__init__.py Outdated
@TomerBN-Nvidia TomerBN-Nvidia force-pushed the upstream-routing-replay branch from c0961c9 to e439466 Compare April 9, 2026 14:29
Tomer Natan added 6 commits April 12, 2026 05:16
Add optional int16_t* routing_replay_out parameter to the standalone DSV3
fused routing kernel (noAuxTcKernels). When provided, writes selected expert
IDs per token during routing. Includes input validation in the entry point.
Add optional routing_replay_out parameter to trtllm_fp8_block_scale_moe(),
trtllm_bf16_moe(), trtllm_bf16_routed_moe(), and fused_topk_deepseek().
Relaxed dim0 validation (>= instead of ==) for CUDA graph compatibility.
Thread through autotuner and launcher classes via kwargs.
Add test_routing_replay_out for DSV3 fused routing: verifies replay tensor
matches topk_indices and passing None has no side effects.

Add test_fp8_block_scale_moe_routing_replay for FP8 MoE: verifies replay
has no effect on MoE output, expert IDs are valid, and each token has
exactly top_k unique experts.
- Add int16 dtype check to all 5 launcher validation blocks (was missing,
  could silently corrupt memory on wrong dtype)
- Add routing_replay_out param to trtllm_fp4_block_scale_moe(),
  trtllm_mxint4_block_scale_moe(), trtllm_fp8_per_tensor_scale_moe()
  public functions
@TomerBN-Nvidia TomerBN-Nvidia force-pushed the upstream-routing-replay branch from 60b321d to 5a8c656 Compare April 12, 2026 12:17
@TomerBN-Nvidia
Copy link
Copy Markdown
Contributor Author

Addressing CodeRabbit "Outside diff range" findings (review 4081573685)

All three Python API issues from the second review are fixed in 60b321d:

  1. trtllm_bf16_routed_moe silently dropping routing_replay_out — The parameter was accepted but not forwarded to the underlying trtllm_bf16_moe() call. Fixed: added routing_replay_out to the call.

  2. FP4 and MXINT4 ops missing routing_replay_out — The trtllm_fp4_block_scale_moe_op and trtllm_mxint4_block_scale_moe_op custom ops (and their fake ops) were missing routing_replay_out in their signatures, while the function body referenced it — would have been a NameError under torch.compile. Fixed: added the parameter to all four function signatures, added mutates_args=("routing_replay_out",), and forwarded it through the public API wrappers.

  3. trtllm_fp8_per_tensor_scale_moe "undefined variable" — This was a false positive. The parameter IS in the function signature at line 2488 and correctly passed at line 2549. No fix needed.

Also applied pre-commit formatting fixes (clang-format, ruff) to ensure CI passes.

…, duplicate test, Llama4 replay gaps

- Rename duplicate test_routing_replay_out → test_routing_replay_out_extended (pytest collision)
- Forward routing_replay_out in trtllm_bf16_routed_moe (was silently dropped)
- Add routing_replay_out to FP4/MXINT4 op signatures, fake ops, and public APIs (NameError)
- Move set_routing_replay_out() from protected to public in FusedMoeLauncher
- Add routing replay writes to Llama4 warp and cluster kernel paths
- Add explanatory comment for intentional dim0 validation omission (CUDA graph pre-alloc)
- Apply pre-commit formatting fixes (clang-format, ruff)
@TomerBN-Nvidia TomerBN-Nvidia force-pushed the upstream-routing-replay branch from 5a8c656 to 1fb8454 Compare April 12, 2026 12:25
@TomerBN-Nvidia TomerBN-Nvidia requested a review from amirkl94 April 13, 2026 08:33
Comment thread csrc/tvm_ffi_utils.h Outdated
Comment thread flashinfer/fused_moe/core.py
Comment thread csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu Outdated
…wap condition order

- Move using tvm::ffi::Optional from shared header to noAuxTcKernels.cu only
- Add _validate_routing_replay_out() Python helper with shape/dtype/contiguity checks
- Call validation in all 6 public API functions before C++ dispatch
- Swap condition order in DeepSeek routing kernel: nullptr check first (cheaper)
@aleozlx aleozlx added the run-ci label Apr 13, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 13, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !539 has been created, and the CI pipeline #48422232 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !539 has been updated with latest changes, and the CI pipeline #48453498 is currently running. I'll report back once the pipeline job completes.

@amitz-nv
Copy link
Copy Markdown
Contributor

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !539 has been updated with latest changes, and the CI pipeline #48479103 is currently running. I'll report back once the pipeline job completes.

Tomer Natan and others added 2 commits April 14, 2026 04:12
Inserting the field in the middle of DataBase and KernelParamsBase
shifted memory offsets for all subsequent fields, causing GEMM crashes
in FP8/FP4 autotuner tests (11/15 failures). Moving to end preserves
the original layout for all existing fields.

Also adds missing routing_replay_out arg to MXINT4 and FP4 paths
in MoERunner._run().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
fix: move mPtrRoutingReplayOut to end of routing structs
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

tests seem clean

@aleozlx aleozlx merged commit 5056474 into flashinfer-ai:main Apr 14, 2026
28 of 34 checks passed
aleozlx added a commit to nv-yunzheq/flashinfer that referenced this pull request Apr 14, 2026
…eplay_out)

Merged origin/main into nv-yunzheq/DSR1_shared_expert_fusion.
New conflicts from PR flashinfer-ai#3024 (routing_replay_out support) resolved by
keeping both PR's shared-expert fields and main's routing replay fields.

Co-Authored-By: Claude <noreply@anthropic.com>
aleozlx added a commit that referenced this pull request Apr 15, 2026
…3024)

## 📌 Description

Add an optional `routing_replay_out` parameter to all MoE kernel entry
points. When provided (int16 tensor of shape `[N, top_k]`), the routing
kernel writes selected expert IDs per token during MoE routing — inside
the same fused kernel that computes the MoE output. When `None`
(default), zero overhead.

This enables **routing replay** for RL training workflows: an inference
engine (e.g., vLLM) captures which experts were selected for each token
and returns the data alongside the model output, so the training
pipeline can replay the same routing decisions.

The previous Python callback approach in vLLM's router breaks under
`torch.compile` + CUDA graphs (callback is traced once, tensor reference
baked at trace time). This kernel-level approach works correctly with
pre-allocated buffers and CUDA graph replay.

### Changes

**C++ kernel changes:**
- `trtllm_fused_moe_kernel_launcher.cu`: Add `routing_replay_out` field
to `FusedMoeLauncher` base class. Add validation + passthrough in all
entry points (`trtllm_fp8_block_scale_moe`, `trtllm_bf16_moe`,
`trtllm_fp8_per_tensor_scale_moe`, `trtllm_fp4_block_scale_moe`,
`trtllm_mxint4_block_scale_moe`).
- `trtllm_fused_moe_runner.cu`: Add `int16_t* routing_replay_out` to
`Runner::run()`. Pass through to routing data struct for all routing
method paths (DeepSeek, Llama4, Custom/Renormalize, MiniMax2).
- `noAuxTcKernels.cu`: Add `routing_replay_out` to standalone DSV3 fused
routing kernel and `NoAuxTc` entry point.
- Routing kernel writes in `trtllm_fused_moe_routing_deepseek.cu`,
`trtllm_fused_moe_routing_custom.cu`,
`trtllm_fused_moe_routing_llama4.cu`, and `noAuxTcKernels.cu`.

**Python API changes:**
- `core.py`: Add `routing_replay_out` parameter to
`trtllm_fp8_block_scale_moe()`, `trtllm_bf16_moe()`,
`trtllm_bf16_routed_moe()`, internal op functions, fake functions, and
autotuner/launcher kwargs threading.
- `fused_routing_dsv3.py`: Add `routing_replay_out` to
`fused_topk_deepseek()` with int16 dtype validation and relaxed dim0
check (`>=` instead of `==`).

**`routing_replay_out` spec:**
- **dtype**: `torch.int16`
- **shape**: `(num_tokens_or_larger, top_k)` — buffer may be larger than
`num_tokens` for CUDA graph pre-allocation
- **layout**: row-major. `replay[t, k]` = k-th ranked expert ID for
token `t`
- **when None**: zero overhead, kernel skips the write entirely
- **dim0 validation**: `>=` (not `==`) — the kernel determines write
extent from `routing_logits.shape[0]`

## 🔍 Related Issues

vLLM's `--enable-return-routed-experts` for RL training pipelines. The
previous Python callback approach breaks under `torch.compile` + CUDA
graphs.

## 🚀 Pull Request Checklist

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

**New tests:**
- `test_dsv3_fused_routing.py::test_routing_replay_out` — Verifies
replay matches `topk_indices` and `None` has no side effects
-
`test_trtllm_gen_routed_fused_moe.py::test_fp8_block_scale_moe_routing_replay`
— Verifies replay has zero effect on MoE output, IDs are valid, each
token has `top_k` unique experts

**Test plan:**
- [ ] `pytest
tests/model_optimizations/test_dsv3_fused_routing.py::test_routing_replay_out
-v`
- [ ] `pytest
tests/moe/test_trtllm_gen_routed_fused_moe.py::test_fp8_block_scale_moe_routing_replay
-v`
- [ ] Verify zero overhead when `routing_replay_out=None` (benchmark
with and without)
- [ ] Validated end-to-end with vLLM on Super MXFP8 (2 nodes, TP=4,
DP=2) — non-zero routing data at 256 concurrency

## Reviewer Notes

- The dim0 validation uses `>=` instead of `==` intentionally — this
allows CUDA graph pre-allocation of oversized buffers. The kernel
determines write extent from `routing_logits.shape[0]`.
- Also includes `docs/vllm_routing_replay_integration.md` as an
integration guide for vLLM consumers.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added optional `routing_replay_out` parameter to TensorRT-LLM MoE
operations, enabling recording of selected expert indices during routing
for each token.

* **Documentation**
  * Added guide for integrating routing replay with vLLM's MoE workflow.

* **Tests**
* Added validation tests for routing replay functionality across
DeepSeek V3 and FP8 block-scale MoE kernels.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Tomer Natan <tbarnatan@oci-hsg-cs-001-login-02.cm.cluster>
Co-authored-by: Alex Yang <aleyang@nvidia.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TomerBN-Nvidia added a commit to TomerBN-Nvidia/vllm that referenced this pull request Apr 16, 2026
Extend routing replay to support monolithic (FP8/MXFP8) kernel paths
and prefix caching:

Monolithic kernel support:
- Thread routing_replay_out through apply_monolithic() chain
  (modelopt.py, fp8.py, modular_kernel.py, fused_moe_method_base.py)
- Add _monolithic_writes_routing_replay flag to quant methods
- BF16 monolithic fallback: run select_experts() separately when
  kernel does not write routing data internally

Prefix caching:
- Initialize host cache with -1 sentinel instead of 0
  (expert ID 0 is valid; -1 marks cache-hit positions)

Tests:
- Add TestMonolithicWritesFlag tests
- Update host cache sentinel test for -1 initialization

Depends on: flashinfer-ai/flashinfer#3024 (routing_replay_out param)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants