[feat] Add routing_replay_out support to MoE kernels and Python API#3024
[feat] Add routing_replay_out support to MoE kernels and Python API#3024aleozlx merged 24 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR adds optional routing replay capture functionality across the MoE kernel stack. When enabled via an optional Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~28 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new "routing replay" feature for MoE (Mixture of Experts) kernels, allowing the recording of selected expert IDs into an optional int16 tensor (routing_replay_out) during the routing process. This functionality is integrated across various CUDA kernels (noAuxTcKernels.cu, trtllm_fused_moe_routing_custom.cu, trtllm_fused_moe_routing_deepseek.cu, trtllm_fused_moe_routing_llama4.cu), the FusedMoeLauncher and Runner classes, and exposed through Python bindings. New test cases verify the feature, and a documentation file (vllm_routing_replay_integration.md) explains its usage. Feedback indicates a duplicate test function definition in test_dsv3_fused_routing.py, and suggests refactoring duplicated validation logic and replay_ptr retrieval into helper functions/methods in csrc/trtllm_fused_moe_kernel_launcher.cu, along with adding a missing int16 dtype check to the validation.
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
flashinfer/fused_moe/core.py (3)
2428-2451:⚠️ Potential issue | 🟠 Major
trtllm_bf16_routed_moenever forwards the new buffer.The function now accepts
routing_replay_out, but the delegatedtrtllm_bf16_moe(...)call still stops atnorm_topk_prob. Routed BF16 callers will silently get no replay data even when they pass a buffer.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 2428 - 2451, The call inside trtllm_bf16_routed_moe currently invokes trtllm_bf16_moe but stops at the norm_topk_prob argument and therefore never forwards the new routing_replay_out buffer; update the delegated call to pass the routing_replay_out argument through to trtllm_bf16_moe (i.e., add the routing_replay_out parameter to the argument list after the norm_topk_prob/place where additional trailing args belong) so routed BF16 callers receive the replay data; check the trtllm_bf16_routed_moe function and the trtllm_bf16_moe call site to ensure the parameter name routing_replay_out matches and ordering aligns with trtllm_bf16_moe's signature.
2523-2548:⚠️ Potential issue | 🔴 Critical
routing_replay_outis undefined at line 2547 — this will raiseNameErrorat runtime.The function signature (line 2463) does not include
routing_replay_outas a parameter, yet line 2547 forwards it to the CUDA kernel call. Either remove this argument from the call, or add the parameter to the function signature and update the corresponding op registration (trtllm_fp8_per_tensor_scale_moe_opand_fake_trtllm_fp8_per_tensor_scale_moe) to support replay if that is intended.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 2523 - 2548, The call to get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe(...) passes routing_replay_out which is not defined in the surrounding function signature, causing a NameError; either remove routing_replay_out from that call so the kernel is invoked with the existing parameters, or add a routing_replay_out parameter to the enclosing function signature and propagate it into the op registration (trtllm_fp8_per_tensor_scale_moe_op and _fake_trtllm_fp8_per_tensor_scale_moe) so the backend and fake op accept and forward the replay output consistently.
1995-2029:⚠️ Potential issue | 🔴 CriticalAdd missing
routing_replay_outparameter to FP4 and MxInt4 block-scale MoE functions.The real op functions
trtllm_fp4_block_scale_moe_opandtrtllm_mxint4_block_scale_moe_opreferencerouting_replay_outin their calls to the underlying C++ ops (lines 2029 and 2214 respectively), but neither function nor their corresponding fake ops declare this parameter. This will causeNameError: name 'routing_replay_out' is not definedat runtime.Add
routing_replay_out: Optional[torch.Tensor] = Noneto both real op signatures and both fake op signatures (matching the pattern in_fake_trtllm_fp8_block_scale_moe), and include_ = routing_replay_outin the fake op bodies to suppress linting warnings.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1995 - 2029, The real op functions trtllm_fp4_block_scale_moe_op and trtllm_mxint4_block_scale_moe_op are missing the routing_replay_out parameter used when calling the underlying C++ ops; add routing_replay_out: Optional[torch.Tensor] = None to both real op signatures and propagate it into the calls (where routing_replay_out is passed); also update the corresponding fake ops (_fake_trtllm_fp4_block_scale_moe and _fake_trtllm_mxint4_block_scale_moe) to accept routing_replay_out: Optional[torch.Tensor] = None and add a no-op reference (_ = routing_replay_out) in each fake body to suppress lint warnings. Ensure the torch and typing.Optional annotations are consistent with existing file style.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/fused_moe/noAuxTcKernels.cu`:
- Around line 352-367: The routing_replay_out validation allows undersized dim0;
ensure the buffer is at least num_tokens tall before kernel launch by replacing
the equality check with a lower-bound check: when
routing_replay_out.has_value(), validate replay.sizes()[0] >= num_tokens (in
addition to existing checks on dim1/topk and dtype) and only then set replay_ptr
so the kernel cannot write past the end of the routing_replay_out buffer;
reference the variables replay, replay.sizes(), routing_replay_out, replay_ptr,
num_tokens, and topk when making the change.
In `@csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu`:
- Around line 468-471: The replay buffer is only written in the large-token
kernel path, leaving mPtrRoutingReplayOut uninitialized for the common Llama4
score paths; update the single-warp and single-cluster kernels
(routingIndicesWarpKernel and routingIndicesClusterKernel) to check
params.mPtrRoutingReplayOut and write the selected expert id (cast to int16_t)
for the corresponding token index (use the same tokenIdx/warpMaxExpertIdx logic
used in the large-token kernel), so callers requesting replay get a populated
buffer across all execution paths.
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 228-230: The setter set_routing_replay_out(const
Optional<TensorView>&) is currently protected but is invoked by free functions
(not class members), causing access errors; move the declaration/definition of
set_routing_replay_out and the member routing_replay_out into the public section
of the class (or alternatively wire routing_replay_out via the constructor/init
path and remove the protected setter) so free functions can call it; locate the
symbol set_routing_replay_out and routing_replay_out in
trtllm_fused_moe_kernel_launcher.cu and change the access specifier to public
(or add a public wrapper/constructor parameter) to resolve the compilation
error.
In `@tests/model_optimizations/test_dsv3_fused_routing.py`:
- Around line 559-566: The test currently compares routing_replay_out and
topk_indices using sets which ignores ordering; change the check to assert
positional equality per token by comparing the ordered sequences (e.g., compare
routing_replay_out[t].tolist() to topk_indices[t].tolist() or use an array_equal
on routing_replay_out[t] and topk_indices[t]) and update the assertion message
to show the ordered lists for failing tokens; apply the same fix to the
duplicate check at the later block referenced (lines 649-656) so both places
validate replay order, not just membership.
- Around line 594-680: The test function test_routing_replay_out is defined
twice causing the first definition to be overwritten; fix by either renaming one
of the duplicates (e.g., rename the second occurrence to
test_routing_replay_out_extended or test_routing_replay_out_stress) or by
merging their parametrize decorators into a single test (combine the differing
num_tokens lists and other params) so both parameter sets are covered; update
any references accordingly and ensure unique test names in the module.
---
Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2428-2451: The call inside trtllm_bf16_routed_moe currently
invokes trtllm_bf16_moe but stops at the norm_topk_prob argument and therefore
never forwards the new routing_replay_out buffer; update the delegated call to
pass the routing_replay_out argument through to trtllm_bf16_moe (i.e., add the
routing_replay_out parameter to the argument list after the norm_topk_prob/place
where additional trailing args belong) so routed BF16 callers receive the replay
data; check the trtllm_bf16_routed_moe function and the trtllm_bf16_moe call
site to ensure the parameter name routing_replay_out matches and ordering aligns
with trtllm_bf16_moe's signature.
- Around line 2523-2548: The call to
get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe(...) passes
routing_replay_out which is not defined in the surrounding function signature,
causing a NameError; either remove routing_replay_out from that call so the
kernel is invoked with the existing parameters, or add a routing_replay_out
parameter to the enclosing function signature and propagate it into the op
registration (trtllm_fp8_per_tensor_scale_moe_op and
_fake_trtllm_fp8_per_tensor_scale_moe) so the backend and fake op accept and
forward the replay output consistently.
- Around line 1995-2029: The real op functions trtllm_fp4_block_scale_moe_op and
trtllm_mxint4_block_scale_moe_op are missing the routing_replay_out parameter
used when calling the underlying C++ ops; add routing_replay_out:
Optional[torch.Tensor] = None to both real op signatures and propagate it into
the calls (where routing_replay_out is passed); also update the corresponding
fake ops (_fake_trtllm_fp4_block_scale_moe and
_fake_trtllm_mxint4_block_scale_moe) to accept routing_replay_out:
Optional[torch.Tensor] = None and add a no-op reference (_ = routing_replay_out)
in each fake body to suppress lint warnings. Ensure the torch and
typing.Optional annotations are consistent with existing file style.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1ee1a5d5-e45a-46d8-a2c3-947375e46edb
📒 Files selected for processing (14)
csrc/fused_moe/noAuxTcKernels.cucsrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cucsrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cucsrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cucsrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_runner.cudocs/vllm_routing_replay_integration.mdflashinfer/fused_moe/core.pyflashinfer/fused_moe/fused_routing_dsv3.pyinclude/flashinfer/trtllm/fused_moe/RoutingKernel.hinclude/flashinfer/trtllm/fused_moe/noAuxTcKernels.hinclude/flashinfer/trtllm/fused_moe/runner.htests/model_optimizations/test_dsv3_fused_routing.pytests/moe/test_trtllm_gen_routed_fused_moe.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer-cubin/flashinfer_cubin/__init__.py`:
- Line 81: The file currently hard-codes __version__ = "0.6.7", which overwrites
the dynamic value from _get_version(); remove the hard-coded assignment and
ensure __version__ is set only via the call to _get_version() (or a proper
fallback if _get_version() fails) so that the module-level __version__ uses the
dynamically computed version; update the module to rely on the existing
_get_version() function and delete or replace the literal assignment to
__version__.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 69c8437a-bbcb-4983-9d4b-26873653142d
📒 Files selected for processing (1)
flashinfer-cubin/flashinfer_cubin/__init__.py
c0961c9 to
e439466
Compare
Add optional int16_t* routing_replay_out parameter to the standalone DSV3 fused routing kernel (noAuxTcKernels). When provided, writes selected expert IDs per token during routing. Includes input validation in the entry point.
Add optional routing_replay_out parameter to trtllm_fp8_block_scale_moe(), trtllm_bf16_moe(), trtllm_bf16_routed_moe(), and fused_topk_deepseek(). Relaxed dim0 validation (>= instead of ==) for CUDA graph compatibility. Thread through autotuner and launcher classes via kwargs.
Add test_routing_replay_out for DSV3 fused routing: verifies replay tensor matches topk_indices and passing None has no side effects. Add test_fp8_block_scale_moe_routing_replay for FP8 MoE: verifies replay has no effect on MoE output, expert IDs are valid, and each token has exactly top_k unique experts.
- Add int16 dtype check to all 5 launcher validation blocks (was missing, could silently corrupt memory on wrong dtype) - Add routing_replay_out param to trtllm_fp4_block_scale_moe(), trtllm_mxint4_block_scale_moe(), trtllm_fp8_per_tensor_scale_moe() public functions
60b321d to
5a8c656
Compare
Addressing CodeRabbit "Outside diff range" findings (review 4081573685)All three Python API issues from the second review are fixed in 60b321d:
Also applied pre-commit formatting fixes (clang-format, ruff) to ensure CI passes. |
…, duplicate test, Llama4 replay gaps - Rename duplicate test_routing_replay_out → test_routing_replay_out_extended (pytest collision) - Forward routing_replay_out in trtllm_bf16_routed_moe (was silently dropped) - Add routing_replay_out to FP4/MXINT4 op signatures, fake ops, and public APIs (NameError) - Move set_routing_replay_out() from protected to public in FusedMoeLauncher - Add routing replay writes to Llama4 warp and cluster kernel paths - Add explanatory comment for intentional dim0 validation omission (CUDA graph pre-alloc) - Apply pre-commit formatting fixes (clang-format, ruff)
5a8c656 to
1fb8454
Compare
…wap condition order - Move using tvm::ffi::Optional from shared header to noAuxTcKernels.cu only - Add _validate_routing_replay_out() Python helper with shape/dtype/contiguity checks - Call validation in all 6 public API functions before C++ dispatch - Swap condition order in DeepSeek routing kernel: nullptr check first (cheaper)
|
/bot run |
…to upstream-routing-replay
|
/bot run |
|
/bot run |
Inserting the field in the middle of DataBase and KernelParamsBase shifted memory offsets for all subsequent fields, causing GEMM crashes in FP8/FP4 autotuner tests (11/15 failures). Moving to end preserves the original layout for all existing fields. Also adds missing routing_replay_out arg to MXINT4 and FP4 paths in MoERunner._run(). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
fix: move mPtrRoutingReplayOut to end of routing structs
|
tests seem clean |
…eplay_out) Merged origin/main into nv-yunzheq/DSR1_shared_expert_fusion. New conflicts from PR flashinfer-ai#3024 (routing_replay_out support) resolved by keeping both PR's shared-expert fields and main's routing replay fields. Co-Authored-By: Claude <noreply@anthropic.com>
…3024) ## 📌 Description Add an optional `routing_replay_out` parameter to all MoE kernel entry points. When provided (int16 tensor of shape `[N, top_k]`), the routing kernel writes selected expert IDs per token during MoE routing — inside the same fused kernel that computes the MoE output. When `None` (default), zero overhead. This enables **routing replay** for RL training workflows: an inference engine (e.g., vLLM) captures which experts were selected for each token and returns the data alongside the model output, so the training pipeline can replay the same routing decisions. The previous Python callback approach in vLLM's router breaks under `torch.compile` + CUDA graphs (callback is traced once, tensor reference baked at trace time). This kernel-level approach works correctly with pre-allocated buffers and CUDA graph replay. ### Changes **C++ kernel changes:** - `trtllm_fused_moe_kernel_launcher.cu`: Add `routing_replay_out` field to `FusedMoeLauncher` base class. Add validation + passthrough in all entry points (`trtllm_fp8_block_scale_moe`, `trtllm_bf16_moe`, `trtllm_fp8_per_tensor_scale_moe`, `trtllm_fp4_block_scale_moe`, `trtllm_mxint4_block_scale_moe`). - `trtllm_fused_moe_runner.cu`: Add `int16_t* routing_replay_out` to `Runner::run()`. Pass through to routing data struct for all routing method paths (DeepSeek, Llama4, Custom/Renormalize, MiniMax2). - `noAuxTcKernels.cu`: Add `routing_replay_out` to standalone DSV3 fused routing kernel and `NoAuxTc` entry point. - Routing kernel writes in `trtllm_fused_moe_routing_deepseek.cu`, `trtllm_fused_moe_routing_custom.cu`, `trtllm_fused_moe_routing_llama4.cu`, and `noAuxTcKernels.cu`. **Python API changes:** - `core.py`: Add `routing_replay_out` parameter to `trtllm_fp8_block_scale_moe()`, `trtllm_bf16_moe()`, `trtllm_bf16_routed_moe()`, internal op functions, fake functions, and autotuner/launcher kwargs threading. - `fused_routing_dsv3.py`: Add `routing_replay_out` to `fused_topk_deepseek()` with int16 dtype validation and relaxed dim0 check (`>=` instead of `==`). **`routing_replay_out` spec:** - **dtype**: `torch.int16` - **shape**: `(num_tokens_or_larger, top_k)` — buffer may be larger than `num_tokens` for CUDA graph pre-allocation - **layout**: row-major. `replay[t, k]` = k-th ranked expert ID for token `t` - **when None**: zero overhead, kernel skips the write entirely - **dim0 validation**: `>=` (not `==`) — the kernel determines write extent from `routing_logits.shape[0]` ## 🔍 Related Issues vLLM's `--enable-return-routed-experts` for RL training pipelines. The previous Python callback approach breaks under `torch.compile` + CUDA graphs. ## 🚀 Pull Request Checklist ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). **New tests:** - `test_dsv3_fused_routing.py::test_routing_replay_out` — Verifies replay matches `topk_indices` and `None` has no side effects - `test_trtllm_gen_routed_fused_moe.py::test_fp8_block_scale_moe_routing_replay` — Verifies replay has zero effect on MoE output, IDs are valid, each token has `top_k` unique experts **Test plan:** - [ ] `pytest tests/model_optimizations/test_dsv3_fused_routing.py::test_routing_replay_out -v` - [ ] `pytest tests/moe/test_trtllm_gen_routed_fused_moe.py::test_fp8_block_scale_moe_routing_replay -v` - [ ] Verify zero overhead when `routing_replay_out=None` (benchmark with and without) - [ ] Validated end-to-end with vLLM on Super MXFP8 (2 nodes, TP=4, DP=2) — non-zero routing data at 256 concurrency ## Reviewer Notes - The dim0 validation uses `>=` instead of `==` intentionally — this allows CUDA graph pre-allocation of oversized buffers. The kernel determines write extent from `routing_logits.shape[0]`. - Also includes `docs/vllm_routing_replay_integration.md` as an integration guide for vLLM consumers. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added optional `routing_replay_out` parameter to TensorRT-LLM MoE operations, enabling recording of selected expert indices during routing for each token. * **Documentation** * Added guide for integrating routing replay with vLLM's MoE workflow. * **Tests** * Added validation tests for routing replay functionality across DeepSeek V3 and FP8 block-scale MoE kernels. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Tomer Natan <tbarnatan@oci-hsg-cs-001-login-02.cm.cluster> Co-authored-by: Alex Yang <aleyang@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extend routing replay to support monolithic (FP8/MXFP8) kernel paths and prefix caching: Monolithic kernel support: - Thread routing_replay_out through apply_monolithic() chain (modelopt.py, fp8.py, modular_kernel.py, fused_moe_method_base.py) - Add _monolithic_writes_routing_replay flag to quant methods - BF16 monolithic fallback: run select_experts() separately when kernel does not write routing data internally Prefix caching: - Initialize host cache with -1 sentinel instead of 0 (expert ID 0 is valid; -1 marks cache-hit positions) Tests: - Add TestMonolithicWritesFlag tests - Update host cache sentinel test for -1 initialization Depends on: flashinfer-ai/flashinfer#3024 (routing_replay_out param)
📌 Description
Add an optional
routing_replay_outparameter to all MoE kernel entry points. When provided (int16 tensor of shape[N, top_k]), the routing kernel writes selected expert IDs per token during MoE routing — inside the same fused kernel that computes the MoE output. WhenNone(default), zero overhead.This enables routing replay for RL training workflows: an inference engine (e.g., vLLM) captures which experts were selected for each token and returns the data alongside the model output, so the training pipeline can replay the same routing decisions.
The previous Python callback approach in vLLM's router breaks under
torch.compile+ CUDA graphs (callback is traced once, tensor reference baked at trace time). This kernel-level approach works correctly with pre-allocated buffers and CUDA graph replay.Changes
C++ kernel changes:
trtllm_fused_moe_kernel_launcher.cu: Addrouting_replay_outfield toFusedMoeLauncherbase class. Add validation + passthrough in all entry points (trtllm_fp8_block_scale_moe,trtllm_bf16_moe,trtllm_fp8_per_tensor_scale_moe,trtllm_fp4_block_scale_moe,trtllm_mxint4_block_scale_moe).trtllm_fused_moe_runner.cu: Addint16_t* routing_replay_outtoRunner::run(). Pass through to routing data struct for all routing method paths (DeepSeek, Llama4, Custom/Renormalize, MiniMax2).noAuxTcKernels.cu: Addrouting_replay_outto standalone DSV3 fused routing kernel andNoAuxTcentry point.trtllm_fused_moe_routing_deepseek.cu,trtllm_fused_moe_routing_custom.cu,trtllm_fused_moe_routing_llama4.cu, andnoAuxTcKernels.cu.Python API changes:
core.py: Addrouting_replay_outparameter totrtllm_fp8_block_scale_moe(),trtllm_bf16_moe(),trtllm_bf16_routed_moe(), internal op functions, fake functions, and autotuner/launcher kwargs threading.fused_routing_dsv3.py: Addrouting_replay_outtofused_topk_deepseek()with int16 dtype validation and relaxed dim0 check (>=instead of==).routing_replay_outspec:torch.int16(num_tokens_or_larger, top_k)— buffer may be larger thannum_tokensfor CUDA graph pre-allocationreplay[t, k]= k-th ranked expert ID for tokent>=(not==) — the kernel determines write extent fromrouting_logits.shape[0]🔍 Related Issues
vLLM's
--enable-return-routed-expertsfor RL training pipelines. The previous Python callback approach breaks undertorch.compile+ CUDA graphs.🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).New tests:
test_dsv3_fused_routing.py::test_routing_replay_out— Verifies replay matchestopk_indicesandNonehas no side effectstest_trtllm_gen_routed_fused_moe.py::test_fp8_block_scale_moe_routing_replay— Verifies replay has zero effect on MoE output, IDs are valid, each token hastop_kunique expertsTest plan:
pytest tests/model_optimizations/test_dsv3_fused_routing.py::test_routing_replay_out -vpytest tests/moe/test_trtllm_gen_routed_fused_moe.py::test_fp8_block_scale_moe_routing_replay -vrouting_replay_out=None(benchmark with and without)Reviewer Notes
>=instead of==intentionally — this allows CUDA graph pre-allocation of oversized buffers. The kernel determines write extent fromrouting_logits.shape[0].docs/vllm_routing_replay_integration.mdas an integration guide for vLLM consumers.Summary by CodeRabbit
New Features
routing_replay_outparameter to TensorRT-LLM MoE operations, enabling recording of selected expert indices during routing for each token.Documentation
Tests