Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions tests/model_executor/test_routed_experts_capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types

import pytest
import torch

from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter

pytestmark = pytest.mark.cpu_test


class DummyRouter(BaseRouter):
@property
def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.FUSED_TOPK

def _compute_routing(self, hidden_states, router_logits, indices_type):
topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64)
topk_weights = torch.ones_like(topk_ids, dtype=torch.float32)
return topk_weights, topk_ids

def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
# Make mapping observable without requiring CUDA EPLB path.
return topk_ids + 10


def _make_router() -> DummyRouter:
return DummyRouter(
top_k=2,
global_num_experts=16,
eplb_state=EplbLayerState(),
enable_eplb=False,
indices_type_getter=None,
)


def test_base_router_capture_pre_eplb_mapping():
router = _make_router()
captured = []

def capture_fn(ids):
captured.append(ids.clone())

router.set_capture_fn(capture_fn)
topk_weights, topk_ids = router.select_experts(
hidden_states=torch.empty(1),
router_logits=torch.empty(1),
)

assert topk_weights.shape == topk_ids.shape
assert len(captured) == 1
assert torch.equal(captured[0], torch.tensor([[1, 2], [3, 4]]))
assert torch.equal(topk_ids, torch.tensor([[11, 12], [13, 14]]))


def test_base_router_capture_with_eplb_enabled():
router = _make_router()
router.enable_eplb = True
router.eplb_state.expert_load_view = torch.zeros(32, dtype=torch.int64)
router.eplb_state.logical_to_physical_map = torch.arange(32).view(32, 1)
router.eplb_state.logical_replica_count = torch.ones(32, dtype=torch.int64)

captured = []

def capture_fn(ids):
captured.append(ids.clone())

router.set_capture_fn(capture_fn)
_, topk_ids = router.select_experts(
hidden_states=torch.empty(1),
router_logits=torch.empty(1),
)

assert len(captured) == 1
# Capture should see logical ids pre-EPLB mapping.
assert torch.equal(captured[0], torch.tensor([[1, 2], [3, 4]]))
# Our DummyRouter mapping adds +10.
assert torch.equal(topk_ids, torch.tensor([[11, 12], [13, 14]]))


def test_gpu_model_runner_binds_router_capture(monkeypatch):
from vllm.v1.worker import gpu_model_runner as gmr

class DummyFusedMoE:
def __init__(self):
self.layer_id = 7
self.router = _make_router()

class DummyCapturer:
def __init__(self):
self.calls = []

def capture(self, layer_id, topk_ids):
self.calls.append((layer_id, topk_ids))

dummy_module = DummyFusedMoE()

# Patch the runtime import inside _bind_routed_experts_capturer.
import vllm.model_executor.layers.fused_moe.layer as fused_moe_layer

monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE)

dummy_self = types.SimpleNamespace(
compilation_config=types.SimpleNamespace(
static_forward_context={"dummy": dummy_module}
)
)

capturer = DummyCapturer()
gmr.GPUModelRunner._bind_routed_experts_capturer(dummy_self, capturer)

assert dummy_module.router.capture_fn is not None
dummy_module.router.capture_fn(torch.tensor([[5, 6]]))

assert len(capturer.calls) == 1
layer_id, topk_ids = capturer.calls[0]
assert layer_id == 7
assert torch.equal(topk_ids, torch.tensor([[5, 6]]))


def test_gpu_model_runner_binding_stage(monkeypatch):
from vllm.v1.worker import gpu_model_runner as gmr

class DummyFusedMoE:
def __init__(self):
self.layer_id = 11
self.router = _make_router()

class DummyCapturer:
def __init__(self):
self.calls = []

def capture(self, layer_id, topk_ids):
self.calls.append((layer_id, topk_ids))

dummy_module = DummyFusedMoE()

import vllm.model_executor.layers.fused_moe.layer as fused_moe_layer

monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE)

dummy_self = types.SimpleNamespace(
compilation_config=types.SimpleNamespace(
static_forward_context={"dummy": dummy_module}
)
)

# Before binding, no capture hook.
assert dummy_module.router.capture_fn is None

capturer = DummyCapturer()
gmr.GPUModelRunner._bind_routed_experts_capturer(dummy_self, capturer)

# After binding, hook should exist and be callable.
assert callable(dummy_module.router.capture_fn)
dummy_module.router.capture_fn(torch.tensor([[9, 10]]))
assert len(capturer.calls) == 1
21 changes: 0 additions & 21 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
Expand Down Expand Up @@ -523,18 +520,6 @@ def __init__(
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation

self.capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
self.capture = lambda topk_ids: capturer.capture(
self.layer_id, topk_ids
)

self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
Expand Down Expand Up @@ -1680,9 +1665,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
router_logits=staged_router_logits,
)

if self.capture is not None:
self.capture(topk_ids)

final_hidden_states = self.quant_method.apply(
layer=self,
x=staged_hidden_states,
Expand Down Expand Up @@ -1875,9 +1857,6 @@ def forward_impl(
router_logits=router_logits,
)

if self.capture is not None:
self.capture(topk_ids)

final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/fused_moe/router/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def __init__(
self.eplb_state = eplb_state
self.enable_eplb = enable_eplb
self.indices_type_getter = indices_type_getter
self.capture_fn: Callable[[torch.Tensor], None] | None = None

def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> None:
"""Set a capture callback for logical routed expert IDs."""
self.capture_fn = capture_fn

def _validate_eplb_state(self) -> None:
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
Expand Down Expand Up @@ -231,6 +236,10 @@ def select_experts(
hidden_states, router_logits, indices_type
)

# Capture logical ids before EPLB mapping.
if self.capture_fn is not None:
self.capture_fn(topk_ids)

# Step 4: Apply EPLB mapping
topk_ids = self._apply_eplb_mapping(topk_ids)

Expand Down
16 changes: 16 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5974,6 +5974,22 @@ def init_routed_experts_capturer(self):
max_num_kv_tokens=self.max_num_kv_tokens,
vllm_config=self.vllm_config,
)
self._bind_routed_experts_capturer(routed_experts_capturer)

def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.router.base_router import (
BaseRouter,
)

for module in self.compilation_config.static_forward_context.values():
if isinstance(module, FusedMoE) and isinstance(module.router, BaseRouter):
layer_id = module.layer_id

def _capture_fn(topk_ids, _layer_id=layer_id, _capturer=capturer):
_capturer.capture(_layer_id, topk_ids)

module.router.set_capture_fn(_capture_fn)

def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Expand Down