Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/training/routed_experts_replay.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ When a request has multiple completions (`n > 1`), each completion shares the sa
```text
Forward Pass Async D2H Pipeline Output
───────────── ────────────────── ──────
FusedMoE layer After forward pass: On request finish:
FusedMoERouter After forward pass: On request finish:
writes topk_ids ──────► D2H copy to pinned ──────► Extract from host cache
to device buffer staging buffer Split at prompt_len
(L, N, K) int16 (via CUDA stream) Trim gen to output len
Expand All @@ -125,7 +125,9 @@ A pre-allocated GPU buffer with layout `(L, N, K)` where:
- `N` = `max_num_batched_tokens`
- `K` = `num_experts_per_tok` (top-k)

The `(L, N, K)` layout ensures that `buffer[layer_id]` gives a contiguous `(N, K)` view per layer. Each `FusedMoE` layer gets a persistent reference to its slice via `module._routing_replay_out = buffer[layer_id]`.
The `(L, N, K)` layout ensures that `buffer[layer_id]` gives a contiguous `(N, K)` view per layer. Each `FusedMoERouter` layer gets a persistent reference to its slice via `router._routing_replay_out = buffer[layer_id]`.

The `layer_id`s are managed by `RoutedExpertsCapturer` and keyed by `FusedMoE.layer_name`.

**Dtype**: `int16` — sufficient for expert IDs (max ~512 experts in practice) and half the memory of `int32`.

Expand All @@ -146,7 +148,7 @@ This design ensures the D2H copy overlaps with the next forward pass, minimizing

CUDA graph compatibility requires two mechanisms:

1. **Persistent tensor attribute**: Each `FusedMoE` layer stores a reference to its buffer slice as `module._routing_replay_out`. Because `torch.compile` captures module attributes by reference, graph replay always writes to the live buffer — not a stale snapshot.
1. **Persistent tensor attribute**: Each `FusedMoERouter` stores a reference to its buffer slice as `router._routing_replay_out`. Because `torch.compile` captures module attributes by reference, graph replay always writes to the live buffer — not a stale snapshot.

2. **Static marking**: Both the full `(L, N, K)` buffer and each per-layer `(N, K)` view are marked with `cudagraph_mark_tensor_static()`. This prevents CUDA graphs from snapshot/restore behavior that would zero the buffer on replay.

Expand Down
21 changes: 13 additions & 8 deletions tests/model_executor/test_routed_experts_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class _DummyMoEConfig:
class _DummyQuantMethod:
supports_internal_mk = True

class DummyFusedMoE:
_routing_replay_out: torch.Tensor
class _DummyRouter:
_routing_replay_out: torch.Tensor | None = None

def __init__(self, moe_layer_id):
self.moe_layer_id = moe_layer_id
class DummyFusedMoE:
def __init__(self, name: str):
self.layer_name = name
self.moe_config = _DummyMoEConfig()
self.quant_method = _DummyQuantMethod()
self.router = _DummyRouter()

monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE)

Expand All @@ -39,19 +41,22 @@ class DummyCapturer:
def get_device_cache(self):
return DummyDeviceCache(buffer)

def map_layer_to_id(self, name: str) -> int:
return int(name)

monkeypatch.setattr(rec_mod, "get_global_experts_capturer", lambda: DummyCapturer())

m0 = DummyFusedMoE(moe_layer_id=0)
m2 = DummyFusedMoE(moe_layer_id=2)
m0 = DummyFusedMoE(name="0")
m2 = DummyFusedMoE(name="2")

class DummyModel:
def modules(self):
return iter([m0, m2])

rec_mod.bind_routing_capture_to_model(DummyModel())

assert torch.equal(m0._routing_replay_out, buffer[0])
assert torch.equal(m2._routing_replay_out, buffer[2])
assert torch.equal(m0.router._routing_replay_out, buffer[0])
assert torch.equal(m2.router._routing_replay_out, buffer[2])


def test_bind_routing_capture_to_model_noop_when_disabled(monkeypatch):
Expand Down
9 changes: 2 additions & 7 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,6 @@ class FusedMoE(PluggableLayer):
not supported by the router (or the experts).
"""

# Auto-incrementing layer ID for routing replay buffer binding.
_next_moe_layer_id: int = 0

# --8<-- [end:fused_moe]

def __init__(
Expand Down Expand Up @@ -294,10 +291,6 @@ def __init__(
):
super().__init__()

# Assign unique layer ID for routing replay buffer binding.
self.moe_layer_id = FusedMoE._next_moe_layer_id
FusedMoE._next_moe_layer_id += 1

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
Expand Down Expand Up @@ -341,6 +334,8 @@ def __init__(
# Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping

print(f"PREFIX = {prefix}")

Comment thread
bnellnm marked this conversation as resolved.
Outdated
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
Expand Down
21 changes: 18 additions & 3 deletions vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def get_host_cache(self):
def get_device_cache(self):
raise NotImplementedError

def map_layer_to_id(self, layer_name: str) -> int:
raise NotImplementedError


def _count_moe_layers(hf_config) -> int:
"""Count the number of MoE layers in a model.
Expand Down Expand Up @@ -293,6 +296,8 @@ def __init__(
device=device,
)

self._id_map: dict[str, int] = {}

# ---- Async D2H pipeline (rank-0 only) ----
# Non-rank-0 workers only need the device buffer for symmetric
# CUDA graph capture; they skip the D2H pipeline entirely.
Expand Down Expand Up @@ -476,6 +481,13 @@ def get_host_cache(self):
def get_device_cache(self):
return self.device_cache

def map_layer_to_id(self, layer_name: str) -> int:
if layer_name not in self._id_map:
next_id = len(self._id_map)
self._id_map[layer_name] = next_id
return next_id
return self._id_map[layer_name]


class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer):
def __init__(self):
Expand All @@ -499,6 +511,9 @@ def get_host_cache(self):
def get_device_cache(self):
pass

def map_layer_to_id(self, layer_name: str) -> int:
return 0


# Global capturer instance (per-process)
_global_expert_capturer: RoutedExpertsCapturer | None = _RoutedExpertsCapturerNoop()
Expand Down Expand Up @@ -794,7 +809,7 @@ def bind_routing_capture_to_model(model) -> None:

bound = 0
for module in model.modules():
if isinstance(module, FusedMoE) and hasattr(module, "moe_layer_id"):
if isinstance(module, FusedMoE):
# Per-FusedMoE configurations not yet validated for routing
# capture. These signals are only set after model init, so a
# config-level guard cannot see them.
Expand All @@ -815,9 +830,9 @@ def bind_routing_capture_to_model(model) -> None:
f"dp_size={module.moe_config.dp_size})."
)

layer_id = module.moe_layer_id
layer_id = capturer.map_layer_to_id(module.layer_name)
layer_buf = buffer[layer_id] # (N_max, K)
module._routing_replay_out = layer_buf
module.router._routing_replay_out = layer_buf
# Mark each per-layer view as static so CUDA graphs don't
# snapshot/restore or relocate the buffer during replay.
if hasattr(torch.compiler, "cudagraph_mark_tensor_static"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _compute_routing(
"""
raise NotImplementedError

def select_experts(
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class FusedMoERouter(ABC):
method that is used for routing hidden states based on router logits.
"""

def __init__(self):
self._routing_replay_out: torch.Tensor | None = None

@abstractmethod
def set_capture_fn(
self,
Expand All @@ -27,6 +30,15 @@ def routing_method_type(self) -> RoutingMethodType:
raise NotImplementedError

@abstractmethod
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def select_experts(
self,
hidden_states: torch.Tensor,
Expand All @@ -47,4 +59,18 @@ def select_experts(
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
raise NotImplementedError

topk_weights, topk_ids = self._select_experts(
hidden_states,
router_logits,
input_ids=input_ids,
)

# Write routing data for non-monolithic path (Triton, etc.)
# (set by bind_routing_capture_to_model during capturer init)
if self._routing_replay_out is not None:
self._routing_replay_out[: topk_ids.shape[0]].copy_(
topk_ids.to(torch.int16)
)

return topk_weights, topk_ids
8 changes: 0 additions & 8 deletions vllm/model_executor/layers/fused_moe/runner/moe_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,6 @@ def _apply_quant_method(
shared_experts_input, SharedExpertsOrder.NO_OVERLAP
)

# Get routing replay buffer from persistent layer attribute
# (set by bind_routing_capture_to_model during capturer init)
routing_replay_out = getattr(layer, "_routing_replay_out", None)

if self._quant_method.is_monolithic:
fused_out = self._quant_method.apply_monolithic(
layer=layer,
Expand All @@ -529,10 +525,6 @@ def _apply_quant_method(
input_ids=input_ids,
)

# Write routing data for non-monolithic path (Triton, etc.)
if routing_replay_out is not None:
routing_replay_out[: topk_ids.shape[0]].copy_(topk_ids.to(torch.int16))

# Passing shared_experts_input in case SharedExpertsOrder is
# MK_INTERNAL_OVERLAPPED.
fused_out = self._quant_method.apply(
Expand Down
Loading