diff --git a/tests/model_executor/test_routed_experts_capture.py b/tests/model_executor/test_routed_experts_capture.py index 770a3fa53850..9570e09fc982 100644 --- a/tests/model_executor/test_routed_experts_capture.py +++ b/tests/model_executor/test_routed_experts_capture.py @@ -185,59 +185,109 @@ def capture(self, layer_id, topk_ids): assert len(capturer.calls) == 1 -def test_routed_experts_capturer_single_dp_no_metadata(): - """dp_metadata is None: capture writes the full topk_ids rows.""" - capturer = _capturer_with_buffer(dp_rank=0) - topk = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32) - ctx = SimpleNamespace(dp_metadata=None) - with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx): - capturer.capture(layer_id=0, topk_ids=topk) - assert torch.equal(capturer._device_buffer[:3, 0, :], topk) - assert capturer._device_buffer[3, 0, 0].item() == -1 - - -def test_routed_experts_capturer_dp_naive_concatenated_all_ranks(): - """n == sum(num_tokens_dp): slice this rank's segment from concatenated topk.""" - capturer = _capturer_with_buffer(dp_rank=1) - num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32) - ctx = SimpleNamespace( - dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp) - ) - # Concatenated order: rank0 rows then rank1 rows. - topk = torch.tensor( - [[0, 1], [2, 3], [10, 11], [12, 13], [14, 15]], dtype=torch.int32 - ) - with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx): - capturer.capture(layer_id=0, topk_ids=topk) - want = topk[2:5] - assert torch.equal(capturer._device_buffer[:3, 0, :], want) - - -def test_routed_experts_capturer_dp_modular_local_tokens(): - """n == token_num_per_dp: topk is already local to this DP rank.""" - capturer = _capturer_with_buffer(dp_rank=1) - num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32) - ctx = SimpleNamespace( - dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp) - ) - topk = torch.tensor([[10, 11], [12, 13], [14, 15]], dtype=torch.int32) - with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx): - capturer.capture(layer_id=0, topk_ids=topk) - assert torch.equal(capturer._device_buffer[:3, 0, :], topk) - - -def test_routed_experts_capturer_dp_unexpected_batch_raises(): - """Mismatch between topk batch dim and DP layout: fail fast.""" - capturer = _capturer_with_buffer(dp_rank=0) - num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32) - ctx = SimpleNamespace( - dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp) - ) - # total=5, local=2: n=1 matches neither naive (5) nor modular (2). - topk = torch.tensor([[1, 2]], dtype=torch.int32) - with ( - patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx), - pytest.raises(AssertionError, match="unexpected topk_ids batch dim"), - ): - capturer.capture(layer_id=0, topk_ids=topk) - assert capturer._device_buffer[0, 0, 0].item() == -1 +# ========================================================================= +# Tests for device-cache routing replay architecture +# ========================================================================= + + +class TestRoutedExpertsDeviceCache: + """Tests for _RoutedExpertsDeviceCache (GPU buffer for routing data).""" + + def test_allocation_shape_and_dtype(self): + """Device cache allocates (L, N, K) int16 buffer.""" + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsDeviceCache, + ) + + cache = _RoutedExpertsDeviceCache( + num_hidden_layers=40, + max_num_batched_tokens=8192, + num_experts_per_tok=8, + ) + assert cache.buffer.shape == (40, 8192, 8) + assert cache.buffer.dtype == torch.int16 + + def test_per_layer_view_is_contiguous(self): + """buffer[layer_id] gives contiguous (N, K) view for FlashInfer.""" + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsDeviceCache, + ) + + cache = _RoutedExpertsDeviceCache( + num_hidden_layers=40, + max_num_batched_tokens=8192, + num_experts_per_tok=8, + ) + layer_view = cache.buffer[0] + assert layer_view.is_contiguous() + assert layer_view.shape == (8192, 8) + + +class TestRoutedExpertsHostCache: + """Tests for _RoutedExpertsHostCache (per-request numpy buffer).""" + + def test_sentinel_initialization(self): + """Host cache initializes with -1 sentinel (not zeros). + + This is critical: expert ID 0 is valid, so we use -1 to mark + positions with no routing data (e.g., prefix-cached tokens). + """ + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsHostCache, + ) + import numpy as np + + cache = _RoutedExpertsHostCache( + num_hidden_layers=40, + num_experts_per_tok=8, + ) + buf = cache.get_or_grow_buffer("req1", max_pos=100) + assert buf.dtype == np.int16 + assert (buf == -1).all(), "Host cache must initialize with -1, not zeros" + + def test_grow_preserves_existing_data(self): + """Growing the buffer preserves previously written data.""" + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsHostCache, + ) + + cache = _RoutedExpertsHostCache( + num_hidden_layers=40, + num_experts_per_tok=8, + ) + buf = cache.get_or_grow_buffer("req1", max_pos=50) + buf[0, 0, 0] = 42 + buf2 = cache.get_or_grow_buffer("req1", max_pos=200) + assert buf2[0, 0, 0] == 42, "Data lost during buffer grow" + + def test_free_request_removes_buffer(self): + """Freeing a request removes its buffer.""" + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _RoutedExpertsHostCache, + ) + + cache = _RoutedExpertsHostCache( + num_hidden_layers=40, + num_experts_per_tok=8, + ) + cache.get_or_grow_buffer("req1", max_pos=50) + cache.free_request("req1") + assert cache.get_buffer("req1") is None + + +class TestMonolithicWritesFlag: + """Test that quant methods correctly declare routing replay capability.""" + + def test_fp8_has_flag(self): + from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod + + assert getattr(Fp8MoEMethod, "_monolithic_writes_routing_replay", False) + + def test_base_class_no_flag(self): + from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, + ) + + assert not getattr( + FusedMoEMethodBase, "_monolithic_writes_routing_replay", False + ) diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 2bc1b6e08750..943701135acc 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -92,6 +92,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): # not part of the OpenAI spec but is useful for tracing the tokens # in agent scenarios token_ids: list[int] | None = None + routed_experts: list[list[list[int]]] | None = None # [seq_len, num_layers, top_k] class ChatCompletionResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 0b8dd0aa28ef..3d3d6fa57e04 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -1283,6 +1283,11 @@ async def chat_completion_full_generator( token_ids=( as_list(output.token_ids) if request.return_token_ids else None ), + routed_experts=( + output.routed_experts.tolist() + if output.routed_experts is not None + else None + ), ) choices.append(choice_data) continue @@ -1482,6 +1487,11 @@ async def chat_completion_full_generator( token_ids=( as_list(output.token_ids) if request.return_token_ids else None ), + routed_experts=( + output.routed_experts.tolist() + if output.routed_experts is not None + else None + ), ) choice_data = maybe_filter_parallel_tool_calls(choice_data, request) diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index c785d254084d..e13631a99430 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -468,6 +468,7 @@ class CompletionResponseChoice(OpenAIBaseModel): token_ids: list[int] | None = None # For response prompt_logprobs: list[dict[int, Logprob] | None] | None = None prompt_token_ids: list[int] | None = None # For prompt + routed_experts: list[list[list[int]]] | None = None # [seq_len, num_layers, top_k] class CompletionResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index fb7f253c7ea3..abf4bd08b3ee 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -531,6 +531,11 @@ def request_output_to_completion_response( token_ids=( as_list(output.token_ids) if request.return_token_ids else None ), + routed_experts=( + output.routed_experts.tolist() + if output.routed_experts is not None + else None + ), ) choices.append(choice_data) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index a239dfea92e4..edc3c562c731 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -169,5 +169,6 @@ def apply_monolithic( layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, + routing_replay_out: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 190a9cc3b5d7..49fc70251b49 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -215,6 +215,8 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: # --8<-- [start:fused_moe] @PluggableLayer.register("fused_moe") class FusedMoE(PluggableLayer): + # Auto-incrementing layer ID for routing replay buffer binding. + _next_moe_layer_id: int = 0 """FusedMoE layer for MoE models. This layer contains both MergedColumnParallel weights (gate_up_proj / @@ -280,6 +282,10 @@ def __init__( self._routed_input_transform = routed_input_transform + # Assign unique layer ID for routing replay buffer binding. + self.moe_layer_id = FusedMoE._next_moe_layer_id + FusedMoE._next_moe_layer_id += 1 + if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f2e6e2560e70..1c302396a150 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -987,6 +987,7 @@ def apply( e_score_correction_bias: torch.Tensor | None = None, routed_scaling_factor: float | None = None, topk_group: int | None = None, + routing_replay_out: torch.Tensor | None = None, ) -> torch.Tensor: """ Same as apply(), except uses router_logits as opposed diff --git a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py index 5b93b3d5c6ea..32b9ca9296b0 100644 --- a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py +++ b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py @@ -1,353 +1,608 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Adapted from -# https://github.com/sgl-project/sglang/blob/bed301a5acaa9577c9aa706468bdf242f6a43051/python/sglang/srt/layers/moe/routed_experts_capturer.py - -from __future__ import annotations - -import fcntl import logging -import os -import tempfile -from collections.abc import Generator -from contextlib import contextmanager -from multiprocessing import shared_memory -from unittest.mock import patch +from abc import ABC +from typing import Optional import numpy as np import torch +import torch.distributed +from vllm.config.model import ModelConfig -from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_rank -from vllm.forward_context import get_forward_context -from vllm.platforms import current_platform logger = logging.getLogger(__name__) -# Constants -_TMP_DIR = tempfile.gettempdir() -_LOCK_FILE_PREFIX = os.path.join(_TMP_DIR, "vllm_routed_experts") -_BUFFER_PREFIX = "vllm_routed_experts_buffer" - -# Global singleton instances -_global_experts_capturer: RoutedExpertsCapturer | None = None -_global_experts_reader: RoutedExpertsReader | None = None - - -@contextmanager -def _file_lock(lock_file: str, mode: str = "wb+") -> Generator[None, None, None]: - """Context manager for file-based locking.""" - with open(lock_file, mode) as fp: - fcntl.flock(fp, fcntl.LOCK_EX) - try: - yield - finally: - fcntl.flock(fp, fcntl.LOCK_UN) - - -def _create_or_attach_shared_memory( - name: str, size: int, lock_file: str -) -> shared_memory.SharedMemory: - """Create or attach to shared memory with proper locking.""" - # Ensure lock file exists before acquiring lock - with open(lock_file, "wb"): - pass - with _file_lock(lock_file): - try: - shm = shared_memory.SharedMemory(name=name, create=True, size=size) - except FileExistsError: - shm = shared_memory.SharedMemory(name=name, create=False, size=size) +# --------------------------------------------------------------------------- +# Custom op for routing capture -- traceable by torch.compile / Dynamo. +# +# Registered as a formal custom op so that torch.compile traces through it +# cleanly without graph breaks. ALL TP ranks call this op with a real +# device buffer to ensure identical CUDA graph structure (symmetry). +# Non-rank-0 buffers are written but never read for D2H. +# --------------------------------------------------------------------------- - if shm.size != size: - logger.warning( - "Shared memory %s size mismatch; recreating", - name, - ) - shm.close() - shm.unlink() - try: - shm = shared_memory.SharedMemory(name=name, create=True, size=size) - logger.info("Created shared memory %s", name) - except FileExistsError: - shm = shared_memory.SharedMemory(name=name, create=False, size=size) - logger.info("Linked to existing shared memory %s", name) - return shm +@torch.library.custom_op("vllm::capture_routing", mutates_args={"buffer"}) +def capture_routing_op( + buffer: torch.Tensor, + topk_ids: torch.Tensor, + layer_id: int, + batch_size: int, +) -> None: + buffer[layer_id, :batch_size, :].copy_( + topk_ids[:batch_size].to(buffer.dtype), non_blocking=True + ) -class RoutedExpertsCapturer: - """ - Capturer for routed experts with device and optional shared memory buffer. +@capture_routing_op.register_fake +def _capture_routing_op_fake( + buffer: torch.Tensor, + topk_ids: torch.Tensor, + layer_id: int, + batch_size: int, +) -> None: + pass - This class captures expert routing decisions during model forward passes - and optionally stores them in shared memory for cross-process access. - """ - _instance: RoutedExpertsCapturer | None = None +_GB = 1024 * 1024 * 1024 +_MB = 1024 * 1024 - def __init__(self) -> None: - self._device_buffer: torch.Tensor | None = None - self._shm: shared_memory.SharedMemory | None = None - self._host_buffer_view: np.ndarray | None = None - self._lock_file: str | None = None - @classmethod - def create(cls) -> RoutedExpertsCapturer: - """Create a global singleton instance.""" - global _global_experts_capturer - if _global_experts_capturer is not None: - raise RuntimeError("Experts capturer already created.") +def get_tensor_size_bytes(t: torch.Tensor): + return np.prod(t.shape) * t.dtype.itemsize - _global_experts_capturer = cls() - return _global_experts_capturer - @staticmethod - def get_instance() -> RoutedExpertsCapturer | None: - """Get the global singleton instance.""" - return _global_experts_capturer +class _RoutedExpertsDeviceCache: + """Per-device (GPU) cache for capturing routed expert IDs during forward + pass. Always writes at row 0 so that CUDA graph replay sees the same + addresses that were recorded at capture time. + """ + + DTYPE = torch.int16 - def init_buffer( + def __init__( self, - max_num_batched_tokens: int, - max_num_kv_tokens: int, - vllm_config: VllmConfig, + num_batched_tokens: int, + num_hidden_layers: int, + num_experts_per_tok: int, + num_fused_shared_experts: int, + device: str, ) -> None: - """ - Initialize the device buffer and optionally shared memory buffer. - - Args: - max_num_batched_tokens: Maximum number of tokens in a batch. - max_num_kv_tokens: Maximum number of KV tokens for shared memory. - vllm_config: vllm configuration containing layer and expert info. - """ + # Layout: (L, N, K) so that buffer[layer_id] is a contiguous (N, K) + # view — required by the FlashInfer routing-replay kernel which + # writes expert IDs assuming contiguous row-major memory. + self.num_hidden_layers = num_hidden_layers + self.buffer = torch.zeros( + (num_hidden_layers, num_batched_tokens, num_experts_per_tok), + dtype=self.DTYPE, + device=device, + ) + self._finalize_allocation_log() - if self._device_buffer is not None: - raise RuntimeError("Device buffer has already been initialized") + def get_buffer_size_bytes(self): + return get_tensor_size_bytes(self.buffer) - hf_config = vllm_config.model_config.hf_text_config - num_layers = hf_config.num_hidden_layers - num_experts_per_tok = hf_config.num_experts_per_tok + def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): + assert layer_id is not None, "capturing routing experts but get layer_id None" + batch, _ = topk_ids.shape + self.buffer[layer_id, :batch, :].copy_(topk_ids, non_blocking=True) - # Initialize device buffer - self._device_buffer = torch.zeros( - (max_num_batched_tokens, num_layers, num_experts_per_tok), - dtype=torch.int32, - device=current_platform.device_type, + def _finalize_allocation_log(self): + buf_mb = self.get_buffer_size_bytes() / _MB + logger.info( + f"Routing experts device buffer allocated. " + f"shape={tuple(self.buffer.shape)}, size={buf_mb:.2f} MB" ) - self.dp_rank = vllm_config.parallel_config.data_parallel_rank - if get_tensor_model_parallel_rank() != 0: - return - # Initialize shared memory - shape = (max_num_kv_tokens, num_layers, num_experts_per_tok) - buffer_size = int(np.prod(shape)) * np.dtype(np.int32).itemsize - instance_id = vllm_config.instance_id - self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}_{self.dp_rank}.lock" - shm_name = f"{_BUFFER_PREFIX}_{instance_id}_{self.dp_rank}" +class _RoutedExpertsHostCache: + """Host (CPU) cache using numpy arrays for per-request routing data. - self._shm = _create_or_attach_shared_memory( - shm_name, buffer_size, self._lock_file - ) - self._host_buffer_view = np.ndarray(shape, dtype=np.int32, buffer=self._shm.buf) - self._host_buffer_view.fill(0) + Numpy arrays avoid torch dispatcher overhead for scatter operations. + Lazy per-request allocation avoids a massive up-front buffer. + """ - logger.debug( - "Created shared memory buffer '%s' with shape %s", - shm_name, - shape, - ) + DTYPE = np.int16 - def capture(self, layer_id: int, topk_ids: torch.Tensor) -> None: - """ - Capture expert routing decisions for a specific layer. + def __init__( + self, + num_hidden_layers: int, + num_experts_per_tok: int, + max_running_requests: int, + max_model_len: int, + use_shared_memory: bool = True, + ) -> None: + self.max_model_len = max_model_len + self.max_running_requests = max_running_requests + self.num_hidden_layers = num_hidden_layers + self.num_experts_per_tok = num_experts_per_tok + self._use_shared_memory = use_shared_memory - Args: - layer_id: The layer index. - topk_ids: Tensor of shape (batch_size, num_routed_experts). - """ - if self._device_buffer is None: - raise RuntimeError("Buffer not initialized. Call init_buffer() first.") - - ctx = get_forward_context() - if ctx.dp_metadata is None: # single dp - start_loc = 0 - end_loc = topk_ids.shape[0] - token_num_per_dp = topk_ids.shape[0] - else: # multi dp - num_tokens_dp = ctx.dp_metadata.num_tokens_across_dp_cpu - token_num_per_dp = int(num_tokens_dp[self.dp_rank].item()) - total = int(num_tokens_dp.sum().item()) - n = topk_ids.shape[0] - - if n == total: - # Naive dispatch: all DP ranks' tokens concatenated before routing. - cumsum = torch.cumsum(num_tokens_dp, dim=0) - end_loc = int(cumsum[self.dp_rank].item()) - start_loc = end_loc - token_num_per_dp - elif n == token_num_per_dp: - # Modular-kernel path: DP combine happens inside quant_method.apply; - # select_experts only sees this rank's tokens. - start_loc = 0 - end_loc = token_num_per_dp - else: - raise AssertionError( - "RoutedExpertsCapturer: unexpected topk_ids batch dim " - f"{n} (expected {total} or {token_num_per_dp} " - f"for dp_rank={self.dp_rank})" - ) + self._req_buffers: dict[str, np.ndarray] = {} + self._filled_len: dict[str, int] = {} + self._total_allocated_bytes = 0 - if layer_id >= self._device_buffer.shape[1]: - return + self._finalize_allocation_log() - self._device_buffer[:token_num_per_dp, layer_id, :] = topk_ids[ - start_loc:end_loc, : - ] + def get_buffer_size_bytes(self) -> int: + return self._total_allocated_bytes - def clear_buffer(self) -> None: - """Clear the device buffer.""" - if self._device_buffer is not None: - self._device_buffer.zero_() + def get_or_grow_buffer(self, req_id: str, max_pos: int) -> np.ndarray: + required_len = max_pos + 1 - def save_captured_experts(self, indices: np.ndarray) -> None: - """ - Save captured experts from device buffer to shared memory. + if req_id not in self._req_buffers: + buf = np.full( + (required_len, self.num_hidden_layers, self.num_experts_per_tok), + -1, + dtype=self.DTYPE, + ) + self._req_buffers[req_id] = buf + self._total_allocated_bytes += buf.nbytes + return buf + + buf = self._req_buffers[req_id] + if buf.shape[0] >= required_len: + return buf + + new_len = min(max(required_len, buf.shape[0] * 2), self.max_model_len) + new_buf = np.full( + (new_len, self.num_hidden_layers, self.num_experts_per_tok), + -1, + dtype=self.DTYPE, + ) + new_buf[: buf.shape[0]] = buf + self._total_allocated_bytes += new_buf.nbytes - buf.nbytes + self._req_buffers[req_id] = new_buf + return new_buf + + def get_buffer(self, req_id: str) -> np.ndarray | None: + return self._req_buffers.get(req_id) + + def update_filled_len(self, req_id: str, max_pos: int) -> None: + new_len = max_pos + 1 + self._filled_len[req_id] = max(self._filled_len.get(req_id, 0), new_len) + + def get_filled_len(self, req_id: str) -> int: + return self._filled_len.get(req_id, 0) + + def free_request(self, req_id: str) -> None: + if req_id in self._req_buffers: + self._total_allocated_bytes -= self._req_buffers.pop(req_id).nbytes + self._filled_len.pop(req_id, None) + + def _finalize_allocation_log(self): + logger.info( + f"Routing experts host cache initialized (lazy allocation). " + f"max_model_len={self.max_model_len}, " + f"layers={self.num_hidden_layers}, " + f"experts_per_tok={self.num_experts_per_tok}" + ) - Args: - indices: Array of indices indicating where to store the data. - """ - if get_tensor_model_parallel_rank() != 0: - return - if self._lock_file is None: - raise RuntimeError("Shared memory not initialized.") - if self._host_buffer_view is None: - return - if self._device_buffer is None: - raise RuntimeError("Device buffer not initialized.") - num_tokens = len(indices) - data = self._device_buffer[:num_tokens, :, :].cpu().numpy() +class RoutedExpertsCapturer(ABC): + @staticmethod + def create( + enable: bool, + model_config: ModelConfig, + num_fused_shared_experts: int, + num_batched_tokens: int, + max_running_requests: int, + max_model_len: int, + device: str, + shared_host_cache: Optional[_RoutedExpertsHostCache] = None, + skip_host_cache: bool = False, + ): + if enable: + return _RoutedExpertsCapturerReal( + model_config, + num_batched_tokens=num_batched_tokens, + max_running_requests=max_running_requests, + num_fused_shared_experts=num_fused_shared_experts, + max_model_len=max_model_len, + device=device, + shared_host_cache=shared_host_cache, + skip_host_cache=skip_host_cache, + ) + return _RoutedExpertsCapturerNoop() - with _file_lock(self._lock_file): - self._host_buffer_view[indices, :, :] = data + def capture(self, layer_id: int, topk_ids: torch.Tensor): + raise NotImplementedError - def cleanup(self) -> None: - """Explicitly clean up shared memory resources.""" - if self._shm is not None: - try: - self._shm.close() - self._shm.unlink() - except Exception: - logger.debug("Exception during cleanup for capturer", exc_info=True) - finally: - self._shm = None + def get_routed_experts( + self, req_id: str, seqlen: Optional[int] = None, free_slot: bool = True + ): + raise NotImplementedError - def __del__(self) -> None: - """Clean up shared memory on destruction.""" - self.cleanup() + def sync_fwd_experts_buffer_DtoH( + self, + positions: torch.Tensor, + num_scheduled_tokes: dict[str, int], + ): + raise NotImplementedError + def finalize_pending_copy(self): + raise NotImplementedError -class RoutedExpertsReader: - """ - Reader for routed experts from shared memory. + def get_host_cache(self): + raise NotImplementedError - This class attaches to shared memory created by RoutedExpertsCapturer - and reads expert routing decisions. - """ + def get_device_cache(self): + raise NotImplementedError - _instance: RoutedExpertsReader | None = None - def __init__(self) -> None: - self._shm: shared_memory.SharedMemory | None = None - self._host_buffer_view: np.ndarray | None = None - self._lock_file: str | None = None +class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + """Capturer with GPU device cache and CPU host cache. - @classmethod - def create(cls) -> RoutedExpertsReader: - """Create a global singleton instance.""" - global _global_experts_reader - if _global_experts_reader is not None: - raise RuntimeError("Experts reader already created.") + Performance strategy -- async D2H with optimized host-cache scatter: - _global_experts_reader = cls() - return _global_experts_reader + Every decode step we issue a non-blocking D2H copy on a dedicated + CUDA stream. The scatter into per-request host-cache buffers is + deferred to the start of the NEXT step (by which time the copy has + finished). The scatter loop is optimized with direct scalar access + to avoid numpy slice views, int() conversions, and .max() calls. - @staticmethod - def get_instance() -> RoutedExpertsReader | None: - """Get the global singleton instance.""" - if _global_experts_reader is None: - logger.info("Experts reader not initialized.") - return _global_experts_reader + At extraction time (when a request finishes), data is already in a + contiguous host buffer -- just a numpy slice, no concatenation. + """ - def attach_buffer( + def __init__( self, - max_num_kv_tokens: int, - vllm_config: VllmConfig, - ) -> None: - """ - Attach to an existing shared memory buffer. + model_config: ModelConfig, + num_batched_tokens: int, + max_running_requests: int, + num_fused_shared_experts: int, + max_model_len: int, + device: str, + shared_host_cache: Optional[_RoutedExpertsHostCache] = None, + skip_host_cache: bool = False, + ): + self.forward_batch = None + self.num_fused_shared_experts = num_fused_shared_experts + self.num_hidden_layers = model_config.hf_text_config.layers_block_type.count( + "moe" + ) + self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok + self.num_batched_tokens = num_batched_tokens + self.max_model_len = max_model_len + self._skip_host_cache = skip_host_cache + + if skip_host_cache: + self.host_cache = None + logger.info(f"Skipping host cache for device {device} (non-rank-0)") + elif shared_host_cache is not None: + self.host_cache = shared_host_cache + else: + self.host_cache = _RoutedExpertsHostCache( + max_running_requests=max_running_requests, + num_hidden_layers=self.num_hidden_layers, + num_experts_per_tok=self.num_experts_per_tok, + max_model_len=self.max_model_len, + use_shared_memory=False, + ) - Args: - max_num_kv_tokens: Maximum number of KV tokens. - vllm_config: vllm configuration. - """ - if self._shm is not None: - logger.warning("Already attached to shared memory buffer.") - return # Already attached - - hf_config = vllm_config.model_config.hf_text_config - shape = ( - max_num_kv_tokens, - hf_config.num_hidden_layers, - hf_config.num_experts_per_tok, + self.device_cache = _RoutedExpertsDeviceCache( + num_batched_tokens=self.num_batched_tokens, + num_hidden_layers=self.num_hidden_layers, + num_experts_per_tok=self.num_experts_per_tok, + num_fused_shared_experts=self.num_fused_shared_experts, + device=device, ) - self.dp_rank = vllm_config.parallel_config.data_parallel_rank - instance_id = vllm_config.instance_id - self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}_{self.dp_rank}.lock" - shm_name = f"{_BUFFER_PREFIX}_{instance_id}_{self.dp_rank}" - - with _file_lock(self._lock_file, mode="rb+"): - # Avoid resource_tracker registering the shared memory - with patch( - "multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None, - ): - self._shm = shared_memory.SharedMemory(name=shm_name) - - self._host_buffer_view = np.ndarray( - shape, dtype=np.int32, buffer=self._shm.buf + # ---- Async D2H pipeline (rank-0 only) ---- + # Non-rank-0 workers only need the device buffer for symmetric + # CUDA graph capture; they skip the D2H pipeline entirely. + self._has_pending_copy = False + self._pending_positions: np.ndarray | None = None + self._pending_num_scheduled: dict[str, int] | None = None + self._pending_total_tokens: int = 0 + + if not skip_host_cache: + # Same (L, N, K) layout as device_cache.buffer. + self._pinned_staging = torch.zeros( + (self.num_hidden_layers, num_batched_tokens, self.num_experts_per_tok), + dtype=_RoutedExpertsDeviceCache.DTYPE, + pin_memory=True, + ) + self._copy_stream = torch.cuda.Stream(device=device) + self._copy_event = torch.cuda.Event() + + pinned_mb = get_tensor_size_bytes(self._pinned_staging) / _MB + logger.info( + f"Routing experts pinned staging buffer allocated. " + f"shape={tuple(self._pinned_staging.shape)}, " + f"size={pinned_mb:.2f} MB" + ) + else: + self._pinned_staging = None + self._copy_stream = None + self._copy_event = None + logger.info( + f"Routing experts device-only capturer (rank != 0). " + f"Device buffer shape={tuple(self.device_cache.buffer.shape)}" ) - def get_routed_experts(self, indices: np.ndarray) -> np.ndarray: - """ - Read routed expert data from shared memory. + def capture(self, layer_id: int, topk_ids: torch.Tensor): + self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) + + # ------------------------------------------------------------------ + # sync_fwd_experts_buffer_DtoH -- called AFTER the forward pass + # ------------------------------------------------------------------ + + def sync_fwd_experts_buffer_DtoH( + self, + positions: torch.Tensor, + num_scheduled_tokes: dict[str, int], + ): + if self.host_cache is None: + return + + # 1. Finalize previous async copy -- the copy had an entire + # forward pass to complete so event.synchronize() is ~free. + if self._has_pending_copy: + self._copy_event.synchronize() + self._scatter_to_host() + self._has_pending_copy = False + + total_tokens = sum(num_scheduled_tokes.values()) + if total_tokens == 0: + return + + # 2. Issue new async D2H copy on a dedicated stream. + # Device buffer layout is (L, N, K); copy the first total_tokens + # along the N dimension for every layer. + main_stream = torch.cuda.current_stream(self._copy_stream.device) + with torch.cuda.stream(self._copy_stream): + self._copy_stream.wait_stream(main_stream) + self._pinned_staging[:, :total_tokens, :].copy_( + self.device_cache.buffer[:, :total_tokens, :], non_blocking=True + ) + self._copy_event.record() + + # 3. Save metadata for deferred scatter. + self._pending_positions = positions.numpy().copy() + self._pending_num_scheduled = num_scheduled_tokes + self._pending_total_tokens = total_tokens + self._has_pending_copy = True - Args: - indices: Array of indices to read. + # ------------------------------------------------------------------ + # Optimized scatter into pre-allocated host-cache buffers + # ------------------------------------------------------------------ - Returns: - Copy of the expert routing data for the given indices. + def _scatter_to_host(self): + """Scatter D2H data into per-request host cache buffers. + + Staging layout is (L, N, K). Host cache layout is (seq_len, L, K). + We transpose the staging slice to (N, L, K) before scattering so + that indexing by token position naturally yields (L, K) rows. """ - if self._host_buffer_view is None: - raise RuntimeError("Buffer not attached. Call attach_buffer() first.") - if self._lock_file is None: - raise RuntimeError("Lock file not initialized.") + # Transpose (L, N, K) -> (N, L, K) for the active token range. + host_values = ( + self._pinned_staging[:, : self._pending_total_tokens, :] + .numpy() + .transpose(1, 0, 2) + ) + positions_np = self._pending_positions + host_cache = self.host_cache + + offset = 0 + for req_id, n_tokens in self._pending_num_scheduled.items(): + if n_tokens == 0: + continue + + if n_tokens == 1: + pos_val = int(positions_np[offset]) + buf = host_cache.get_or_grow_buffer(req_id, pos_val) + buf[pos_val] = host_values[offset] + host_cache.update_filled_len(req_id, pos_val) + else: + pos = positions_np[offset : offset + n_tokens] + max_pos = int(pos[-1]) if n_tokens > 0 else 0 + if n_tokens > 1: + max_pos = int(pos.max()) + buf = host_cache.get_or_grow_buffer(req_id, max_pos) + buf[pos] = host_values[offset : offset + n_tokens] + host_cache.update_filled_len(req_id, max_pos) + + offset += n_tokens + + self._pending_positions = None + self._pending_num_scheduled = None + self._pending_total_tokens = 0 + + # ------------------------------------------------------------------ + # finalize_pending_copy -- call before reading host cache + # ------------------------------------------------------------------ + + def finalize_pending_copy(self): + """Ensure the most recent async D2H copy has been scattered into + host cache buffers. Call before get_routed_experts.""" + if self._has_pending_copy: + self._copy_event.synchronize() + self._scatter_to_host() + self._has_pending_copy = False + + # ------------------------------------------------------------------ + # Extraction -- O(1), just a numpy slice + # ------------------------------------------------------------------ + + def get_routed_experts( + self, + req_id: str, + seqlen: int | None = None, + free_slot: bool = True, + ): + if self.host_cache is None: + return None + buf = self.host_cache.get_buffer(req_id) + if buf is None: + return None + filled = self.host_cache.get_filled_len(req_id) + if filled <= 0: + return None + effective_len = min(filled, seqlen) if seqlen is not None else filled + result = buf[:effective_len].copy() + if free_slot: + self.host_cache.free_request(req_id) + return result + + def get_host_cache(self): + return self.host_cache + + def get_device_cache(self): + return self.device_cache + + +class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): + def __init__(self): + pass + + def capture(self, layer_id: int, topk_ids: torch.Tensor): + pass + + def get_routed_experts(self, req_id: str, seqlen=None, free_slot=True): + return None + + def sync_fwd_experts_buffer_DtoH(self, positions, num_scheduled_tokes): + pass - with _file_lock(self._lock_file, mode="rb+"): - return self._host_buffer_view[indices, :, :].copy() + def finalize_pending_copy(self): + pass + + def get_host_cache(self): + return None + + def get_device_cache(self): + pass + + +# Global capturer instance (per-process) +_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop() +_shared_host_cache: Optional[_RoutedExpertsHostCache] = None + + +def get_global_experts_capturer(): + return _global_expert_capturer + + +def set_global_experts_capturer(capturer: RoutedExpertsCapturer): + global _global_expert_capturer + _global_expert_capturer = capturer + + +def get_shared_host_cache() -> Optional[_RoutedExpertsHostCache]: + return _shared_host_cache + + +def create_shared_host_cache( + model_config: ModelConfig, + max_running_requests: int, + max_model_len: int, +) -> _RoutedExpertsHostCache: + global _shared_host_cache + num_hidden_layers = model_config.hf_text_config.layers_block_type.count("moe") + num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok + _shared_host_cache = _RoutedExpertsHostCache( + max_running_requests=max_running_requests, + num_hidden_layers=num_hidden_layers, + num_experts_per_tok=num_experts_per_tok, + max_model_len=max_model_len, + use_shared_memory=False, + ) + return _shared_host_cache + + +def init_routed_experts_capturer_with_shared_cache( + enable: bool, + model_config: ModelConfig, + num_fused_shared_experts: int, + num_batched_tokens: int, + max_running_requests: int, + max_model_len: int, + device: str, + rank: int = 0, + world_size: int = 1, +) -> RoutedExpertsCapturer: + """Initialize capturer with rank-aware handling (only rank 0 captures).""" + if not enable: + capturer = _RoutedExpertsCapturerNoop() + set_global_experts_capturer(capturer) + return capturer + + if world_size > 1 and rank != 0: + # Non-rank-0 workers get a device-only capturer (no host cache, + # no D2H pipeline) so that ALL ranks have a real device buffer. + # This ensures the custom op call in every MoE layer produces + # identical CUDA graph structure across TP ranks. + logger.info(f"Creating device-only routed experts capturer for rank {rank}") + capturer = RoutedExpertsCapturer.create( + enable=True, + model_config=model_config, + num_fused_shared_experts=num_fused_shared_experts, + num_batched_tokens=num_batched_tokens, + max_running_requests=max_running_requests, + max_model_len=max_model_len, + device=device, + skip_host_cache=True, + ) + set_global_experts_capturer(capturer) + return capturer + + capturer = RoutedExpertsCapturer.create( + enable=True, + model_config=model_config, + num_fused_shared_experts=num_fused_shared_experts, + num_batched_tokens=num_batched_tokens, + max_running_requests=max_running_requests, + max_model_len=max_model_len, + device=device, + skip_host_cache=False, + ) + set_global_experts_capturer(capturer) + return capturer + + +def bind_routing_capture_to_model(model) -> None: + """Bind routing capture buffers to all FusedMoE layers in the model. + + Must be called AFTER init_routed_experts_capturer_with_shared_cache() + and BEFORE CUDA graph capture. All TP ranks get a real buffer so + that the custom op call produces identical graph structure. + """ + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + capturer = get_global_experts_capturer() + device_cache = capturer.get_device_cache() + if device_cache is None: + return # routing capture not enabled + + buffer = device_cache.buffer + + # Mark the buffer so CUDA graphs do NOT snapshot/restore its contents. + if hasattr(torch.compiler, "cudagraph_mark_tensor_static"): + torch.compiler.cudagraph_mark_tensor_static(buffer) + elif hasattr(torch._C, "_set_static_address_tag"): + torch._C._set_static_address_tag(buffer, True) + try: + torch._dynamo.mark_static_address(buffer) + except Exception: + pass - def cleanup(self) -> None: - """Explicitly clean up resources (close without unlink).""" - if self._shm is not None: + bound = 0 + for module in model.modules(): + if isinstance(module, FusedMoE) and hasattr(module, "moe_layer_id"): + layer_id = module.moe_layer_id + layer_buf = buffer[layer_id] # (N_max, K) + module._routing_replay_out = layer_buf + # Mark each per-layer view as static so CUDA graphs don't + # snapshot/restore or relocate the buffer during replay. + if hasattr(torch.compiler, "cudagraph_mark_tensor_static"): + torch.compiler.cudagraph_mark_tensor_static(layer_buf) try: - self._shm.close() + torch._dynamo.mark_static_address(layer_buf) except Exception: - logger.debug("Exception during cleanup for reader", exc_info=True) - finally: - self._shm = None + pass + bound += 1 - def __del__(self) -> None: - """Close shared memory on destruction (do not unlink).""" - self.cleanup() + logger.info( + f"Bound routing capture buffer to {bound} FusedMoE layers. " + f"Buffer shape={tuple(buffer.shape)}" + ) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py index 692d45d34607..5d3a61286b5c 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py @@ -395,18 +395,46 @@ def _apply_quant_method( shared_experts_input, SharedExpertsOrder.NO_OVERLAP ) + # Get routing replay buffer from persistent layer attribute + # (set by bind_routing_capture_to_model during capturer init) + routing_replay_out = getattr(layer, "_routing_replay_out", None) + if self.quant_method.is_monolithic: fused_out = self.quant_method.apply_monolithic( layer=layer, x=hidden_states, router_logits=router_logits, + routing_replay_out=routing_replay_out, ) + # BF16 monolithic: kernel does not write routing data internally, + # so we run select_experts() separately to capture it. + if ( + routing_replay_out is not None + and not getattr( + self.quant_method, + "_monolithic_writes_routing_replay", + False, + ) + ): + _, topk_ids = self.router.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + routing_replay_out[: topk_ids.shape[0]].copy_( + topk_ids.to(torch.int16) + ) else: topk_weights, topk_ids = self.router.select_experts( hidden_states=hidden_states, router_logits=router_logits, ) + # Write routing data for non-monolithic path (Triton, etc.) + if routing_replay_out is not None: + routing_replay_out[: topk_ids.shape[0]].copy_( + topk_ids.to(torch.int16) + ) + # Passing shared_experts_input in case SharedExpertsOrder is # NO_OVERLAP or MK_INTERNAL_OVERLAPPED. fused_out = self.quant_method.apply( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d7920462e613..657bbd56e6da 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -562,6 +562,7 @@ def process_weights_after_loading(self, layer: Module) -> None: class Fp8MoEMethod(FusedMoEMethodBase): + _monolithic_writes_routing_replay = True """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0b8ad0cbc1ed..60bd0776d2d8 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -737,6 +737,7 @@ def apply( class ModelOptFp8MoEMethod(FusedMoEMethodBase): + _monolithic_writes_routing_replay = True """MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. @@ -948,6 +949,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + routing_replay_out: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None @@ -964,6 +966,7 @@ def apply_monolithic( topk_group=layer.topk_group, e_score_correction_bias=layer.e_score_correction_bias, routed_scaling_factor=layer.routed_scaling_factor, + routing_replay_out=routing_replay_out, ) def apply( @@ -1206,6 +1209,7 @@ def apply( class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): + _monolithic_writes_routing_replay = True """ MoE Method for FP4 Quantization. Args: @@ -1440,6 +1444,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + routing_replay_out: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None @@ -1456,6 +1461,7 @@ def apply_monolithic( topk_group=layer.topk_group, e_score_correction_bias=layer.e_score_correction_bias, routed_scaling_factor=layer.routed_scaling_factor, + routing_replay_out=routing_replay_out, ) def apply( @@ -1670,6 +1676,7 @@ def apply( class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + _monolithic_writes_routing_replay = True """FlashInfer TRTLLM MXFP8 block-scale MoE for ModelOpt checkpoints.""" def __init__( @@ -1918,6 +1925,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + routing_replay_out: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from flashinfer.fused_moe.core import ( ActivationType, @@ -1986,6 +1994,7 @@ def apply_monolithic( use_shuffled_weight=True, weight_layout=0, fp8_quantization_type=Fp8QuantizationType.MxFp8, + routing_replay_out=routing_replay_out, ) if fi_activation_type != ActivationType.Swiglu: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 40b5899f0457..6da064e6d6a9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -27,9 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( - RoutedExpertsReader, -) +# RoutedExpertsReader removed - routing data flows through ModelRunnerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.encoder_budget import MultiModalBudget from vllm.v1.core.encoder_cache_manager import ( @@ -52,7 +50,7 @@ ) from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -256,43 +254,7 @@ def __init__( if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: self.perf_metrics = ModelMetrics(vllm_config) - if self.vllm_config.model_config.enable_return_routed_experts: - assert self.dcp_world_size == 1 and self.pcp_world_size == 1, ( - "enable_return_routed_experts does not support context parallelism " - "(dcp_world_size > 1 or pcp_world_size > 1)" - ) - - self.routed_experts_reader = RoutedExpertsReader.create() - - assert len(kv_cache_config.kv_cache_groups) > 0, ( - "enable_return_routed_experts requires at least one kv cache group" - ) - # Find the attention group for routed experts indexing. - self.routed_experts_attn_gid = 0 - for gid, group in enumerate(kv_cache_config.kv_cache_groups): - if isinstance(group.kv_cache_spec, AttentionSpec): - self.routed_experts_attn_gid = gid - break - min_block_size = min( - [ - group.kv_cache_spec.block_size - for group in kv_cache_config.kv_cache_groups - ] - ) - num_groups = len(kv_cache_config.kv_cache_groups) - self.max_num_kv_tokens = ( - kv_cache_config.num_blocks // num_groups - ) * min_block_size - dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size - pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size - if pcp_size * dcp_size > 1: - self.max_num_kv_tokens *= pcp_size * dcp_size - - self.routed_experts_reader.attach_buffer( - max_num_kv_tokens=self.max_num_kv_tokens, - vllm_config=self.vllm_config, - ) - + # Routed experts init removed - data flows through ModelRunnerOutput self._pause_state: PauseState = PauseState.UNPAUSED def _mamba_block_aligned_split( @@ -1424,10 +1386,14 @@ def update_from_output( request.resumable = False stopped = True + # Get routing data from ModelRunnerOutput (via worker D2H pipeline) routed_experts = None + if (model_runner_output.routed_experts_dict is not None + and req_id in model_runner_output.routed_experts_dict): + routed_experts = model_runner_output.routed_experts_dict[req_id] finish_reason = None if stopped: - routed_experts = self._get_routed_experts(request) + pass # routing data already extracted above # Capture finish_reason BEFORE _handle_stopped_request, which may # reset the status to WAITING for streaming requests that continue. @@ -1606,27 +1572,8 @@ def _handle_stopped_request(self, request: Request) -> bool: def _get_routed_experts(self, request: Request) -> np.ndarray | None: if not self.vllm_config.model_config.enable_return_routed_experts: return None - - kv_blocks = self.kv_cache_manager.get_blocks(request.request_id) - block_ids = kv_blocks.get_block_ids()[self.routed_experts_attn_gid] - num_tokens = request.num_tokens - 1 - - # compute slot mapping using attention group's block_size - block_ids_array = np.array(block_ids, dtype=np.int32) - num_blocks = len(block_ids) - attn_group = self.kv_cache_config.kv_cache_groups[self.routed_experts_attn_gid] - block_size = attn_group.kv_cache_spec.block_size - - # generate block offsets - block_offsets = np.arange(0, block_size) - - # compute slot mapping: slot = block_id * block_size + offset - slot_mapping = ( - block_offsets.reshape((1, block_size)) - + block_ids_array.reshape((num_blocks, 1)) * block_size - ).flatten()[:num_tokens] - - return self.routed_experts_reader.get_routed_experts(indices=slot_mapping) + # Routing data stored per-request from ModelRunnerOutput + return getattr(request, '_routed_experts_data', None) def _update_request_with_output( self, request: Request, new_token_ids: list[int] diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d5c5dba63475..7fa9236bd202 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -174,7 +174,7 @@ class EngineCoreOutput( prefill_stats: PrefillStats | None = None - routed_experts: np.ndarray | None = None + routed_experts: tuple | None = None # The number of NaNs in logits. # A value greater than 0 indicates that the output is corrupted. num_nans_in_logits: int = 0 diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1ae89ae19680..1efba1ab7eaf 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -618,6 +618,11 @@ def process_outputs( kv_transfer_params = engine_core_output.kv_transfer_params routed_experts = engine_core_output.routed_experts + if routed_experts is not None: + shape, data = routed_experts + routed_experts = np.frombuffer(data, dtype=np.int16).reshape( + shape) + if req_state.is_prefilling: if engine_core_output.prefill_stats is not None: req_state.num_cached_tokens = ( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 1f102ec61783..b9b7fea39d10 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -198,6 +198,9 @@ class ModelRunnerOutput: # req_id -> num_nans_in_logits num_nans_in_logits: dict[str, int] | None = None + # req_id -> routed experts data (shape, bytes) tuples + routed_experts_dict: dict[str, tuple] | None = None + # information related to cudagraph execution cudagraph_stats: CUDAGraphStat | None = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c573f79e97a..b7d4fd7bc324 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -54,7 +54,8 @@ from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( - RoutedExpertsCapturer, + get_global_experts_capturer, + init_routed_experts_capturer_with_shared_cache, ) from vllm.model_executor.layers.mamba.ops.ssu_dispatch import ( initialize_mamba_ssu_backend, @@ -2144,10 +2145,7 @@ def _get_block_table(kv_cache_gid: int): block_table_gid_0 = _get_block_table(0) slot_mapping_gid_0 = slot_mappings[0] - if self.routed_experts_initialized: - attn_gid = self.routed_experts_attn_gid - slot_mapping_attn = slot_mappings[attn_gid] - self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() + # routing replay uses device cache approach (no slot_mapping needed) num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ :num_reqs_padded ] @@ -3007,6 +3005,86 @@ def get_model(self) -> nn.Module: return self.model.unwrap() return self.model + def _extract_routed_experts_for_current_batch( + self, + req_ids: list[str], + ) -> dict[str, tuple] | None: + """Extract routed experts for requests predicted to finish this step. + + Checks all stop conditions the scheduler will check (max_tokens, + EOS token, stop tokens, max_model_len) so that every finished + request gets its routing data attached to the ModelRunnerOutput. + + Note: We don't free buffers here -- that happens in + ``_update_states`` for requests that finished in the PREVIOUS step. + """ + capturer = get_global_experts_capturer() + if capturer is None: + return None + host_cache = capturer.get_host_cache() + if host_cache is None: + return None + + finishing_req_ids: list[str] = [] + for req_id in req_ids: + req_state = self.requests.get(req_id) + if req_state is None: + continue + sp = req_state.sampling_params + if sp is None: + continue + output_ids = req_state.output_token_ids + if not output_ids: + continue + if len(output_ids) < sp.min_tokens: + continue + + finishing = False + last_token = output_ids[-1] + + # EOS token (mirrors check_stop: eos_token_id is None + # when ignore_eos=True, so this naturally respects that) + if last_token == sp.eos_token_id: + finishing = True + + # Explicit stop token IDs + if not finishing and sp.stop_token_ids and last_token in sp.stop_token_ids: + finishing = True + + # max_tokens / max_model_len length cap + if not finishing: + if sp.max_tokens is not None and len(output_ids) >= sp.max_tokens: + finishing = True + else: + req_idx = self.input_batch.req_id_to_index.get(req_id) + if req_idx is not None: + total = self.input_batch.num_tokens_no_spec[req_idx] + if total >= self.max_model_len: + finishing = True + + if finishing: + finishing_req_ids.append(req_id) + + if not finishing_req_ids: + return None + + # At least one request is finishing: ensure the latest async D2H + # copy has been scattered into the host cache. + capturer.finalize_pending_copy() + + result: dict[str, tuple] = {} + for req_id in finishing_req_ids: + seqlen = host_cache.get_filled_len(req_id) + if seqlen <= 0: + continue + experts = capturer.get_routed_experts( + req_id, seqlen=seqlen, free_slot=False + ) + if experts is not None: + result[req_id] = (experts.shape, experts.tobytes()) + + return result if result else None + def get_supported_generation_tasks(self) -> list[GenerationTask]: model = self.get_model() supported_tasks = list[GenerationTask]() @@ -3771,11 +3849,9 @@ def execute_model( ) if self.routed_experts_initialized: - capturer = RoutedExpertsCapturer.get_instance() + capturer = get_global_experts_capturer() if capturer is not None: - capturer.clear_buffer() # noqa - else: - logger.error("RoutedExpertsCapturer not initialized.") + capturer.finalize_pending_copy() # If ngram_gpu is used, we need to copy the scheduler_output to avoid # the modification has influence on the scheduler_output in engine core process. @@ -4285,6 +4361,25 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, ) + # Issue async D2H copy of routed experts EARLY so that the copy + # overlaps with eplb, kv_connector finalization, and draft work. + # finalize_pending_copy() + get_routed_experts() happen later in + # _extract_routed_experts_for_current_batch(). + if self.routed_experts_initialized: + capturer = get_global_experts_capturer() + if capturer is not None: + ordered_num_scheduled = { + req_id: scheduler_output.num_scheduled_tokens[req_id] + for req_id in self.input_batch.req_ids + if req_id in scheduler_output.num_scheduled_tokens + } + n = sum(ordered_num_scheduled.values()) + self._positions_cpu[:n].copy_(self.positions[:n]) + capturer.sync_fwd_experts_buffer_DtoH( + positions=self._positions_cpu[:n], + num_scheduled_tokes=ordered_num_scheduled, + ) + if propose_drafts_after_bookkeeping: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. @@ -4304,12 +4399,11 @@ def propose_draft_token_ids(sampled_token_ids): self.kv_connector_output = None with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + routed_experts_dict = None if self.routed_experts_initialized: - capturer = RoutedExpertsCapturer.get_instance() - if capturer is not None: - capturer.save_captured_experts(indices=self.slot_mapping) # noqa - else: - logger.error("RoutedExpertsCapturer not initialized.") + routed_experts_dict = ( + self._extract_routed_experts_for_current_batch( + req_ids_output_copy)) output = ModelRunnerOutput( req_ids=req_ids_output_copy, @@ -4323,6 +4417,7 @@ def propose_draft_token_ids(sampled_token_ids): else None, num_nans_in_logits=num_nans_in_logits, cudagraph_stats=cudagraph_stats, + routed_experts_dict=routed_experts_dict, ) if not self.use_async_scheduling: @@ -5992,6 +6087,7 @@ def capture_model(self) -> int: "Skipping CUDA graph capture. To turn on CUDA graph capture, " "ensure `cudagraph_mode` was not manually set to `NONE`" ) + self.init_routed_experts_capturer() return 0 # Initialize encoder CUDA graph manager if enabled. @@ -6025,6 +6121,13 @@ def capture_model(self) -> int: start_time = time.perf_counter() + # Initialize the routed experts capturer once before any CUDA graph + # capture. Must happen before graphs are captured so the buffer + # address is baked into the graph. Do NOT call this inside + # _capture_cudagraphs() -- creating the capturer twice replaces + # the device buffer, causing graphs to write to a dead buffer. + self.init_routed_experts_capturer() + # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. @@ -6790,45 +6893,47 @@ def init_routed_experts_capturer(self): "Initializing routed experts capturer, enable_return_routed_experts: %s", self.model_config.enable_return_routed_experts, ) - routed_experts_capturer = RoutedExpertsCapturer.create() - self.routed_experts_attn_gid = self._get_attention_kv_cache_gid() - min_block_size = min( - [ - group.kv_cache_spec.block_size - for group in self.kv_cache_config.kv_cache_groups - ] - ) - num_groups = len(self.kv_cache_config.kv_cache_groups) - self.max_num_kv_tokens = ( - self.kv_cache_config.num_blocks // num_groups - ) * min_block_size - dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size - pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size - if pcp_size * dcp_size > 1: - self.max_num_kv_tokens *= pcp_size * dcp_size - - routed_experts_capturer.init_buffer( - max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, - max_num_kv_tokens=self.max_num_kv_tokens, - vllm_config=self.vllm_config, - ) - self._bind_routed_experts_capturer(routed_experts_capturer) - self.routed_experts_initialized = True + from vllm.distributed import get_tp_group - def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None: - from vllm.model_executor.layers.fused_moe.layer import FusedMoE - from vllm.model_executor.layers.fused_moe.router.base_router import ( - BaseRouter, + max_running_requests = ( + self.max_num_tokens // 2 + if self.max_num_reqs is None + else self.max_num_reqs + // self.vllm_config.parallel_config.data_parallel_size ) - for module in self.compilation_config.static_forward_context.values(): - if isinstance(module, FusedMoE) and isinstance(module.router, BaseRouter): - layer_id = module.layer_id + if hasattr(self.model_config.hf_text_config, "n_shared_experts"): + num_fused_shared_experts = 1 + else: + num_fused_shared_experts = 0 - def _capture_fn(topk_ids, _layer_id=layer_id, _capturer=capturer): - _capturer.capture(_layer_id, topk_ids) + tp_group = get_tp_group() + init_routed_experts_capturer_with_shared_cache( + enable=self.model_config.enable_return_routed_experts, + model_config=self.model_config, + num_fused_shared_experts=num_fused_shared_experts, + num_batched_tokens=self.scheduler_config.max_num_batched_tokens, + max_running_requests=max_running_requests, + max_model_len=self.max_model_len, + device=self.device, + rank=tp_group.rank_in_group, + world_size=tp_group.world_size, + ) + self._bind_routed_experts_capturer() + self.routed_experts_initialized = True - module.router.set_capture_fn(_capture_fn) + # Pinned CPU buffer for async positions D2H (avoids sync .cpu() call) + self._positions_cpu = torch.empty( + self.scheduler_config.max_num_batched_tokens, + dtype=torch.long, + pin_memory=True, + ) + + def _bind_routed_experts_capturer(self) -> None: + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + bind_routing_capture_to_model, + ) + bind_routing_capture_to_model(self.model) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """