Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/training/routed_experts_replay.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ When a request has multiple completions (`n > 1`), each completion shares the sa
```text
Forward Pass Async D2H Pipeline Output
───────────── ────────────────── ──────
FusedMoE layer After forward pass: On request finish:
FusedMoERouter After forward pass: On request finish:
writes topk_ids ──────► D2H copy to pinned ──────► Extract from host cache
to device buffer staging buffer Split at prompt_len
(L, N, K) int16 (via CUDA stream) Trim gen to output len
Expand All @@ -125,7 +125,9 @@ A pre-allocated GPU buffer with layout `(L, N, K)` where:
- `N` = `max_num_batched_tokens`
- `K` = `num_experts_per_tok` (top-k)

The `(L, N, K)` layout ensures that `buffer[layer_id]` gives a contiguous `(N, K)` view per layer. Each `FusedMoE` layer gets a persistent reference to its slice via `module._routing_replay_out = buffer[layer_id]`.
The `(L, N, K)` layout ensures that `buffer[layer_id]` gives a contiguous `(N, K)` view per layer. Each `FusedMoERouter` layer gets a persistent reference to its slice via `router._routing_replay_out = buffer[layer_id]`.

The `layer_id`s are managed by `RoutedExpertsCapturer` and keyed by `FusedMoE.layer_name`.

**Dtype**: `int16` — sufficient for expert IDs (max ~512 experts in practice) and half the memory of `int32`.

Expand All @@ -146,7 +148,7 @@ This design ensures the D2H copy overlaps with the next forward pass, minimizing

CUDA graph compatibility requires two mechanisms:

1. **Persistent tensor attribute**: Each `FusedMoE` layer stores a reference to its buffer slice as `module._routing_replay_out`. Because `torch.compile` captures module attributes by reference, graph replay always writes to the live buffer — not a stale snapshot.
1. **Persistent tensor attribute**: Each `FusedMoERouter` stores a reference to its buffer slice as `router._routing_replay_out`. Because `torch.compile` captures module attributes by reference, graph replay always writes to the live buffer — not a stale snapshot.

2. **Static marking**: Both the full `(L, N, K)` buffer and each per-layer `(N, K)` view are marked with `cudagraph_mark_tensor_static()`. This prevents CUDA graphs from snapshot/restore behavior that would zero the buffer on replay.

Expand Down
21 changes: 13 additions & 8 deletions tests/model_executor/test_routed_experts_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class _DummyMoEConfig:
class _DummyQuantMethod:
supports_internal_mk = True

class DummyFusedMoE:
_routing_replay_out: torch.Tensor
class _DummyRouter:
_routing_replay_out: torch.Tensor | None = None

def __init__(self, moe_layer_id):
self.moe_layer_id = moe_layer_id
class DummyFusedMoE:
def __init__(self, name: str):
self.layer_name = name
self.moe_config = _DummyMoEConfig()
self.quant_method = _DummyQuantMethod()
self.router = _DummyRouter()

monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE)

Expand All @@ -39,19 +41,22 @@ class DummyCapturer:
def get_device_cache(self):
return DummyDeviceCache(buffer)

def map_layer_to_id(self, name: str) -> int:
return int(name)

monkeypatch.setattr(rec_mod, "get_global_experts_capturer", lambda: DummyCapturer())

m0 = DummyFusedMoE(moe_layer_id=0)
m2 = DummyFusedMoE(moe_layer_id=2)
m0 = DummyFusedMoE(name="0")
m2 = DummyFusedMoE(name="2")

class DummyModel:
def modules(self):
return iter([m0, m2])

rec_mod.bind_routing_capture_to_model(DummyModel())

assert torch.equal(m0._routing_replay_out, buffer[0])
assert torch.equal(m2._routing_replay_out, buffer[2])
assert torch.equal(m0.router._routing_replay_out, buffer[0])
assert torch.equal(m2.router._routing_replay_out, buffer[2])


def test_bind_routing_capture_to_model_noop_when_disabled(monkeypatch):
Expand Down
7 changes: 0 additions & 7 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ class FusedMoE(PluggableLayer):
not supported by the router (or the experts).
"""

# Auto-incrementing layer ID for routing replay buffer binding.
_next_moe_layer_id: int = 0

# --8<-- [end:fused_moe]

def __init__(
Expand Down Expand Up @@ -148,10 +145,6 @@ def __init__(
):
super().__init__()

# Assign unique layer ID for routing replay buffer binding.
self.moe_layer_id = FusedMoE._next_moe_layer_id
FusedMoE._next_moe_layer_id += 1

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
Expand Down
21 changes: 18 additions & 3 deletions vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def get_host_cache(self):
def get_device_cache(self):
raise NotImplementedError

def map_layer_to_id(self, layer_name: str) -> int:
raise NotImplementedError


def _count_moe_layers(hf_config) -> int:
"""Count the number of MoE layers in a model.
Expand Down Expand Up @@ -293,6 +296,8 @@ def __init__(
device=device,
)

self._id_map: dict[str, int] = {}

# ---- Async D2H pipeline (rank-0 only) ----
# Non-rank-0 workers only need the device buffer for symmetric
# CUDA graph capture; they skip the D2H pipeline entirely.
Expand Down Expand Up @@ -476,6 +481,13 @@ def get_host_cache(self):
def get_device_cache(self):
return self.device_cache

def map_layer_to_id(self, layer_name: str) -> int:
if layer_name not in self._id_map:
next_id = len(self._id_map)
self._id_map[layer_name] = next_id
return next_id
return self._id_map[layer_name]


class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer):
def __init__(self):
Expand All @@ -499,6 +511,9 @@ def get_host_cache(self):
def get_device_cache(self):
pass

def map_layer_to_id(self, layer_name: str) -> int:
return 0


# Global capturer instance (per-process)
_global_expert_capturer: RoutedExpertsCapturer | None = _RoutedExpertsCapturerNoop()
Expand Down Expand Up @@ -794,7 +809,7 @@ def bind_routing_capture_to_model(model) -> None:

bound = 0
for module in model.modules():
if isinstance(module, FusedMoE) and hasattr(module, "moe_layer_id"):
if isinstance(module, FusedMoE):
# Per-FusedMoE configurations not yet validated for routing
# capture. These signals are only set after model init, so a
# config-level guard cannot see them.
Expand All @@ -815,9 +830,9 @@ def bind_routing_capture_to_model(model) -> None:
f"dp_size={module.moe_config.dp_size})."
)

layer_id = module.moe_layer_id
layer_id = capturer.map_layer_to_id(module.layer_name)
layer_buf = buffer[layer_id] # (N_max, K)
module._routing_replay_out = layer_buf
module.router._routing_replay_out = layer_buf
# Mark each per-layer view as static so CUDA graphs don't
# snapshot/restore or relocate the buffer during replay.
if hasattr(torch.compiler, "cudagraph_mark_tensor_static"):
Expand Down
11 changes: 1 addition & 10 deletions vllm/model_executor/layers/fused_moe/router/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,6 @@ def __init__(
self.global_num_experts = global_num_experts
self.eplb_state = eplb_state
self.indices_type_getter = indices_type_getter
self.capture_fn: Callable[[torch.Tensor], None] | None = None

def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> None:
"""Set a capture callback for logical routed expert IDs."""
self.capture_fn = capture_fn

def _validate_eplb_state(self) -> None:
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
Expand Down Expand Up @@ -247,7 +242,7 @@ def _compute_routing(
"""
raise NotImplementedError

def select_experts(
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand Down Expand Up @@ -285,10 +280,6 @@ def select_experts(
hidden_states, router_logits, indices_type, input_ids=input_ids
)

# Capture logical ids before EPLB mapping.
if self.capture_fn is not None:
self.capture_fn(topk_ids)

# Step 4: Apply EPLB mapping
topk_ids = self._apply_eplb_mapping(topk_ids)

Expand Down
34 changes: 26 additions & 8 deletions vllm/model_executor/layers/fused_moe/router/fused_moe_router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable

import torch

Expand All @@ -14,19 +13,24 @@ class FusedMoERouter(ABC):
method that is used for routing hidden states based on router logits.
"""

@abstractmethod
def set_capture_fn(
self,
capture_fn: Callable[[torch.Tensor], None] | None,
) -> None:
raise NotImplementedError
def __init__(self):
self._routing_replay_out: torch.Tensor | None = None

@property
@abstractmethod
def routing_method_type(self) -> RoutingMethodType:
raise NotImplementedError

@abstractmethod
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def select_experts(
self,
hidden_states: torch.Tensor,
Expand All @@ -47,4 +51,18 @@ def select_experts(
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
raise NotImplementedError

topk_weights, topk_ids = self._select_experts(
hidden_states,
router_logits,
input_ids=input_ids,
)

# Write routing data for non-monolithic path (Triton, etc.)
# (set by bind_routing_capture_to_model during capturer init)
if self._routing_replay_out is not None:
self._routing_replay_out[: topk_ids.shape[0]].copy_(
topk_ids.to(torch.int16)
)

return topk_weights, topk_ids
8 changes: 0 additions & 8 deletions vllm/model_executor/layers/fused_moe/runner/moe_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,6 @@ def _apply_quant_method(
shared_experts_input, SharedExpertsOrder.NO_OVERLAP
)

# Get routing replay buffer from persistent layer attribute
# (set by bind_routing_capture_to_model during capturer init)
routing_replay_out = getattr(layer, "_routing_replay_out", None)

if self._quant_method.is_monolithic:
fused_out = self._quant_method.apply_monolithic(
layer=layer,
Expand All @@ -529,10 +525,6 @@ def _apply_quant_method(
input_ids=input_ids,
)

# Write routing data for non-monolithic path (Triton, etc.)
if routing_replay_out is not None:
routing_replay_out[: topk_ids.shape[0]].copy_(topk_ids.to(torch.int16))

# Passing shared_experts_input in case SharedExpertsOrder is
# MK_INTERNAL_OVERLAPPED.
fused_out = self._quant_method.apply(
Expand Down
Loading