diff --git a/tests/model_executor/test_routed_experts_capture.py b/tests/model_executor/test_routed_experts_capture.py new file mode 100644 index 000000000000..45bf4bcac6a8 --- /dev/null +++ b/tests/model_executor/test_routed_experts_capture.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import types + +import pytest +import torch + +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter + +pytestmark = pytest.mark.cpu_test + + +class DummyRouter(BaseRouter): + @property + def routing_method_type(self) -> RoutingMethodType: + return RoutingMethodType.FUSED_TOPK + + def _compute_routing(self, hidden_states, router_logits, indices_type): + topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64) + topk_weights = torch.ones_like(topk_ids, dtype=torch.float32) + return topk_weights, topk_ids + + def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor: + # Make mapping observable without requiring CUDA EPLB path. + return topk_ids + 10 + + +def _make_router() -> DummyRouter: + return DummyRouter( + top_k=2, + global_num_experts=16, + eplb_state=EplbLayerState(), + enable_eplb=False, + indices_type_getter=None, + ) + + +def test_base_router_capture_pre_eplb_mapping(): + router = _make_router() + captured = [] + + def capture_fn(ids): + captured.append(ids.clone()) + + router.set_capture_fn(capture_fn) + topk_weights, topk_ids = router.select_experts( + hidden_states=torch.empty(1), + router_logits=torch.empty(1), + ) + + assert topk_weights.shape == topk_ids.shape + assert len(captured) == 1 + assert torch.equal(captured[0], torch.tensor([[1, 2], [3, 4]])) + assert torch.equal(topk_ids, torch.tensor([[11, 12], [13, 14]])) + + +def test_base_router_capture_with_eplb_enabled(): + router = _make_router() + router.enable_eplb = True + router.eplb_state.expert_load_view = torch.zeros(32, dtype=torch.int64) + router.eplb_state.logical_to_physical_map = torch.arange(32).view(32, 1) + router.eplb_state.logical_replica_count = torch.ones(32, dtype=torch.int64) + + captured = [] + + def capture_fn(ids): + captured.append(ids.clone()) + + router.set_capture_fn(capture_fn) + _, topk_ids = router.select_experts( + hidden_states=torch.empty(1), + router_logits=torch.empty(1), + ) + + assert len(captured) == 1 + # Capture should see logical ids pre-EPLB mapping. + assert torch.equal(captured[0], torch.tensor([[1, 2], [3, 4]])) + # Our DummyRouter mapping adds +10. + assert torch.equal(topk_ids, torch.tensor([[11, 12], [13, 14]])) + + +def test_gpu_model_runner_binds_router_capture(monkeypatch): + from vllm.v1.worker import gpu_model_runner as gmr + + class DummyFusedMoE: + def __init__(self): + self.layer_id = 7 + self.router = _make_router() + + class DummyCapturer: + def __init__(self): + self.calls = [] + + def capture(self, layer_id, topk_ids): + self.calls.append((layer_id, topk_ids)) + + dummy_module = DummyFusedMoE() + + # Patch the runtime import inside _bind_routed_experts_capturer. + import vllm.model_executor.layers.fused_moe.layer as fused_moe_layer + + monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE) + + dummy_self = types.SimpleNamespace( + compilation_config=types.SimpleNamespace( + static_forward_context={"dummy": dummy_module} + ) + ) + + capturer = DummyCapturer() + gmr.GPUModelRunner._bind_routed_experts_capturer(dummy_self, capturer) + + assert dummy_module.router.capture_fn is not None + dummy_module.router.capture_fn(torch.tensor([[5, 6]])) + + assert len(capturer.calls) == 1 + layer_id, topk_ids = capturer.calls[0] + assert layer_id == 7 + assert torch.equal(topk_ids, torch.tensor([[5, 6]])) + + +def test_gpu_model_runner_binding_stage(monkeypatch): + from vllm.v1.worker import gpu_model_runner as gmr + + class DummyFusedMoE: + def __init__(self): + self.layer_id = 11 + self.router = _make_router() + + class DummyCapturer: + def __init__(self): + self.calls = [] + + def capture(self, layer_id, topk_ids): + self.calls.append((layer_id, topk_ids)) + + dummy_module = DummyFusedMoE() + + import vllm.model_executor.layers.fused_moe.layer as fused_moe_layer + + monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE) + + dummy_self = types.SimpleNamespace( + compilation_config=types.SimpleNamespace( + static_forward_context={"dummy": dummy_module} + ) + ) + + # Before binding, no capture hook. + assert dummy_module.router.capture_fn is None + + capturer = DummyCapturer() + gmr.GPUModelRunner._bind_routed_experts_capturer(dummy_self, capturer) + + # After binding, hook should exist and be callable. + assert callable(dummy_module.router.capture_fn) + dummy_module.router.capture_fn(torch.tensor([[9, 10]])) + assert len(capturer.calls) == 1 diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 538089882231..c814e716d4f9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -44,9 +44,6 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, ) -from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( - RoutedExpertsCapturer, -) from vllm.model_executor.layers.fused_moe.router.router_factory import ( create_fused_moe_router, ) @@ -523,18 +520,6 @@ def __init__( self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation - self.capture: Callable[[torch.Tensor], None] | None = None - if ( - self.vllm_config.model_config is not None - and self.vllm_config.model_config.enable_return_routed_experts - ): - # In dummy runs, the capturer is not initialized. - capturer = RoutedExpertsCapturer.get_instance() - if capturer is not None: - self.capture = lambda topk_ids: capturer.capture( - self.layer_id, topk_ids - ) - self.router = create_fused_moe_router( top_k=top_k, global_num_experts=self.global_num_experts, @@ -1680,9 +1665,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): router_logits=staged_router_logits, ) - if self.capture is not None: - self.capture(topk_ids) - final_hidden_states = self.quant_method.apply( layer=self, x=staged_hidden_states, @@ -1875,9 +1857,6 @@ def forward_impl( router_logits=router_logits, ) - if self.capture is not None: - self.capture(topk_ids) - final_hidden_states = self.quant_method.apply( layer=self, x=x, # The type signture of this is wrong due to the hack. diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 9969818abfd6..52005d40d525 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -127,6 +127,11 @@ def __init__( self.eplb_state = eplb_state self.enable_eplb = enable_eplb self.indices_type_getter = indices_type_getter + self.capture_fn: Callable[[torch.Tensor], None] | None = None + + def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> None: + """Set a capture callback for logical routed expert IDs.""" + self.capture_fn = capture_fn def _validate_eplb_state(self) -> None: """Validate that EPLB state is properly initialized if EPLB is enabled.""" @@ -231,6 +236,10 @@ def select_experts( hidden_states, router_logits, indices_type ) + # Capture logical ids before EPLB mapping. + if self.capture_fn is not None: + self.capture_fn(topk_ids) + # Step 4: Apply EPLB mapping topk_ids = self._apply_eplb_mapping(topk_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3ab7fcad7642..9b5a97ca7cdc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5974,6 +5974,22 @@ def init_routed_experts_capturer(self): max_num_kv_tokens=self.max_num_kv_tokens, vllm_config=self.vllm_config, ) + self._bind_routed_experts_capturer(routed_experts_capturer) + + def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + from vllm.model_executor.layers.fused_moe.router.base_router import ( + BaseRouter, + ) + + for module in self.compilation_config.static_forward_context.values(): + if isinstance(module, FusedMoE) and isinstance(module.router, BaseRouter): + layer_id = module.layer_id + + def _capture_fn(topk_ids, _layer_id=layer_id, _capturer=capturer): + _capturer.capture(_layer_id, topk_ids) + + module.router.set_capture_fn(_capture_fn) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """