Skip to content

[Feature] Support recording expert indices for rollout router replay#28284

Merged
22quinn merged 33 commits intovllm-project:mainfrom
xhx1022:dev
Jan 12, 2026
Merged

[Feature] Support recording expert indices for rollout router replay#28284
22quinn merged 33 commits intovllm-project:mainfrom
xhx1022:dev

Conversation

@xhx1022
Copy link
Copy Markdown
Contributor

@xhx1022 xhx1022 commented Nov 7, 2025

Purpose

This PR introduces Rollout Router Replay (R3) support into vLLM runtime.
Inspired by the recent research in reinforcement learning alignment for MoE-based LLMs (arXiv:2510.11370, arXiv:2507.18071), this implementation allows recording the expert routing decisions for every token at every layer during model inference. The recorded routing traces can be used for replaying the expert routing process during RL post traning;

Currently, the initial version supports:

  • ✅ Tensor Parallel mode;
  • ✅ Prefix Cache mode;
  • ✅ Compatibility with CUDA Graph execution;

Below is a minimal reproducible example for running Qwen3-30B-A3B with tensor_parallel_size = 8 and async concurrent inference.
The example also prints the shape of the returned routed_experts tensor.

import asyncio

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM

# Configure engine args with TP=8
ENGINE_ARGS = AsyncEngineArgs(
    model="/dev/shm/Qwen3-30B-A3B",
    tensor_parallel_size=8,
    enable_return_routed_experts=True,
    # enforce_eager=True,
)

async def test_qwen3_tp8_concurrent():
    """Test concurrent inference of Qwen3-30B-A3B with TP=8"""

    # Create AsyncLLM instance
    async_llm = AsyncLLM.from_engine_args(ENGINE_ARGS)

    # Configure sampling parameters
    sampling_params = SamplingParams(
        max_tokens=100,
        temperature=0.8,
        top_p=0.95,
        output_kind=RequestOutputKind.FINAL_ONLY,
    )

    NUM_REQUESTS = 10
    prompts = [
        f"Request {i}: Hello, please introduce yourself."
        for i in range(NUM_REQUESTS)
    ]

    async def generate_single(request_id: str, prompt: str):
        outputs = []
        async for output in async_llm.generate(
            request_id=request_id,
            prompt=prompt,
            sampling_params=sampling_params
        ):
            outputs.append(output)
        return outputs[-1]  # Return the final output

    # Create concurrent tasks
    tasks = [
        generate_single(f"request-{i}", prompts[i])
        for i in range(NUM_REQUESTS)
    ]

    # Wait for all tasks to finish
    results = await asyncio.gather(*tasks)

    # Validate results
    for i, result in enumerate(results):
        assert result.finished
        assert len(result.outputs[0].text) > 0
        print(result.outputs[0].routed_experts.shape)  # [seq_len, layer_num, topk]

    # Cleanup
    async_llm.shutdown()

# Run test
if __name__ == "__main__":
    asyncio.run(test_qwen3_tp8_concurrent())

Reminder

The number of experts in the output can be 1 less than (prompt_length + response_token_count).
This gap of 1 is expected because the final generated token is sampled, not computed through the forward layer, and therefore is not included in the expert count.

Acknowledgments

This work is inspired by and builds upon the implementation from SGLang PR #12162.
Special thanks to @ocss884 and the SGLang RL team for their valuable discussions and contributions.


Note

Introduces an opt-in pathway to capture and return MoE routed expert indices per token/layer.

  • Adds enable_return_routed_experts to ModelConfig, plumbed through EngineArgs/CLI (--enable-return-routed-experts), LLM entrypoint, and config logging
  • Implements RoutedExpertsCapturer/RoutedExpertsReader with shared memory buffers; fused_moe/layer.py records gate topk_ids per layer
  • GPUModelRunner initializes/clears/saves captured experts and computes slot_mapping; TP rank 0 writes to shared memory
  • Scheduler derives token slots from KV blocks and attaches routed_experts to EngineCoreOutput
  • Output pipeline propagates to clients: CompletionOutput.routed_experts (numpy array) and related plumbing in v1 engine/output processor

Written by Cursor Bugbot for commit b0fb649926346a1a132d8c8dd294a5a95142579f. This will update automatically on new commits. Configure here.


Note

Introduces optional routed-expert tracing for MoE models, plumbed end-to-end and exposed in request outputs.

  • Adds enable_return_routed_experts to ModelConfig and surfaces it via EngineArgs/CLI (--enable-return-routed-experts), LLM entrypoint, and config logging
  • Implements RoutedExpertsCapturer/RoutedExpertsReader with shared memory buffers; fused_moe/layer.py records gate topk_ids per layer using layer_id
  • GPUModelRunner initializes/clears/saves captured experts and computes slot_mapping; TP rank 0 writes captured indices to shared memory
  • Scheduler reconstructs token slots from KV blocks and attaches routed_experts to EngineCoreOutput
  • Output pipeline propagates to clients: adds CompletionOutput.routed_experts (numpy array) and corresponding plumbing in v1 engine/output processor

Written by Cursor Bugbot for commit 407fd57d30b3fe321fcbb75bea382b17fa89f349. This will update automatically on new commits. Configure here.


Note

Enables recording and returning MoE routed expert indices for each token/layer when enable_return_routed_experts is set.

  • Adds enable_return_routed_experts to ModelConfig; plumbed through EngineArgs/CLI (--enable-return-routed-experts), LLM entrypoint, and config logging
  • Implements RoutedExpertsCapturer/RoutedExpertsReader with shared memory buffers; fused_moe/layer.py captures gate topk_ids using layer_id
  • GPUModelRunner initializes/clears/saves captured experts and computes slot_mapping; TP rank 0 writes captured indices to shared memory
  • Scheduler reconstructs token slots from KV blocks and attaches routed_experts to engine outputs
  • Propagates to clients by adding CompletionOutput.routed_experts (numpy array) and wiring through v1 engine/output processor

Written by Cursor Bugbot for commit ec8ed03f325943d2c63b329e36f018123f91109d. This will update automatically on new commits. Configure here.


Note

Enables optional routed-expert tracing for MoE models and surfaces it in request outputs.

  • Adds enable_return_routed_experts to ModelConfig, threads through EngineArgs/CLI (--enable-return-routed-experts), LLM entrypoint, and config logging
  • Implements RoutedExpertsCapturer/RoutedExpertsReader with shared memory buffers; fused_moe/layer.py captures gate topk_ids using layer_id
  • GPUModelRunner initializes/clears/saves captured experts and computes slot_mapping; TP rank 0 writes captured indices to shared memory
  • Scheduler reconstructs token slot indices from KV blocks and attaches routed_experts to engine outputs
  • Output pipeline propagates to clients by adding CompletionOutput.routed_experts and wiring through v1 engine/output processor

Written by Cursor Bugbot for commit ec8ed03f325943d2c63b329e36f018123f91109d. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit ec8ed03f325943d2c63b329e36f018123f91109d. Configure here.


Note

Enables tracing MoE router decisions end-to-end when enable_return_routed_experts is set, and exposes them to clients.

  • Adds enable_return_routed_experts to ModelConfig, plumbed through EngineArgs/CLI (--enable-return-routed-experts), LLM entrypoint, and config logging
  • Implements RoutedExpertsCapturer/RoutedExpertsReader with shared-memory buffers; TP rank 0 writes captured indices
  • Hooks in fused_moe/layer.py to capture gate topk_ids per layer_id
  • GPUModelRunner initializes/clears/saves captured experts and computes slot_mapping
  • Scheduler reconstructs token slot indices from KV blocks and attaches routed_experts to EngineCoreOutput
  • Propagates to clients by adding CompletionOutput.routed_experts (numpy array) and wiring through v1 engine/output processor

Written by Cursor Bugbot for commit ec8ed03f325943d2c63b329e36f018123f91109d. This will update automatically on new commits. Configure here.


Note

Enables optional routed-expert tracing for MoE models and surfaces it in request outputs.

  • Adds enable_return_routed_experts to ModelConfig; plumbed through EngineArgs/CLI (--enable-return-routed-experts), LLM entrypoint, and config logging
  • New RoutedExpertsCapturer/RoutedExpertsReader with shared-memory buffers to store/read per-token per-layer router topk_ids; TP rank 0 writes, scheduler reads
  • Hooks fused_moe/layer.py to capture gate topk_ids with layer_id; GPUModelRunner initializes/clears/saves captures and computes slot_mapping
  • Scheduler derives token slot indices from KV blocks on request finish and attaches routed_experts to engine outputs
  • Propagates to clients by adding CompletionOutput.routed_experts (numpy array) and wiring through v1 engine/output processor/engine structures

Written by Cursor Bugbot for commit c9d5d3b2729284422782eb1dccd1bc0668ab111c. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 90ebad2bf212397ec26f7939e47f11c2139fa9e9. Configure here.


Note

Introduces an opt-in pathway to capture and return MoE router topk_ids per token/layer and expose them to clients.

  • Adds enable_return_routed_experts to ModelConfig, surfaced via EngineArgs/CLI (--enable-return-routed-experts), LLM entrypoint, and config logging
  • New RoutedExpertsCapturer/RoutedExpertsReader with shared memory buffers to store/read per-token per-layer routed experts; TP rank 0 writes, scheduler reads
  • Hooks fused_moe/layer.py to capture gate topk_ids using layer_id; GPUModelRunner initializes/clears/saves captures and computes slot_mapping
  • Scheduler reconstructs token slot indices from KV blocks on request finish and attaches routed_experts (numpy array, shape [seq_len, layer_num, topk]) to outputs
  • Plumbs through output pipeline: adds CompletionOutput.routed_experts and wires via v1 engine/output processor
  • Disables async scheduling when enable_return_routed_experts=True and asserts no context parallelism (DCP/PCP > 1 unsupported)

Written by Cursor Bugbot for commit 90ebad2bf212397ec26f7939e47f11c2139fa9e9. This will update automatically on new commits. Configure here.


Note

Introduces end-to-end, optional routed-expert tracing for MoE models and exposes it in outputs.

  • Adds enable_return_routed_experts to ModelConfig, plumbed via EngineArgs/CLI (--enable-return-routed-experts) and LLM entrypoint; included in config logs
  • Implements fused_moe/routed_experts_capturer.py with shared-memory buffers and a reader/writer; fused_moe/layer.py hooks capture per-layer topk_ids using layer_id
  • GPUModelRunner initializes/clears/saves captures, computes slot_mapping, and writes from TP rank 0
  • Scheduler reconstructs token slot indices from KV blocks and attaches routed_experts (numpy array) to engine outputs; asserts no context parallelism
  • Surfaces to clients by adding CompletionOutput.routed_experts and threading through v1 engine/output processor
  • Disables async scheduling when enable_return_routed_experts=True

Written by Cursor Bugbot for commit 39aefda113994386d32634051e9183c618f25681. This will update automatically on new commits. Configure here.


Note

Enables optional end-to-end capture and return of MoE routed expert indices.

  • Adds enable_return_routed_experts to ModelConfig, surfaced via EngineArgs/CLI (--enable-return-routed-experts) and LLM; included in config logs
  • New fused_moe/routed_experts_capturer.py (capturer/reader with shared memory) and hooks in fused_moe/layer.py to record per-layer router topk_ids (layer_id)
  • Integrates in GPUModelRunner: initialize/clear/save captures, compute slot_mapping, TP rank 0 writes to shared memory
  • Scheduler: attaches reader, reconstructs token slots from KV blocks on request finish, and attaches routed_experts to engine outputs; asserts no DCP/PCP; disables async scheduling when flag is set
  • Output plumbing: adds CompletionOutput.routed_experts (numpy array) and threads through v1 engine/output processor/structures

Written by Cursor Bugbot for commit 819964f71ede807917c1ab2d63fb9c0232fbbd5e. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 21c26d3001b2bbf73e8704f086b37061355f3907. Configure here.


Note

Introduces end-to-end, opt-in routed-expert tracing for MoE models and surfaces it to clients.

  • Adds enable_return_routed_experts to ModelConfig, threaded through EngineArgs/CLI (--enable-return-routed-experts), LLM entrypoint, and config logging
  • New fused_moe/routed_experts_capturer.py implementing capturer/reader singletons with shared memory buffers; TP rank 0 writes captured indices
  • Hooks fused_moe/layer.py to record gate topk_ids per layer_id
  • GPUModelRunner: initializes/clears/saves captures and computes slot_mapping; writes per-token routed experts to shared memory
  • Scheduler: attaches reader, reconstructs token slots from KV blocks on request finish, asserts no context parallelism, and attaches routed_experts to engine outputs
  • Output plumbing: adds CompletionOutput.routed_experts (numpy array, shape [seq_len,layer_num,topk]) and wires through v1 engine/output processor/structures

Written by Cursor Bugbot for commit 9fd9ac2ad7004208b263a6d16a412a74ea3317cc. This will update automatically on new commits. Configure here.


Note

Enables returning MoE routed expert indices when requested, plumbed across config → execution → outputs.

  • Adds enable_return_routed_experts to ModelConfig, surfaced via EngineArgs/CLI and LLM; included in config logs
  • New fused_moe/routed_experts_capturer.py (capturer/reader singletons using shared memory); fused_moe/layer.py records gate topk_ids per layer_id
  • GPUModelRunner initializes/clears/saves captured experts and computes token slot_mapping (TP rank 0 writes)
  • Scheduler reconstructs token slots from KV blocks and attaches routed_experts to engine outputs; asserts no context parallelism
  • Output pipeline: adds CompletionOutput.routed_experts (numpy array) and threads through v1 engine/output processor/structures

Written by Cursor Bugbot for commit aeb469e. This will update automatically on new commits. Configure here.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for recording expert routing decisions for MoE models, a feature named Router Replay Router (R3). The implementation is comprehensive, touching configuration, engine arguments, the model executor, and the scheduler. My review has identified a few critical issues, primarily concerning race conditions due to missing locks for shared memory access and a method signature mismatch that will lead to a TypeError. There are also some code quality suggestions to improve maintainability, such as removing a redundant argument and moving a local import.

Comment thread vllm/model_executor/layers/fused_moe/routed_experts_capturer.py Outdated
Comment thread vllm/model_executor/layers/fused_moe/routed_experts_capturer.py Outdated
Comment thread vllm/engine/arg_utils.py Outdated
@@ -1283,6 +1294,8 @@ def __init__(
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
from vllm.model_executor.models.utils import extract_layer_index
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import from vllm.model_executor.models.utils import extract_layer_index is performed inside the __init__ method. It is a best practice to place all imports at the top of the file for better readability, performance, and to avoid potential circular import issues. Please move this import to the top of the file.

Comment thread vllm/v1/core/sched/scheduler.py Outdated
@chatgpt-codex-connector
Copy link
Copy Markdown

💡 Codex Review

https://github.com/vllm-project/vllm/blob/611bc69292546334ddbcc52689ffe86f91da41e1/vllm/v1/worker/gpu_model_runner.py#L2737-L2738
P1 Badge Convert slot_mapping to numpy before saving experts

When routed expert recording is enabled, GPUModelRunner.execute_model passes self.slot_mapping directly into RoutedExpertsCapturer.save_captured_experts (save_captured_experts(indices=self.slot_mapping)). self.slot_mapping is a CPU torch.Tensor (it is assigned as slot_mapping.cpu() earlier in the method), but _RoutedExpertsCapturerReal.save_captured_experts indexes a NumPy buffer with the indices argument (self._host_buffer_view[indices, :, :] = data). NumPy does not accept PyTorch tensors as advanced indices, so this call will raise a TypeError: only integer scalar arrays can be converted to a scalar index the first time the feature is exercised. Converting the tensor to a NumPy array (e.g. indices.numpy()) before the call would avoid the crash.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Nov 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify Bot added the performance Performance-related issues label Nov 7, 2025
@xhx1022 xhx1022 force-pushed the dev branch 3 times, most recently from 04141a2 to f4e5998 Compare November 7, 2025 15:10
arlenxu and others added 6 commits January 12, 2026 19:20
Signed-off-by: arlenxu <arlenxu@tencent.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Comment thread vllm/config/vllm.py
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, "
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing async scheduling disable for routed experts feature

Medium Severity

The PR notes explicitly state that async scheduling should be disabled when enable_return_routed_experts=True, but this isn't implemented. The async scheduling logic in VllmConfig.__post_init__ handles various incompatibility cases (PP > 1, speculative decoding, executor backend) but doesn't include any handling for enable_return_routed_experts. Users in the PR discussion are reporting significant latency issues (10X slower), which could be related to async scheduling interference with the capture/save operations for routed experts.

Fix in Cursor Fix in Web

@22quinn 22quinn merged commit 49e6b86 into vllm-project:main Jan 12, 2026
62 checks passed
TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/vllm that referenced this pull request Jan 13, 2026
…llm-project#28284)

Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: Hongxin Xu <70438206+xhx1022@users.noreply.github.com>
Signed-off-by: arlenxu <arlenxu@tencent.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: arlenxu <arlenxu@tencent.com>
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
sammysun0711 pushed a commit to sammysun0711/vllm that referenced this pull request Jan 16, 2026
…llm-project#28284)

Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: Hongxin Xu <70438206+xhx1022@users.noreply.github.com>
Signed-off-by: arlenxu <arlenxu@tencent.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: arlenxu <arlenxu@tencent.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…llm-project#28284)

Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: Hongxin Xu <70438206+xhx1022@users.noreply.github.com>
Signed-off-by: arlenxu <arlenxu@tencent.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: arlenxu <arlenxu@tencent.com>
pjin-nvidia added a commit to pjin-nvidia/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: Peter Jin <pjin@nvidia.com>
linlinlinzhao

This comment was marked as outdated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed rl Related to RL workflows v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.