diff --git a/docs/features/routed_experts_replay.md b/docs/features/routed_experts_replay.md new file mode 100644 index 000000000000..bfc6fc3b9568 --- /dev/null +++ b/docs/features/routed_experts_replay.md @@ -0,0 +1,285 @@ +# Routed Experts Replay + +## Overview + +Routed experts replay captures which MoE (Mixture of Experts) experts process each token during inference and returns this information alongside the generated text. This is essential for **reinforcement learning (RL) training pipelines** (such as GRPO and RLHF) where the training step needs to reconstruct expert routing decisions from the inference pass. + +When enabled, each API response includes: + +- **`prompt_routed_experts`**: A `[prompt_len, num_moe_layers, top_k]` array of expert IDs for the prompt tokens (at the response level, shared across completions). +- **`routed_experts`**: A `[gen_len, num_moe_layers, top_k]` array of expert IDs for the generated tokens (per completion). + +For example, a model with 40 MoE layers and top-22 routing that processes a 100-token prompt and generates 50 tokens would return: + +- `prompt_routed_experts`: shape `[100, 40, 22]` +- `routed_experts`: shape `[50, 40, 22]` + +Each value is an int16 expert ID in the range `[0, num_experts)`. + +## Quickstart + +### OpenAI API Server + +```bash +vllm serve \ + --enable-return-routed-experts \ + --tensor-parallel-size 4 \ + --enable-expert-parallel +``` + +Then query the `/v1/completions` endpoint as usual. The response includes routing data: + +```python +import requests + +resp = requests.post("http://localhost:8000/v1/completions", json={ + "model": "", + "prompt": "Explain quantum computing.", + "max_tokens": 64, + "temperature": 0.0, +}).json() + +# Generation routing (per completion choice) +gen_routing = resp["choices"][0]["routed_experts"] # [gen_len, layers, top_k] + +# Prompt routing (shared across all choices) +prompt_routing = resp["prompt_routed_experts"] # [prompt_len, layers, top_k] + +print(f"Prompt routing shape: [{len(prompt_routing)}, " + f"{len(prompt_routing[0])}, {len(prompt_routing[0][0])}]") +print(f"Gen routing shape: [{len(gen_routing)}, " + f"{len(gen_routing[0])}, {len(gen_routing[0][0])}]") +``` + +### Python SDK (Offline Inference) + +```python +from vllm import LLM, SamplingParams + +llm = LLM( + model="", + enable_return_routed_experts=True, + tensor_parallel_size=4, + enable_expert_parallel=True, +) + +outputs = llm.generate( + ["Explain quantum computing."], + SamplingParams(temperature=0, max_tokens=64), +) + +result = outputs[0] + +# Prompt routing: numpy array, shape [prompt_len, num_moe_layers, top_k] +prompt_routing = result.prompt_routed_experts +print(f"Prompt routing: {prompt_routing.shape}, dtype={prompt_routing.dtype}") + +# Generation routing: numpy array, shape [gen_len, num_moe_layers, top_k] +gen_routing = result.outputs[0].routed_experts +print(f"Gen routing: {gen_routing.shape}, dtype={gen_routing.dtype}") +``` + +## Output Format + +### `CompletionOutput.routed_experts` + +- **Type**: `numpy.ndarray` (Python SDK) or `list[list[list[int]]]` (JSON API) +- **Shape**: `[gen_len, num_moe_layers, top_k]` +- **Dtype**: `int16` +- **Content**: Expert IDs for **generated tokens only**. `gen_len` matches the number of generated tokens (i.e., `usage.completion_tokens` or fewer). + +### `RequestOutput.prompt_routed_experts` + +- **Type**: `numpy.ndarray` (Python SDK) or `list[list[list[int]]]` (JSON API) +- **Shape**: `[prompt_len, num_moe_layers, top_k]` +- **Dtype**: `int16` +- **Content**: Expert IDs for **prompt tokens only**. `prompt_len` matches `usage.prompt_tokens`. This field lives on the request-level response (not per-choice), because prompt routing is shared across all completions when `n > 1`. + +### Why Separate Prompt and Generation Routing? + +When a request has multiple completions (`n > 1`), each completion shares the same prompt but produces different generated text. Storing prompt routing once on the `RequestOutput` (rather than duplicating it on every `CompletionOutput`) avoids redundant data. For RL training, the consumer typically needs: + +1. The prompt routing (once) to reconstruct the forward pass for the shared prefix. +2. The per-completion generation routing to reconstruct each completion's forward pass. + +## Architecture + +### Data Flow + +```text +Forward Pass Async D2H Pipeline Output +───────────── ────────────────── ────── +FusedMoE layer After forward pass: On request finish: +writes topk_ids ──────► D2H copy to pinned ──────► Extract from host cache +to device buffer staging buffer Split at prompt_len +(L, N, K) int16 (via CUDA stream) Trim gen to output len + Scatter to per-request Serialize to API response + host cache (numpy) +``` + +### Device Cache + +A pre-allocated GPU buffer with layout `(L, N, K)` where: + +- `L` = number of MoE layers +- `N` = `max_num_batched_tokens` +- `K` = `num_experts_per_tok` (top-k) + +The `(L, N, K)` layout ensures that `buffer[layer_id]` gives a contiguous `(N, K)` view per layer. Each `FusedMoE` layer gets a persistent reference to its slice via `module._routing_replay_out = buffer[layer_id]`. + +**Dtype**: `int16` — sufficient for expert IDs (max ~512 experts in practice) and half the memory of `int32`. + +### Host Cache + +Per-request numpy arrays for accumulating routing data across decode steps. Each request gets a lazily allocated `(seq_len, L, K)` int16 buffer that grows as the sequence lengthens. Buffers are freed when a request completes. + +### Async D2H Pipeline + +After each forward pass, the model runner issues a non-blocking device-to-host copy on a dedicated CUDA stream: + +1. **Copy**: `pinned_staging[:, :total_tokens, :].copy_(device_buffer[:, :total_tokens, :])` on a separate stream, recorded with a CUDA event. +2. **Scatter** (deferred to next step): On the *next* forward pass, synchronize the event (effectively free — an entire forward pass has elapsed) and scatter the staging data into per-request host cache buffers using the token positions. + +This design ensures the D2H copy overlaps with the next forward pass, minimizing GPU stall time. + +### CUDA Graph Compatibility + +CUDA graph compatibility requires two mechanisms: + +1. **Persistent tensor attribute**: Each `FusedMoE` layer stores a reference to its buffer slice as `module._routing_replay_out`. Because `torch.compile` captures module attributes by reference, graph replay always writes to the live buffer — not a stale snapshot. + +2. **Static marking**: Both the full `(L, N, K)` buffer and each per-layer `(N, K)` view are marked with `cudagraph_mark_tensor_static()`. This prevents CUDA graphs from snapshot/restore behavior that would zero the buffer on replay. + +### Multi-Node Support + +On multi-node tensor-parallel setups, all TP ranks allocate a device buffer (required for symmetric CUDA graph structure), but only TP rank 0 runs the D2H pipeline and host cache. Routing data flows from the model runner through `ModelRunnerOutput` via Ray DAG to the scheduler — no shared memory or file locks needed. + +### Routing Capture Path + +For the **non-monolithic (Triton) kernel path** (e.g., BF16 MoE), routing is captured after `select_experts()` in the MoE runner: + +```python +routing_replay_out = getattr(layer, "_routing_replay_out", None) +topk_weights, topk_ids = self.router.select_experts(...) + +if routing_replay_out is not None: + routing_replay_out[:topk_ids.shape[0]].copy_(topk_ids.to(torch.int16)) +``` + +For the **monolithic kernel path** (e.g., FP8/MXFP8 via FlashInfer), `routing_replay_out` is threaded through the `apply_monolithic()` call chain and FlashInfer writes expert IDs directly during routing inside the fused kernel. + +### MTP (Multi-Token Prediction) Handling + +With MTP speculative decoding, the model captures routing for all tokens including speculative ones that may later be rejected. When a request finishes, the generation routing is trimmed to match the actual number of accepted output tokens: + +```python +num_gen = self.detokenizer.num_output_tokens() +if gen_routed_experts.shape[0] > num_gen and num_gen > 0: + gen_routed_experts = gen_routed_experts[:num_gen] +``` + +This ensures the routing array length always matches the token IDs in the response. + +## Design Decisions + +### Why Replace SharedMemory with Device Cache? + +The previous implementation used `multiprocessing.SharedMemory` with `fcntl` file locking to transfer routing data from GPU workers to the scheduler. This approach had fundamental problems: + +- **Multi-node**: `SharedMemory` is node-local. On multi-node TP setups (required for 400B+ parameter models), the scheduler on node 0 cannot read shared memory from workers on other nodes. +- **Performance**: Synchronous `.cpu().numpy()` D2H transfers block the GPU. File-based locking adds further overhead. +- **CUDA graphs**: The callback-based capture mechanism bakes tensor references at trace time, causing stale data on graph replay. + +The device cache approach solves all three: data flows through Ray DAG (works multi-node), D2H is async (non-blocking), and persistent tensor attributes work with CUDA graphs. + +### Why `(L, N, K)` Layout Instead of `(N, L, K)`? + +FlashInfer's `routing_replay_out` parameter expects a contiguous `(N, K)` tensor per layer. With `(L, N, K)` layout, `buffer[layer_id]` gives a contiguous `(N, K)` view with zero-copy slicing. The previous `(N, L, K)` layout would require non-contiguous indexing or an explicit copy. + +### Why int16 Instead of int32? + +Expert IDs are small integers (typically 0-255 for models with up to 256 experts). `int16` supports up to 32,767 experts — far more than any current model — while halving GPU memory usage and D2H bandwidth compared to `int32`. + +### Why Split Prompt and Generation Routing? + +RL training pipelines process prompt and generation routing separately: + +- Prompt routing reconstructs the shared forward pass for the input. +- Generation routing reconstructs each sampled trajectory. + +With `n > 1` completions, all completions share the same prompt routing. Duplicating it per completion would waste memory proportional to `n * prompt_len * L * K`. Instead, `prompt_routed_experts` is stored once on `RequestOutput` and shared. + +### Why Async D2H Instead of Synchronous Copy? + +A synchronous `.cpu()` call forces the GPU to drain its command queue before the copy can begin, stalling the pipeline. The async approach: + +1. Issues the copy on a separate CUDA stream (non-blocking to the main compute stream). +2. Defers the host-side scatter to the *next* step, by which time the copy has finished. + +This means the D2H transfer overlaps entirely with the next forward pass, adding near-zero latency to the critical path. + +### Why All TP Ranks Get a Device Buffer? + +CUDA graph capture records the exact sequence of kernel calls and their arguments. If only rank 0 had a device buffer, the `FusedMoE` layer would take a different code path on rank 0 vs. other ranks (one writes to a buffer, others don't). This asymmetry causes different CUDA graph structures across ranks, which can lead to NCCL deadlocks during collective operations inside the graph. Giving all ranks a real buffer ensures symmetric graph structure. Only rank 0 does the D2H copy and host cache management. + +## Performance + +Routing replay adds a small overhead from the device buffer writes and async D2H copies. On tested configurations: + +- **Throughput overhead** (random data, ISL=1024, OSL=1024): **~2%** +- **Memory overhead** (int16 buffer, 40 layers, 8192 tokens, top-22): **~14 MB per GPU** +- **Accuracy impact** (GSM8K): **Zero** (pass@1 identical with and without routing replay) + +The overhead is dominated by the per-layer `.copy_()` during the forward pass. The async D2H pipeline runs entirely in the background. + +## Supported Configurations + +| Configuration | Supported | +|------------------------------------------|-----------------------------------------------------------| +| BF16 Triton MoE (non-monolithic) | Yes | +| FP8/MXFP8 FlashInfer MoE (monolithic) | Yes (requires FlashInfer with `routing_replay_out`) | +| CUDA graphs | Yes | +| Multi-node tensor parallelism | Yes | +| Data parallelism (DP) | Yes | +| Expert parallelism (EP) | Yes | +| Prefix caching | Yes (cached positions marked with `-1` sentinel) | +| MTP speculative decoding | Yes (gen routing trimmed to accepted tokens) | +| `n > 1` (multiple completions) | Yes (prompt routing shared, gen routing per-completion) | + +## Limitations + +- **Streaming**: Routing data is only available when the request finishes (not streamed incrementally). +- **V1 engine only**: Routing replay is implemented for the vLLM V1 engine. + +## CLI Reference + +| Flag | Description | +|------------------------------------|------------------------------------------------------------------------| +| `--enable-return-routed-experts` | Enable routing replay capture and return expert IDs in API responses. | + +## API Reference + +### Completions (`/v1/completions`) + +**Response-level field:** + +| Field | Type | Description | +|---------------------------|-------------------------------------|-----------------------------------------------------------------------------| +| `prompt_routed_experts` | `list[list[list[int]]]` or `null` | Expert IDs for prompt tokens. Shape: `[prompt_len, num_moe_layers, top_k]`. | + +**Choice-level field:** + +| Field | Type | Description | +|--------------------|-------------------------------------|-------------------------------------------------------------------------------| +| `routed_experts` | `list[list[list[int]]]` or `null` | Expert IDs for generated tokens. Shape: `[gen_len, num_moe_layers, top_k]`. | + +### Chat Completions (`/v1/chat/completions`) + +Same fields as above on `ChatCompletionResponse` and `ChatCompletionResponseChoice`. + +### Python SDK + +| Object | Field | Type | Description | +|----------------------|---------------------------|--------------------------|-----------------------------| +| `RequestOutput` | `prompt_routed_experts` | `np.ndarray` or `None` | `[prompt_len, L, K]` i16 | +| `CompletionOutput` | `routed_experts` | `np.ndarray` or `None` | `[gen_len, L, K]` int16 | diff --git a/tests/model_executor/test_routed_experts_capture.py b/tests/model_executor/test_routed_experts_capture.py index 770a3fa53850..58dfdb302f38 100644 --- a/tests/model_executor/test_routed_experts_capture.py +++ b/tests/model_executor/test_routed_experts_capture.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import types -from types import SimpleNamespace -from unittest.mock import patch import pytest import torch @@ -185,59 +183,88 @@ def capture(self, layer_id, topk_ids): assert len(capturer.calls) == 1 -def test_routed_experts_capturer_single_dp_no_metadata(): - """dp_metadata is None: capture writes the full topk_ids rows.""" - capturer = _capturer_with_buffer(dp_rank=0) - topk = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32) - ctx = SimpleNamespace(dp_metadata=None) - with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx): - capturer.capture(layer_id=0, topk_ids=topk) - assert torch.equal(capturer._device_buffer[:3, 0, :], topk) - assert capturer._device_buffer[3, 0, 0].item() == -1 +# ========================================================================= +# Tests for device-cache routing replay architecture +# ========================================================================= -def test_routed_experts_capturer_dp_naive_concatenated_all_ranks(): - """n == sum(num_tokens_dp): slice this rank's segment from concatenated topk.""" - capturer = _capturer_with_buffer(dp_rank=1) - num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32) - ctx = SimpleNamespace( - dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp) - ) - # Concatenated order: rank0 rows then rank1 rows. - topk = torch.tensor( - [[0, 1], [2, 3], [10, 11], [12, 13], [14, 15]], dtype=torch.int32 - ) - with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx): - capturer.capture(layer_id=0, topk_ids=topk) - want = topk[2:5] - assert torch.equal(capturer._device_buffer[:3, 0, :], want) - - -def test_routed_experts_capturer_dp_modular_local_tokens(): - """n == token_num_per_dp: topk is already local to this DP rank.""" - capturer = _capturer_with_buffer(dp_rank=1) - num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32) - ctx = SimpleNamespace( - dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp) - ) - topk = torch.tensor([[10, 11], [12, 13], [14, 15]], dtype=torch.int32) - with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx): - capturer.capture(layer_id=0, topk_ids=topk) - assert torch.equal(capturer._device_buffer[:3, 0, :], topk) - - -def test_routed_experts_capturer_dp_unexpected_batch_raises(): - """Mismatch between topk batch dim and DP layout: fail fast.""" - capturer = _capturer_with_buffer(dp_rank=0) - num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32) - ctx = SimpleNamespace( - dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp) - ) - # total=5, local=2: n=1 matches neither naive (5) nor modular (2). - topk = torch.tensor([[1, 2]], dtype=torch.int32) - with ( - patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx), - pytest.raises(AssertionError, match="unexpected topk_ids batch dim"), - ): - capturer.capture(layer_id=0, topk_ids=topk) - assert capturer._device_buffer[0, 0, 0].item() == -1 +class TestRoutedExpertsDeviceCache: + """Tests for _RoutedExpertsDeviceCache (GPU buffer for routing data).""" + + def test_allocation_shape_and_dtype(self): + """Device cache allocates (L, N, K) int16 buffer.""" + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsDeviceCache, + ) + + cache = _RoutedExpertsDeviceCache( + num_hidden_layers=40, + max_num_batched_tokens=8192, + num_experts_per_tok=8, + ) + assert cache.buffer.shape == (40, 8192, 8) + assert cache.buffer.dtype == torch.int16 + + def test_per_layer_view_is_contiguous(self): + """buffer[layer_id] gives contiguous (N, K) view for FlashInfer.""" + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsDeviceCache, + ) + + cache = _RoutedExpertsDeviceCache( + num_hidden_layers=40, + max_num_batched_tokens=8192, + num_experts_per_tok=8, + ) + layer_view = cache.buffer[0] + assert layer_view.is_contiguous() + assert layer_view.shape == (8192, 8) + + +class TestRoutedExpertsHostCache: + """Tests for _RoutedExpertsHostCache (per-request numpy buffer).""" + + def test_sentinel_initialization(self): + """Host cache initializes with zeros by default.""" + import numpy as np + + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsHostCache, + ) + + cache = _RoutedExpertsHostCache( + num_hidden_layers=40, + num_experts_per_tok=8, + ) + buf = cache.get_or_grow_buffer("req1", max_pos=100) + assert buf.dtype == np.int16 + assert (buf == 0).all(), "Host cache must initialize with zeros" + + def test_grow_preserves_existing_data(self): + """Growing the buffer preserves previously written data.""" + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsHostCache, + ) + + cache = _RoutedExpertsHostCache( + num_hidden_layers=40, + num_experts_per_tok=8, + ) + buf = cache.get_or_grow_buffer("req1", max_pos=50) + buf[0, 0, 0] = 42 + buf2 = cache.get_or_grow_buffer("req1", max_pos=200) + assert buf2[0, 0, 0] == 42, "Data lost during buffer grow" + + def test_free_request_removes_buffer(self): + """Freeing a request removes its buffer.""" + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsHostCache, + ) + + cache = _RoutedExpertsHostCache( + num_hidden_layers=40, + num_experts_per_tok=8, + ) + cache.get_or_grow_buffer("req1", max_pos=50) + cache.free_request("req1") + assert cache.get_buffer("req1") is None diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index aacac38e07fc..ceba1f6e6778 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -92,12 +92,16 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): # not part of the OpenAI spec but is useful for tracing the tokens # in agent scenarios token_ids: list[int] | None = None + routed_experts: list[list[list[int]]] | None = None # [gen_len, num_layers, top_k] class ChatCompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: Literal["chat.completion"] = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) + prompt_routed_experts: list[list[list[int]]] | None = ( + None # [prompt_len, num_layers, top_k] + ) model: str choices: list[ChatCompletionResponseChoice] service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 446f127a91e3..9471a2f10e4e 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -1325,6 +1325,11 @@ async def chat_completion_full_generator( token_ids=( as_list(output.token_ids) if request.return_token_ids else None ), + routed_experts=( + output.routed_experts.tolist() + if output.routed_experts is not None + else None + ), ) choices.append(choice_data) continue @@ -1541,6 +1546,11 @@ async def chat_completion_full_generator( token_ids=( as_list(output.token_ids) if request.return_token_ids else None ), + routed_experts=( + output.routed_experts.tolist() + if output.routed_experts is not None + else None + ), ) choice_data = maybe_filter_parallel_tool_calls(choice_data, request) @@ -1580,6 +1590,10 @@ async def chat_completion_full_generator( request_metadata.final_usage_info = usage + prompt_routed_experts = None + if final_res.prompt_routed_experts is not None: + prompt_routed_experts = final_res.prompt_routed_experts.tolist() + response = ChatCompletionResponse( id=request_id, created=created_time, @@ -1591,6 +1605,7 @@ async def chat_completion_full_generator( final_res.prompt_token_ids if request.return_token_ids else None ), kv_transfer_params=final_res.kv_transfer_params, + prompt_routed_experts=prompt_routed_experts, ) # Log complete response if output logging is enabled diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index c785d254084d..f103bb0202a3 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -468,12 +468,16 @@ class CompletionResponseChoice(OpenAIBaseModel): token_ids: list[int] | None = None # For response prompt_logprobs: list[dict[int, Logprob] | None] | None = None prompt_token_ids: list[int] | None = None # For prompt + routed_experts: list[list[list[int]]] | None = None # [gen_len, num_layers, top_k] class CompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: Literal["text_completion"] = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) + prompt_routed_experts: list[list[list[int]]] | None = ( + None # [prompt_len, num_layers, top_k] + ) model: str choices: list[CompletionResponseChoice] service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index fb7f253c7ea3..2bc0e4e3bb98 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -531,6 +531,11 @@ def request_output_to_completion_response( token_ids=( as_list(output.token_ids) if request.return_token_ids else None ), + routed_experts=( + output.routed_experts.tolist() + if output.routed_experts is not None + else None + ), ) choices.append(choice_data) @@ -554,8 +559,13 @@ def request_output_to_completion_response( ) request_metadata.final_usage_info = usage + prompt_routed_experts = None if final_res_batch: kv_transfer_params = final_res_batch[0].kv_transfer_params + pre = final_res_batch[0].prompt_routed_experts + if pre is not None: + prompt_routed_experts = pre.tolist() + return CompletionResponse( id=request_id, created=created_time, @@ -563,6 +573,7 @@ def request_output_to_completion_response( choices=choices, usage=usage, kv_transfer_params=kv_transfer_params, + prompt_routed_experts=prompt_routed_experts, ) def _create_completion_logprobs( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 190a9cc3b5d7..0afe72ff5986 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -237,6 +237,9 @@ class FusedMoE(PluggableLayer): router_logits_dtype: Data type for router logits buffers. """ + # Auto-incrementing layer ID for routing replay buffer binding. + _next_moe_layer_id: int = 0 + # --8<-- [end:fused_moe] def __init__( @@ -280,6 +283,10 @@ def __init__( self._routed_input_transform = routed_input_transform + # Assign unique layer ID for routing replay buffer binding. + self.moe_layer_id = FusedMoE._next_moe_layer_id + FusedMoE._next_moe_layer_id += 1 + if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype diff --git a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py index 5b93b3d5c6ea..6c157e8d1c21 100644 --- a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py +++ b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py @@ -1,353 +1,770 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Adapted from -# https://github.com/sgl-project/sglang/blob/bed301a5acaa9577c9aa706468bdf242f6a43051/python/sglang/srt/layers/moe/routed_experts_capturer.py - from __future__ import annotations -import fcntl +import contextlib import logging -import os -import tempfile -from collections.abc import Generator -from contextlib import contextmanager -from multiprocessing import shared_memory -from unittest.mock import patch +from abc import ABC, abstractmethod import numpy as np import torch +import torch.distributed -from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_rank -from vllm.forward_context import get_forward_context -from vllm.platforms import current_platform +from vllm.config.model import ModelConfig logger = logging.getLogger(__name__) -# Constants -_TMP_DIR = tempfile.gettempdir() -_LOCK_FILE_PREFIX = os.path.join(_TMP_DIR, "vllm_routed_experts") -_BUFFER_PREFIX = "vllm_routed_experts_buffer" - -# Global singleton instances -_global_experts_capturer: RoutedExpertsCapturer | None = None -_global_experts_reader: RoutedExpertsReader | None = None - - -@contextmanager -def _file_lock(lock_file: str, mode: str = "wb+") -> Generator[None, None, None]: - """Context manager for file-based locking.""" - with open(lock_file, mode) as fp: - fcntl.flock(fp, fcntl.LOCK_EX) - try: - yield - finally: - fcntl.flock(fp, fcntl.LOCK_UN) - - -def _create_or_attach_shared_memory( - name: str, size: int, lock_file: str -) -> shared_memory.SharedMemory: - """Create or attach to shared memory with proper locking.""" - # Ensure lock file exists before acquiring lock - with open(lock_file, "wb"): - pass - with _file_lock(lock_file): - try: - shm = shared_memory.SharedMemory(name=name, create=True, size=size) - except FileExistsError: - shm = shared_memory.SharedMemory(name=name, create=False, size=size) +# --------------------------------------------------------------------------- +# Custom op for routing capture -- traceable by torch.compile / Dynamo. +# +# Registered as a formal custom op so that torch.compile traces through it +# cleanly without graph breaks. ALL TP ranks call this op with a real +# device buffer to ensure identical CUDA graph structure (symmetry). +# Non-rank-0 buffers are written but never read for D2H. +# --------------------------------------------------------------------------- - if shm.size != size: - logger.warning( - "Shared memory %s size mismatch; recreating", - name, - ) - shm.close() - shm.unlink() - try: - shm = shared_memory.SharedMemory(name=name, create=True, size=size) - logger.info("Created shared memory %s", name) - except FileExistsError: - shm = shared_memory.SharedMemory(name=name, create=False, size=size) - logger.info("Linked to existing shared memory %s", name) - return shm +@torch.library.custom_op("vllm::capture_routing", mutates_args={"buffer"}) +def capture_routing_op( + buffer: torch.Tensor, + topk_ids: torch.Tensor, + layer_id: int, + batch_size: int, +) -> None: + buffer[layer_id, :batch_size, :].copy_( + topk_ids[:batch_size].to(buffer.dtype), non_blocking=True + ) -class RoutedExpertsCapturer: - """ - Capturer for routed experts with device and optional shared memory buffer. +@capture_routing_op.register_fake +def _capture_routing_op_fake( + buffer: torch.Tensor, + topk_ids: torch.Tensor, + layer_id: int, + batch_size: int, +) -> None: + pass + - This class captures expert routing decisions during model forward passes - and optionally stores them in shared memory for cross-process access. +_MB = 1024 * 1024 + + +class _RoutedExpertsDeviceCache: + """Per-device (GPU) cache for capturing routed expert IDs during forward + pass. Always writes at row 0 so that CUDA graph replay sees the same + addresses that were recorded at capture time. """ - _instance: RoutedExpertsCapturer | None = None + DTYPE = torch.int16 + + def __init__( + self, + num_batched_tokens: int, + num_hidden_layers: int, + num_experts_per_tok: int, + device: str, + ) -> None: + # Layout: (L, N, K) so that buffer[layer_id] is a contiguous (N, K) + # view — required by the FlashInfer routing-replay kernel which + # writes expert IDs assuming contiguous row-major memory. + self.num_hidden_layers = num_hidden_layers + self.buffer = torch.zeros( + (num_hidden_layers, num_batched_tokens, num_experts_per_tok), + dtype=self.DTYPE, + device=device, + ) + self._finalize_allocation_log() + + def get_buffer_size_bytes(self): + return self.buffer.nbytes + + def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): + assert layer_id is not None, "capturing routing experts but get layer_id None" + batch, _ = topk_ids.shape + self.buffer[layer_id, :batch, :].copy_(topk_ids, non_blocking=True) + + def _finalize_allocation_log(self): + buf_mb = self.get_buffer_size_bytes() / _MB + logger.info( + "Routing experts device buffer allocated. shape=%s, size=%.2f MB", + tuple(self.buffer.shape), + buf_mb, + ) - def __init__(self) -> None: - self._device_buffer: torch.Tensor | None = None - self._shm: shared_memory.SharedMemory | None = None - self._host_buffer_view: np.ndarray | None = None - self._lock_file: str | None = None - @classmethod - def create(cls) -> RoutedExpertsCapturer: - """Create a global singleton instance.""" - global _global_experts_capturer - if _global_experts_capturer is not None: - raise RuntimeError("Experts capturer already created.") +class _RoutedExpertsHostCache: + """Host (CPU) cache using numpy arrays for per-request routing data. - _global_experts_capturer = cls() - return _global_experts_capturer + Numpy arrays avoid torch dispatcher overhead for scatter operations. + Lazy per-request allocation avoids a massive up-front buffer. + """ - @staticmethod - def get_instance() -> RoutedExpertsCapturer | None: - """Get the global singleton instance.""" - return _global_experts_capturer + DTYPE = np.int16 - def init_buffer( + def __init__( self, - max_num_batched_tokens: int, - max_num_kv_tokens: int, - vllm_config: VllmConfig, + num_hidden_layers: int, + num_experts_per_tok: int, + max_model_len: int, ) -> None: - """ - Initialize the device buffer and optionally shared memory buffer. + self.max_model_len = max_model_len + self.num_hidden_layers = num_hidden_layers + self.num_experts_per_tok = num_experts_per_tok - Args: - max_num_batched_tokens: Maximum number of tokens in a batch. - max_num_kv_tokens: Maximum number of KV tokens for shared memory. - vllm_config: vllm configuration containing layer and expert info. - """ + self._req_buffers: dict[str, np.ndarray] = {} + self._filled_len: dict[str, int] = {} + self._total_allocated_bytes = 0 + + self._finalize_allocation_log() - if self._device_buffer is not None: - raise RuntimeError("Device buffer has already been initialized") + def get_buffer_size_bytes(self) -> int: + return self._total_allocated_bytes - hf_config = vllm_config.model_config.hf_text_config - num_layers = hf_config.num_hidden_layers - num_experts_per_tok = hf_config.num_experts_per_tok + def get_or_grow_buffer(self, req_id: str, max_pos: int) -> np.ndarray: + required_len = max_pos + 1 - # Initialize device buffer - self._device_buffer = torch.zeros( - (max_num_batched_tokens, num_layers, num_experts_per_tok), - dtype=torch.int32, - device=current_platform.device_type, + if req_id not in self._req_buffers: + buf = np.zeros( + (required_len, self.num_hidden_layers, self.num_experts_per_tok), + dtype=self.DTYPE, + ) + self._req_buffers[req_id] = buf + self._total_allocated_bytes += buf.nbytes + return buf + + buf = self._req_buffers[req_id] + if buf.shape[0] >= required_len: + return buf + + new_len = min(max(required_len, buf.shape[0] * 2), self.max_model_len) + new_buf = np.zeros( + (new_len, self.num_hidden_layers, self.num_experts_per_tok), + dtype=self.DTYPE, + ) + new_buf[: buf.shape[0]] = buf + self._total_allocated_bytes += new_buf.nbytes - buf.nbytes + self._req_buffers[req_id] = new_buf + return new_buf + + def get_buffer(self, req_id: str) -> np.ndarray | None: + return self._req_buffers.get(req_id) + + def update_filled_len(self, req_id: str, max_pos: int) -> None: + new_len = max_pos + 1 + self._filled_len[req_id] = max(self._filled_len.get(req_id, 0), new_len) + + def get_filled_len(self, req_id: str) -> int: + return self._filled_len.get(req_id, 0) + + def free_request(self, req_id: str) -> None: + if req_id in self._req_buffers: + self._total_allocated_bytes -= self._req_buffers.pop(req_id).nbytes + self._filled_len.pop(req_id, None) + + def _finalize_allocation_log(self): + logger.info( + "Routing experts host cache initialized (lazy allocation). " + "max_model_len=%s, layers=%s, experts_per_tok=%s", + self.max_model_len, + self.num_hidden_layers, + self.num_experts_per_tok, ) - self.dp_rank = vllm_config.parallel_config.data_parallel_rank - if get_tensor_model_parallel_rank() != 0: - return - # Initialize shared memory - shape = (max_num_kv_tokens, num_layers, num_experts_per_tok) - buffer_size = int(np.prod(shape)) * np.dtype(np.int32).itemsize - instance_id = vllm_config.instance_id - self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}_{self.dp_rank}.lock" - shm_name = f"{_BUFFER_PREFIX}_{instance_id}_{self.dp_rank}" +class RoutedExpertsCapturer(ABC): + @staticmethod + def create( + enable: bool, + model_config: ModelConfig, + num_fused_shared_experts: int, + num_batched_tokens: int, + max_model_len: int, + device: str, + shared_host_cache: _RoutedExpertsHostCache | None = None, + skip_host_cache: bool = False, + ): + if enable: + return _RoutedExpertsCapturerReal( + model_config, + num_batched_tokens=num_batched_tokens, + num_fused_shared_experts=num_fused_shared_experts, + max_model_len=max_model_len, + device=device, + shared_host_cache=shared_host_cache, + skip_host_cache=skip_host_cache, + ) + return _RoutedExpertsCapturerNoop() + + @abstractmethod + def capture(self, layer_id: int, topk_ids: torch.Tensor): + raise NotImplementedError + + def get_routed_experts( + self, req_id: str, seqlen: int | None = None, free_slot: bool = True + ): + raise NotImplementedError + + def sync_fwd_experts_buffer_DtoH( + self, + positions: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ): + raise NotImplementedError + + def finalize_pending_copy(self): + raise NotImplementedError + + def get_host_cache(self): + raise NotImplementedError + + def get_device_cache(self): + raise NotImplementedError + - self._shm = _create_or_attach_shared_memory( - shm_name, buffer_size, self._lock_file +class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + """Capturer with GPU device cache and CPU host cache. + + Performance strategy -- async D2H with optimized host-cache scatter: + + Every decode step we issue a non-blocking D2H copy on a dedicated + CUDA stream. The scatter into per-request host-cache buffers is + deferred to the start of the NEXT step (by which time the copy has + finished). The scatter loop is optimized with direct scalar access + to avoid numpy slice views, int() conversions, and .max() calls. + + At extraction time (when a request finishes), data is already in a + contiguous host buffer -- just a numpy slice, no concatenation. + """ + + def __init__( + self, + model_config: ModelConfig, + num_batched_tokens: int, + num_fused_shared_experts: int, + max_model_len: int, + device: str, + shared_host_cache: _RoutedExpertsHostCache | None = None, + skip_host_cache: bool = False, + ): + self.num_fused_shared_experts = num_fused_shared_experts + self.num_hidden_layers = model_config.hf_text_config.layers_block_type.count( + "moe" ) - self._host_buffer_view = np.ndarray(shape, dtype=np.int32, buffer=self._shm.buf) - self._host_buffer_view.fill(0) + self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok + self.num_batched_tokens = num_batched_tokens + self.max_model_len = max_model_len + self._skip_host_cache = skip_host_cache + + if skip_host_cache: + self.host_cache = None + logger.info("Skipping host cache for device %s (non-rank-0)", device) + elif shared_host_cache is not None: + self.host_cache = shared_host_cache + else: + self.host_cache = _RoutedExpertsHostCache( + num_hidden_layers=self.num_hidden_layers, + num_experts_per_tok=self.num_experts_per_tok, + max_model_len=self.max_model_len, + ) - logger.debug( - "Created shared memory buffer '%s' with shape %s", - shm_name, - shape, + self.device_cache = _RoutedExpertsDeviceCache( + num_batched_tokens=self.num_batched_tokens, + num_hidden_layers=self.num_hidden_layers, + num_experts_per_tok=self.num_experts_per_tok, + device=device, ) - def capture(self, layer_id: int, topk_ids: torch.Tensor) -> None: - """ - Capture expert routing decisions for a specific layer. + # ---- Async D2H pipeline (rank-0 only) ---- + # Non-rank-0 workers only need the device buffer for symmetric + # CUDA graph capture; they skip the D2H pipeline entirely. + self._has_pending_copy = False + self._pending_positions: np.ndarray | None = None + self._pending_num_scheduled: dict[str, int] | None = None + self._pending_total_tokens: int = 0 + + if not skip_host_cache: + # Same (L, N, K) layout as device_cache.buffer. + self._pinned_staging = torch.zeros( + (self.num_hidden_layers, num_batched_tokens, self.num_experts_per_tok), + dtype=_RoutedExpertsDeviceCache.DTYPE, + pin_memory=True, + ) + self._copy_stream = torch.cuda.Stream(device=device) + self._copy_event = torch.cuda.Event() + + pinned_mb = self._pinned_staging.nbytes / _MB + logger.info( + "Routing experts pinned staging buffer allocated. " + "shape=%s, size=%.2f MB", + tuple(self._pinned_staging.shape), + pinned_mb, + ) + else: + self._pinned_staging = None + self._copy_stream = None + self._copy_event = None + logger.info( + "Routing experts device-only capturer (rank != 0). " + "Device buffer shape=%s", + tuple(self.device_cache.buffer.shape), + ) - Args: - layer_id: The layer index. - topk_ids: Tensor of shape (batch_size, num_routed_experts). - """ - if self._device_buffer is None: - raise RuntimeError("Buffer not initialized. Call init_buffer() first.") - - ctx = get_forward_context() - if ctx.dp_metadata is None: # single dp - start_loc = 0 - end_loc = topk_ids.shape[0] - token_num_per_dp = topk_ids.shape[0] - else: # multi dp - num_tokens_dp = ctx.dp_metadata.num_tokens_across_dp_cpu - token_num_per_dp = int(num_tokens_dp[self.dp_rank].item()) - total = int(num_tokens_dp.sum().item()) - n = topk_ids.shape[0] - - if n == total: - # Naive dispatch: all DP ranks' tokens concatenated before routing. - cumsum = torch.cumsum(num_tokens_dp, dim=0) - end_loc = int(cumsum[self.dp_rank].item()) - start_loc = end_loc - token_num_per_dp - elif n == token_num_per_dp: - # Modular-kernel path: DP combine happens inside quant_method.apply; - # select_experts only sees this rank's tokens. - start_loc = 0 - end_loc = token_num_per_dp - else: - raise AssertionError( - "RoutedExpertsCapturer: unexpected topk_ids batch dim " - f"{n} (expected {total} or {token_num_per_dp} " - f"for dp_rank={self.dp_rank})" - ) + def capture(self, layer_id: int, topk_ids: torch.Tensor): + self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) - if layer_id >= self._device_buffer.shape[1]: + # ------------------------------------------------------------------ + # sync_fwd_experts_buffer_DtoH -- called AFTER the forward pass + # ------------------------------------------------------------------ + + def sync_fwd_experts_buffer_DtoH( + self, + positions: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ): + if self.host_cache is None: return - self._device_buffer[:token_num_per_dp, layer_id, :] = topk_ids[ - start_loc:end_loc, : - ] + # 1. Finalize previous async copy -- the copy had an entire + # forward pass to complete so event.synchronize() is ~free. + if self._has_pending_copy: + self._copy_event.synchronize() + self._scatter_to_host() + self._has_pending_copy = False - def clear_buffer(self) -> None: - """Clear the device buffer.""" - if self._device_buffer is not None: - self._device_buffer.zero_() + total_tokens = sum(num_scheduled_tokens.values()) + if total_tokens == 0: + return - def save_captured_experts(self, indices: np.ndarray) -> None: - """ - Save captured experts from device buffer to shared memory. + # 2. Issue new async D2H copy on a dedicated stream. + # Device buffer layout is (L, N, K); copy the first total_tokens + # along the N dimension for every layer. + main_stream = torch.cuda.current_stream(self._copy_stream.device) + with torch.cuda.stream(self._copy_stream): + self._copy_stream.wait_stream(main_stream) + self._pinned_staging[:, :total_tokens, :].copy_( + self.device_cache.buffer[:, :total_tokens, :], non_blocking=True + ) + self._copy_event.record() - Args: - indices: Array of indices indicating where to store the data. - """ - if get_tensor_model_parallel_rank() != 0: - return - if self._lock_file is None: - raise RuntimeError("Shared memory not initialized.") - if self._host_buffer_view is None: - return - if self._device_buffer is None: - raise RuntimeError("Device buffer not initialized.") + # 3. Save metadata for deferred scatter. + self._pending_positions = positions.numpy().copy() + self._pending_num_scheduled = num_scheduled_tokens + self._pending_total_tokens = total_tokens + self._has_pending_copy = True - num_tokens = len(indices) - data = self._device_buffer[:num_tokens, :, :].cpu().numpy() + # ------------------------------------------------------------------ + # Optimized scatter into pre-allocated host-cache buffers + # ------------------------------------------------------------------ - with _file_lock(self._lock_file): - self._host_buffer_view[indices, :, :] = data + def _scatter_to_host(self): + """Scatter D2H data into per-request host cache buffers. - def cleanup(self) -> None: - """Explicitly clean up shared memory resources.""" - if self._shm is not None: - try: - self._shm.close() - self._shm.unlink() - except Exception: - logger.debug("Exception during cleanup for capturer", exc_info=True) - finally: - self._shm = None + Staging layout is (L, N, K). Host cache layout is (seq_len, L, K). + We transpose the staging slice to (N, L, K) before scattering so + that indexing by token position naturally yields (L, K) rows. + """ + # Transpose (L, N, K) -> (N, L, K) for the active token range. + host_values = ( + self._pinned_staging[:, : self._pending_total_tokens, :] + .numpy() + .transpose(1, 0, 2) + ) + positions_np = self._pending_positions + host_cache = self.host_cache + assert self._pending_num_scheduled is not None + assert positions_np is not None + assert host_cache is not None + + offset = 0 + for req_id, n_tokens in self._pending_num_scheduled.items(): + if n_tokens == 0: + continue + + if n_tokens == 1: + pos_val = int(positions_np[offset]) + buf = host_cache.get_or_grow_buffer(req_id, pos_val) + buf[pos_val] = host_values[offset] + host_cache.update_filled_len(req_id, pos_val) + else: + pos = positions_np[offset : offset + n_tokens] + max_pos = int(pos[-1]) if n_tokens > 0 else 0 + if n_tokens > 1: + max_pos = int(pos.max()) + buf = host_cache.get_or_grow_buffer(req_id, max_pos) + buf[pos] = host_values[offset : offset + n_tokens] + host_cache.update_filled_len(req_id, max_pos) + + offset += n_tokens + + self._pending_positions = None + self._pending_num_scheduled = None + self._pending_total_tokens = 0 + + # ------------------------------------------------------------------ + # finalize_pending_copy -- call before reading host cache + # ------------------------------------------------------------------ + + def finalize_pending_copy(self): + """Ensure the most recent async D2H copy has been scattered into + host cache buffers. Call before get_routed_experts.""" + if self._has_pending_copy: + self._copy_event.synchronize() + self._scatter_to_host() + self._has_pending_copy = False + + # ------------------------------------------------------------------ + # Extraction -- O(1), just a numpy slice + # ------------------------------------------------------------------ + + def get_routed_experts( + self, + req_id: str, + seqlen: int | None = None, + free_slot: bool = True, + ): + if self.host_cache is None: + return None + buf = self.host_cache.get_buffer(req_id) + if buf is None: + return None + filled = self.host_cache.get_filled_len(req_id) + if filled <= 0: + return None + effective_len = min(filled, seqlen) if seqlen is not None else filled + result = buf[:effective_len].copy() + if free_slot: + self.host_cache.free_request(req_id) + return result + + def get_host_cache(self): + return self.host_cache + + def get_device_cache(self): + return self.device_cache + + +class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): + def __init__(self): + pass - def __del__(self) -> None: - """Clean up shared memory on destruction.""" - self.cleanup() + def capture(self, layer_id: int, topk_ids: torch.Tensor): + pass + def get_routed_experts(self, req_id: str, seqlen=None, free_slot=True): + return None -class RoutedExpertsReader: - """ - Reader for routed experts from shared memory. + def sync_fwd_experts_buffer_DtoH(self, positions, num_scheduled_tokens): + pass - This class attaches to shared memory created by RoutedExpertsCapturer - and reads expert routing decisions. - """ + def finalize_pending_copy(self): + pass - _instance: RoutedExpertsReader | None = None + def get_host_cache(self): + return None - def __init__(self) -> None: - self._shm: shared_memory.SharedMemory | None = None - self._host_buffer_view: np.ndarray | None = None - self._lock_file: str | None = None + def get_device_cache(self): + pass - @classmethod - def create(cls) -> RoutedExpertsReader: - """Create a global singleton instance.""" - global _global_experts_reader - if _global_experts_reader is not None: - raise RuntimeError("Experts reader already created.") - _global_experts_reader = cls() - return _global_experts_reader +# Global capturer instance (per-process) +_global_expert_capturer: RoutedExpertsCapturer | None = _RoutedExpertsCapturerNoop() +_shared_host_cache: _RoutedExpertsHostCache | None = None - @staticmethod - def get_instance() -> RoutedExpertsReader | None: - """Get the global singleton instance.""" - if _global_experts_reader is None: - logger.info("Experts reader not initialized.") - return _global_experts_reader - def attach_buffer( - self, - max_num_kv_tokens: int, - vllm_config: VllmConfig, - ) -> None: - """ - Attach to an existing shared memory buffer. +def get_global_experts_capturer(): + return _global_expert_capturer - Args: - max_num_kv_tokens: Maximum number of KV tokens. - vllm_config: vllm configuration. - """ - if self._shm is not None: - logger.warning("Already attached to shared memory buffer.") - return # Already attached - - hf_config = vllm_config.model_config.hf_text_config - shape = ( - max_num_kv_tokens, - hf_config.num_hidden_layers, - hf_config.num_experts_per_tok, - ) - self.dp_rank = vllm_config.parallel_config.data_parallel_rank - instance_id = vllm_config.instance_id - self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}_{self.dp_rank}.lock" - shm_name = f"{_BUFFER_PREFIX}_{instance_id}_{self.dp_rank}" - - with _file_lock(self._lock_file, mode="rb+"): - # Avoid resource_tracker registering the shared memory - with patch( - "multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None, - ): - self._shm = shared_memory.SharedMemory(name=shm_name) - - self._host_buffer_view = np.ndarray( - shape, dtype=np.int32, buffer=self._shm.buf - ) +def set_global_experts_capturer(capturer: RoutedExpertsCapturer): + global _global_expert_capturer + _global_expert_capturer = capturer - def get_routed_experts(self, indices: np.ndarray) -> np.ndarray: - """ - Read routed expert data from shared memory. - Args: - indices: Array of indices to read. +def extract_routed_experts_for_current_batch( + req_ids: list[str], + requests: dict, + req_id_to_index: dict[str, int], + num_tokens_no_spec: np.ndarray, + max_model_len: int, +) -> dict[str, tuple] | None: + """Extract routed experts for requests predicted to finish this step. - Returns: - Copy of the expert routing data for the given indices. - """ - if self._host_buffer_view is None: - raise RuntimeError("Buffer not attached. Call attach_buffer() first.") - if self._lock_file is None: - raise RuntimeError("Lock file not initialized.") - - with _file_lock(self._lock_file, mode="rb+"): - return self._host_buffer_view[indices, :, :].copy() - - def cleanup(self) -> None: - """Explicitly clean up resources (close without unlink).""" - if self._shm is not None: - try: - self._shm.close() - except Exception: - logger.debug("Exception during cleanup for reader", exc_info=True) - finally: - self._shm = None - - def __del__(self) -> None: - """Close shared memory on destruction (do not unlink).""" - self.cleanup() + Checks all stop conditions the scheduler will check (max_tokens, + EOS token, stop tokens, max_model_len) so that every finished + request gets its routing data attached to the ModelRunnerOutput. + + Args: + req_ids: Ordered request IDs for the current batch. + requests: Map of req_id to CachedRequestState (read-only). + req_id_to_index: Map of req_id to input batch index. + num_tokens_no_spec: Array of total token counts per request index. + max_model_len: Maximum model sequence length. + """ + capturer = get_global_experts_capturer() + if capturer is None: + return None + host_cache = capturer.get_host_cache() + if host_cache is None: + return None + + finishing_req_ids: list[str] = [] + for req_id in req_ids: + req_state = requests.get(req_id) + if req_state is None: + continue + sp = req_state.sampling_params + if sp is None: + continue + output_ids = req_state.output_token_ids + if not output_ids: + continue + if len(output_ids) < sp.min_tokens: + continue + + finishing = False + last_token = output_ids[-1] + + # EOS token (mirrors check_stop: eos_token_id is None + # when ignore_eos=True, so this naturally respects that) + if last_token == sp.eos_token_id: + finishing = True + + # Explicit stop token IDs + if not finishing and sp.stop_token_ids and last_token in sp.stop_token_ids: + finishing = True + + # max_tokens / max_model_len length cap + if not finishing: + if sp.max_tokens is not None and len(output_ids) >= sp.max_tokens: + finishing = True + else: + req_idx = req_id_to_index.get(req_id) + if req_idx is not None: + total = num_tokens_no_spec[req_idx] + if total >= max_model_len: + finishing = True + + if finishing: + finishing_req_ids.append(req_id) + + if not finishing_req_ids: + return None + + # At least one request is finishing: ensure the latest async D2H + # copy has been scattered into the host cache. + capturer.finalize_pending_copy() + + result: dict[str, tuple] = {} + for req_id in finishing_req_ids: + seqlen = host_cache.get_filled_len(req_id) + if seqlen <= 0: + continue + experts = capturer.get_routed_experts(req_id, seqlen=seqlen, free_slot=False) + if experts is not None: + result[req_id] = (experts.shape, experts.tobytes()) + + return result if result else None + + +def free_routing_buffers( + finished_req_ids: set[str], + preempted_req_ids: set[str] | None = None, +) -> None: + """Free host cache buffers for finished and preempted requests. + + Finished requests had their routing data extracted in the previous + step; preempted requests will be re-prefilled from scratch. + """ + capturer = get_global_experts_capturer() + if capturer is None: + return + host_cache = capturer.get_host_cache() + if host_cache is None: + return + + for req_id in finished_req_ids: + host_cache.free_request(req_id) + if preempted_req_ids: + for req_id in preempted_req_ids: + host_cache.free_request(req_id) + + +def issue_routing_d2h_copy( + input_batch_req_ids: list[str], + num_scheduled_tokens: dict[str, int], + positions: torch.Tensor, + positions_cpu: torch.Tensor, +) -> None: + """Issue async D2H copy of routed experts after the forward pass. + + Called EARLY in the execute_model epilogue so the copy overlaps with + eplb, kv_connector finalization, and draft work. + finalize_pending_copy() + get_routed_experts() happen later in + extract_routed_experts_for_current_batch(). + """ + capturer = get_global_experts_capturer() + if capturer is None: + return + + ordered = { + req_id: num_scheduled_tokens[req_id] + for req_id in input_batch_req_ids + if req_id in num_scheduled_tokens + } + n = sum(ordered.values()) + positions_cpu[:n].copy_(positions[:n]) + capturer.sync_fwd_experts_buffer_DtoH( + positions=positions_cpu[:n], + num_scheduled_tokens=ordered, + ) + + +def split_routed_experts( + routed_experts: np.ndarray, + prompt_len: int, + num_output_tokens: int | None = None, +) -> tuple[np.ndarray | None, np.ndarray | None]: + """Split routing data into prompt and generation portions. + + Args: + routed_experts: Full routing array of shape (seq_len, L, K). + prompt_len: Number of prompt tokens for the request. + num_output_tokens: Actual number of generated tokens (from + detokenizer). When provided, the generation portion is + clipped to this length — necessary with MTP where the model + runner may capture routing for more tokens than the final + output contains. + + Returns: + (prompt_routed_experts, gen_routed_experts) numpy arrays, either + of which may be None if the corresponding portion is empty. + """ + prompt_routed_experts = routed_experts[:prompt_len] + gen_routed_experts = routed_experts[prompt_len:] + + # Clip generation routing to match actual output tokens. + if ( + num_output_tokens is not None + and gen_routed_experts.shape[0] > num_output_tokens + and num_output_tokens > 0 + ): + gen_routed_experts = gen_routed_experts[:num_output_tokens] + + if prompt_routed_experts.size == 0: + prompt_routed_experts = None + if gen_routed_experts.size == 0: + gen_routed_experts = None + + return prompt_routed_experts, gen_routed_experts + + +def get_shared_host_cache() -> _RoutedExpertsHostCache | None: + return _shared_host_cache + + +def create_shared_host_cache( + model_config: ModelConfig, + max_model_len: int, +) -> _RoutedExpertsHostCache: + global _shared_host_cache + num_hidden_layers = model_config.hf_text_config.layers_block_type.count("moe") + num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok + _shared_host_cache = _RoutedExpertsHostCache( + num_hidden_layers=num_hidden_layers, + num_experts_per_tok=num_experts_per_tok, + max_model_len=max_model_len, + ) + return _shared_host_cache + + +def init_routed_experts_capturer_with_shared_cache( + enable: bool, + model_config: ModelConfig, + num_fused_shared_experts: int, + num_batched_tokens: int, + max_model_len: int, + device: str, + rank: int = 0, + world_size: int = 1, +) -> RoutedExpertsCapturer: + """Initialize capturer with rank-aware handling (only rank 0 captures).""" + if not enable: + capturer = _RoutedExpertsCapturerNoop() + set_global_experts_capturer(capturer) + return capturer + + if world_size > 1 and rank != 0: + # Non-rank-0 workers get a device-only capturer (no host cache, + # no D2H pipeline) so that ALL ranks have a real device buffer. + # This ensures the custom op call in every MoE layer produces + # identical CUDA graph structure across TP ranks. + logger.info("Creating device-only routed experts capturer for rank %s", rank) + capturer = RoutedExpertsCapturer.create( + enable=True, + model_config=model_config, + num_fused_shared_experts=num_fused_shared_experts, + num_batched_tokens=num_batched_tokens, + max_model_len=max_model_len, + device=device, + skip_host_cache=True, + ) + set_global_experts_capturer(capturer) + return capturer + + capturer = RoutedExpertsCapturer.create( + enable=True, + model_config=model_config, + num_fused_shared_experts=num_fused_shared_experts, + num_batched_tokens=num_batched_tokens, + max_model_len=max_model_len, + device=device, + skip_host_cache=False, + ) + set_global_experts_capturer(capturer) + return capturer + + +def bind_routing_capture_to_model(model) -> None: + """Bind routing capture buffers to all FusedMoE layers in the model. + + Must be called AFTER init_routed_experts_capturer_with_shared_cache() + and BEFORE CUDA graph capture. All TP ranks get a real buffer so + that the custom op call produces identical graph structure. + """ + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + capturer = get_global_experts_capturer() + device_cache = capturer.get_device_cache() + if device_cache is None: + return # routing capture not enabled + + buffer = device_cache.buffer + + # Mark the buffer so CUDA graphs do NOT snapshot/restore its contents. + if hasattr(torch.compiler, "cudagraph_mark_tensor_static"): + torch.compiler.cudagraph_mark_tensor_static(buffer) + elif hasattr(torch._C, "_set_static_address_tag"): + torch._C._set_static_address_tag(buffer, True) + with contextlib.suppress(Exception): + torch._dynamo.mark_static_address(buffer) + + bound = 0 + for module in model.modules(): + if isinstance(module, FusedMoE) and hasattr(module, "moe_layer_id"): + layer_id = module.moe_layer_id + layer_buf = buffer[layer_id] # (N_max, K) + module._routing_replay_out = layer_buf + # Mark each per-layer view as static so CUDA graphs don't + # snapshot/restore or relocate the buffer during replay. + if hasattr(torch.compiler, "cudagraph_mark_tensor_static"): + torch.compiler.cudagraph_mark_tensor_static(layer_buf) + with contextlib.suppress(Exception): + torch._dynamo.mark_static_address(layer_buf) + bound += 1 + + logger.info( + "Bound routing capture buffer to %s FusedMoE layers. Buffer shape=%s", + bound, + tuple(buffer.shape), + ) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py index 692d45d34607..bb6ca82b5066 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py @@ -395,6 +395,10 @@ def _apply_quant_method( shared_experts_input, SharedExpertsOrder.NO_OVERLAP ) + # Get routing replay buffer from persistent layer attribute + # (set by bind_routing_capture_to_model during capturer init) + routing_replay_out = getattr(layer, "_routing_replay_out", None) + if self.quant_method.is_monolithic: fused_out = self.quant_method.apply_monolithic( layer=layer, @@ -407,6 +411,10 @@ def _apply_quant_method( router_logits=router_logits, ) + # Write routing data for non-monolithic path (Triton, etc.) + if routing_replay_out is not None: + routing_replay_out[: topk_ids.shape[0]].copy_(topk_ids.to(torch.int16)) + # Passing shared_experts_input in case SharedExpertsOrder is # NO_OVERLAP or MK_INTERNAL_OVERLAPPED. fused_out = self.quant_method.apply( diff --git a/vllm/outputs.py b/vllm/outputs.py index 2c71d2afb1b5..aa6d12768ccc 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -121,6 +121,7 @@ def __init__( num_cached_tokens: int | None = None, *, kv_transfer_params: dict[str, Any] | None = None, + prompt_routed_experts: np.ndarray | None = None, # Forward compatibility, code that uses args added in new release can # still run with older versions of vLLM without breaking. **kwargs: Any, @@ -141,12 +142,15 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens self.kv_transfer_params = kv_transfer_params + self.prompt_routed_experts = prompt_routed_experts def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished self.kv_transfer_params = next_output.kv_transfer_params + if next_output.prompt_routed_experts is not None: + self.prompt_routed_experts = next_output.prompt_routed_experts for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 40b5899f0457..c43d1f476d02 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,8 +7,6 @@ from dataclasses import replace from typing import Any -import numpy as np - from vllm import envs from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.config import VllmConfig @@ -27,9 +25,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( - RoutedExpertsReader, -) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.encoder_budget import MultiModalBudget from vllm.v1.core.encoder_cache_manager import ( @@ -52,7 +47,7 @@ ) from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -256,43 +251,6 @@ def __init__( if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: self.perf_metrics = ModelMetrics(vllm_config) - if self.vllm_config.model_config.enable_return_routed_experts: - assert self.dcp_world_size == 1 and self.pcp_world_size == 1, ( - "enable_return_routed_experts does not support context parallelism " - "(dcp_world_size > 1 or pcp_world_size > 1)" - ) - - self.routed_experts_reader = RoutedExpertsReader.create() - - assert len(kv_cache_config.kv_cache_groups) > 0, ( - "enable_return_routed_experts requires at least one kv cache group" - ) - # Find the attention group for routed experts indexing. - self.routed_experts_attn_gid = 0 - for gid, group in enumerate(kv_cache_config.kv_cache_groups): - if isinstance(group.kv_cache_spec, AttentionSpec): - self.routed_experts_attn_gid = gid - break - min_block_size = min( - [ - group.kv_cache_spec.block_size - for group in kv_cache_config.kv_cache_groups - ] - ) - num_groups = len(kv_cache_config.kv_cache_groups) - self.max_num_kv_tokens = ( - kv_cache_config.num_blocks // num_groups - ) * min_block_size - dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size - pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size - if pcp_size * dcp_size > 1: - self.max_num_kv_tokens *= pcp_size * dcp_size - - self.routed_experts_reader.attach_buffer( - max_num_kv_tokens=self.max_num_kv_tokens, - vllm_config=self.vllm_config, - ) - self._pause_state: PauseState = PauseState.UNPAUSED def _mamba_block_aligned_split( @@ -1424,11 +1382,15 @@ def update_from_output( request.resumable = False stopped = True + # Get routing data from ModelRunnerOutput (via worker D2H pipeline) routed_experts = None + if ( + model_runner_output.routed_experts_dict is not None + and req_id in model_runner_output.routed_experts_dict + ): + routed_experts = model_runner_output.routed_experts_dict[req_id] finish_reason = None if stopped: - routed_experts = self._get_routed_experts(request) - # Capture finish_reason BEFORE _handle_stopped_request, which may # reset the status to WAITING for streaming requests that continue. finish_reason = request.get_finished_reason() @@ -1603,31 +1565,6 @@ def _handle_stopped_request(self, request: Request) -> bool: self._enqueue_waiting_request(request) return False - def _get_routed_experts(self, request: Request) -> np.ndarray | None: - if not self.vllm_config.model_config.enable_return_routed_experts: - return None - - kv_blocks = self.kv_cache_manager.get_blocks(request.request_id) - block_ids = kv_blocks.get_block_ids()[self.routed_experts_attn_gid] - num_tokens = request.num_tokens - 1 - - # compute slot mapping using attention group's block_size - block_ids_array = np.array(block_ids, dtype=np.int32) - num_blocks = len(block_ids) - attn_group = self.kv_cache_config.kv_cache_groups[self.routed_experts_attn_gid] - block_size = attn_group.kv_cache_spec.block_size - - # generate block offsets - block_offsets = np.arange(0, block_size) - - # compute slot mapping: slot = block_id * block_size + offset - slot_mapping = ( - block_offsets.reshape((1, block_size)) - + block_ids_array.reshape((num_blocks, 1)) * block_size - ).flatten()[:num_tokens] - - return self.routed_experts_reader.get_routed_experts(indices=slot_mapping) - def _update_request_with_output( self, request: Request, new_token_ids: list[int] ) -> tuple[list[int], bool]: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d5c5dba63475..120939fc5b61 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -8,7 +8,6 @@ from typing import Any, Literal import msgspec -import numpy as np import torch from vllm.lora.request import LoRARequest @@ -174,7 +173,7 @@ class EngineCoreOutput( prefill_stats: PrefillStats | None = None - routed_experts: np.ndarray | None = None + routed_experts: tuple | None = None # The number of NaNs in logits. # A value greater than 0 indicates that the output is corrupted. num_nans_in_logits: int = 0 diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1ae89ae19680..107e6805ce94 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -11,6 +11,9 @@ import torch from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + split_routed_experts, +) from vllm.outputs import ( STREAM_FINISHED, CompletionOutput, @@ -314,8 +317,24 @@ def make_request_output( finished, ) + # Split routing data into prompt and generation portions. + # Prompt routing lives on RequestOutput (shared across n>1 + # completions); generation routing lives on each CompletionOutput. + prompt_routed_experts = None + gen_routed_experts = None + if routed_experts is not None: + prompt_len = len(self.prompt_token_ids) if self.prompt_token_ids else 0 + num_gen = ( + self.detokenizer.num_output_tokens() + if self.detokenizer is not None + else None + ) + prompt_routed_experts, gen_routed_experts = split_routed_experts( + routed_experts, prompt_len, num_gen + ) + output = self._new_completion_output( - new_token_ids, finish_reason, stop_reason, routed_experts + new_token_ids, finish_reason, stop_reason, gen_routed_experts ) if self.parent_req is None: @@ -327,7 +346,11 @@ def make_request_output( external_req_id = self.parent_req.external_req_id return self._new_request_output( - external_req_id, outputs, finished, kv_transfer_params + external_req_id, + outputs, + finished, + kv_transfer_params, + prompt_routed_experts, ) def _new_request_output( @@ -336,6 +359,7 @@ def _new_request_output( outputs: list[CompletionOutput] | list[PoolingOutput], finished: bool, kv_transfer_params: dict[str, Any] | None = None, + prompt_routed_experts: np.ndarray | None = None, ) -> RequestOutput | PoolingRequestOutput: # If prompt embeds were used, put placeholder prompt token ids prompt_token_ids = self.prompt_token_ids @@ -371,6 +395,7 @@ def _new_request_output( kv_transfer_params=kv_transfer_params, num_cached_tokens=self.num_cached_tokens, metrics=self.stats, + prompt_routed_experts=prompt_routed_experts, ) def _new_completion_output( @@ -618,6 +643,10 @@ def process_outputs( kv_transfer_params = engine_core_output.kv_transfer_params routed_experts = engine_core_output.routed_experts + if routed_experts is not None: + shape, data = routed_experts + routed_experts = np.frombuffer(data, dtype=np.int16).reshape(shape) + if req_state.is_prefilling: if engine_core_output.prefill_stats is not None: req_state.num_cached_tokens = ( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 1f102ec61783..b9b7fea39d10 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -198,6 +198,9 @@ class ModelRunnerOutput: # req_id -> num_nans_in_logits num_nans_in_logits: dict[str, int] | None = None + # req_id -> routed experts data (shape, bytes) tuples + routed_experts_dict: dict[str, tuple] | None = None + # information related to cudagraph execution cudagraph_stats: CUDAGraphStat | None = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c573f79e97a..1a5eb7bd9cfa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -54,7 +54,11 @@ from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( - RoutedExpertsCapturer, + extract_routed_experts_for_current_batch, + free_routing_buffers, + get_global_experts_capturer, + init_routed_experts_capturer_with_shared_cache, + issue_routing_d2h_copy, ) from vllm.model_executor.layers.mamba.ops.ssu_dispatch import ( initialize_mamba_ssu_backend, @@ -1082,6 +1086,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None for req_id in scheduler_output.finished_req_ids: self.input_batch.remove_request(req_id) + if self.routed_experts_initialized: + free_routing_buffers( + scheduler_output.finished_req_ids, + scheduler_output.preempted_req_ids, + ) + # Zero GPU memory for freshly allocated cache blocks to prevent # stale NaN/data from corrupting attention or SSM computation. if scheduler_output.new_block_ids_to_zero: @@ -2144,10 +2154,7 @@ def _get_block_table(kv_cache_gid: int): block_table_gid_0 = _get_block_table(0) slot_mapping_gid_0 = slot_mappings[0] - if self.routed_experts_initialized: - attn_gid = self.routed_experts_attn_gid - slot_mapping_attn = slot_mappings[attn_gid] - self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() + # routing replay uses device cache approach (no slot_mapping needed) num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ :num_reqs_padded ] @@ -3771,11 +3778,9 @@ def execute_model( ) if self.routed_experts_initialized: - capturer = RoutedExpertsCapturer.get_instance() + capturer = get_global_experts_capturer() if capturer is not None: - capturer.clear_buffer() # noqa - else: - logger.error("RoutedExpertsCapturer not initialized.") + capturer.finalize_pending_copy() # If ngram_gpu is used, we need to copy the scheduler_output to avoid # the modification has influence on the scheduler_output in engine core process. @@ -4285,6 +4290,14 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, ) + if self.routed_experts_initialized: + issue_routing_d2h_copy( + input_batch_req_ids=self.input_batch.req_ids, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + positions=self.positions, + positions_cpu=self._positions_cpu, + ) + if propose_drafts_after_bookkeeping: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. @@ -4304,12 +4317,15 @@ def propose_draft_token_ids(sampled_token_ids): self.kv_connector_output = None with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + routed_experts_dict = None if self.routed_experts_initialized: - capturer = RoutedExpertsCapturer.get_instance() - if capturer is not None: - capturer.save_captured_experts(indices=self.slot_mapping) # noqa - else: - logger.error("RoutedExpertsCapturer not initialized.") + routed_experts_dict = extract_routed_experts_for_current_batch( + req_ids=req_ids_output_copy, + requests=self.requests, + req_id_to_index=self.input_batch.req_id_to_index, + num_tokens_no_spec=self.input_batch.num_tokens_no_spec, + max_model_len=self.max_model_len, + ) output = ModelRunnerOutput( req_ids=req_ids_output_copy, @@ -4323,6 +4339,7 @@ def propose_draft_token_ids(sampled_token_ids): else None, num_nans_in_logits=num_nans_in_logits, cudagraph_stats=cudagraph_stats, + routed_experts_dict=routed_experts_dict, ) if not self.use_async_scheduling: @@ -5992,6 +6009,7 @@ def capture_model(self) -> int: "Skipping CUDA graph capture. To turn on CUDA graph capture, " "ensure `cudagraph_mode` was not manually set to `NONE`" ) + self.init_routed_experts_capturer() return 0 # Initialize encoder CUDA graph manager if enabled. @@ -6025,6 +6043,13 @@ def capture_model(self) -> int: start_time = time.perf_counter() + # Initialize the routed experts capturer once before any CUDA graph + # capture. Must happen before graphs are captured so the buffer + # address is baked into the graph. Do NOT call this inside + # _capture_cudagraphs() -- creating the capturer twice replaces + # the device buffer, causing graphs to write to a dead buffer. + self.init_routed_experts_capturer() + # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. @@ -6790,45 +6815,40 @@ def init_routed_experts_capturer(self): "Initializing routed experts capturer, enable_return_routed_experts: %s", self.model_config.enable_return_routed_experts, ) - routed_experts_capturer = RoutedExpertsCapturer.create() - self.routed_experts_attn_gid = self._get_attention_kv_cache_gid() - min_block_size = min( - [ - group.kv_cache_spec.block_size - for group in self.kv_cache_config.kv_cache_groups - ] + from vllm.distributed import get_tp_group + + if hasattr(self.model_config.hf_text_config, "n_shared_experts"): + num_fused_shared_experts = 1 + else: + num_fused_shared_experts = 0 + + tp_group = get_tp_group() + init_routed_experts_capturer_with_shared_cache( + enable=self.model_config.enable_return_routed_experts, + model_config=self.model_config, + num_fused_shared_experts=num_fused_shared_experts, + num_batched_tokens=self.scheduler_config.max_num_batched_tokens, + max_model_len=self.max_model_len, + device=self.device, + rank=tp_group.rank_in_group, + world_size=tp_group.world_size, ) - num_groups = len(self.kv_cache_config.kv_cache_groups) - self.max_num_kv_tokens = ( - self.kv_cache_config.num_blocks // num_groups - ) * min_block_size - dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size - pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size - if pcp_size * dcp_size > 1: - self.max_num_kv_tokens *= pcp_size * dcp_size - - routed_experts_capturer.init_buffer( - max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, - max_num_kv_tokens=self.max_num_kv_tokens, - vllm_config=self.vllm_config, - ) - self._bind_routed_experts_capturer(routed_experts_capturer) + self._bind_routed_experts_capturer() self.routed_experts_initialized = True - def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None: - from vllm.model_executor.layers.fused_moe.layer import FusedMoE - from vllm.model_executor.layers.fused_moe.router.base_router import ( - BaseRouter, + # Pinned CPU buffer for async positions D2H (avoids sync .cpu() call) + self._positions_cpu = torch.empty( + self.scheduler_config.max_num_batched_tokens, + dtype=torch.long, + pin_memory=True, ) - for module in self.compilation_config.static_forward_context.values(): - if isinstance(module, FusedMoE) and isinstance(module.router, BaseRouter): - layer_id = module.layer_id - - def _capture_fn(topk_ids, _layer_id=layer_id, _capturer=capturer): - _capturer.capture(_layer_id, topk_ids) + def _bind_routed_experts_capturer(self) -> None: + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + bind_routing_capture_to_model, + ) - module.router.set_capture_fn(_capture_fn) + bind_routing_capture_to_model(self.model) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """