diff --git a/tests/kernels/moe/test_zero_expert_moe.py b/tests/kernels/moe/test_zero_expert_moe.py new file mode 100644 index 000000000000..d8f900256ec3 --- /dev/null +++ b/tests/kernels/moe/test_zero_expert_moe.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for FusedMoE with zero experts. + +Verifies that: +- The ZeroExpertRouter is properly created and used as the layer router. +- A forward pass through FusedMoE with zero experts produces correct output. +- The output decomposes correctly into real expert + zero expert contributions. + +Note: tests generated with Claude. +""" + +import pytest +import torch + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.router.zero_expert_router import ( + ZeroExpertRouter, +) +from vllm.v1.worker.workspace import init_workspace_manager + + +@pytest.fixture +def zero_expert_moe(dist_init, default_vllm_config): + """Create a FusedMoE layer with zero experts.""" + num_experts = 4 + top_k = 2 + # hidden_size must be >= 256 for the zero expert identity kernel to + # produce output (its BLOCK_SIZE=256 causes grid=0 when hidden_dim<256). + hidden_size = 256 + intermediate_size = 512 + zero_expert_num = 1 + + e_score_correction_bias = torch.zeros( + num_experts + zero_expert_num, + dtype=torch.float32, + device="cuda", + ) + + vllm_config = VllmConfig() + vllm_config.compilation_config.static_forward_context = dict() + + with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config): + init_workspace_manager(torch.accelerator.current_device_index()) + + layer = FusedMoE( + zero_expert_type="identity", + e_score_correction_bias=e_score_correction_bias, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=torch.bfloat16, + prefix="test_zero_expert_moe", + renormalize=False, + routed_scaling_factor=1.0, + scoring_func="softmax", + ).cuda() + + layer.quant_method.process_weights_after_loading(layer) + + yield layer, vllm_config + + +@pytest.mark.parametrize("num_tokens", [1, 32]) +def test_zero_expert_moe_router_is_zero_expert_router(zero_expert_moe, num_tokens): + """Verify that FusedMoE with zero_expert_type creates a ZeroExpertRouter.""" + layer, _ = zero_expert_moe + assert isinstance(layer.router, ZeroExpertRouter), ( + f"Expected ZeroExpertRouter but got {type(layer.router).__name__}." + ) + + +@pytest.mark.parametrize("num_tokens", [1, 32]) +def test_zero_expert_moe_no_custom_routing_fn(zero_expert_moe, num_tokens): + """Verify that custom_routing_function is not set (routing is handled + by ZeroExpertRouter, not a memoizing closure).""" + layer, _ = zero_expert_moe + assert layer.custom_routing_function is None + + +@pytest.mark.parametrize("num_tokens", [1, 32]) +def test_zero_expert_moe_forward(zero_expert_moe, num_tokens): + """Run a forward pass through FusedMoE with zero experts and verify output shape.""" + layer, vllm_config = zero_expert_moe + + hidden_size = layer.hidden_size + num_experts = 4 + zero_expert_num = 1 + total_experts = num_experts + zero_expert_num + + hidden_states = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + router_logits = torch.randn( + num_tokens, total_experts, dtype=torch.float32, device="cuda" + ) + + # Initialize weights to small random values to avoid NaN from + # uninitialized memory. + with torch.no_grad(): + for param in layer.parameters(): + if param.dtype.is_floating_point: + param.normal_(0, 0.01) + + with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config): + get_forward_context().all_moe_layers = None + output = layer.forward(hidden_states, router_logits) + + assert output.shape == hidden_states.shape, ( + f"Expected output shape {hidden_states.shape}, got {output.shape}" + ) + assert output.dtype == hidden_states.dtype + assert not torch.isnan(output).any(), "Output contains NaN values" + + +@pytest.mark.parametrize("num_tokens", [1, 32]) +def test_zero_expert_moe_output_decomposition(zero_expert_moe, num_tokens): + """Validate that the FusedMoE output equals a plain FusedMoE + output (real experts only) plus the zero expert contribution. + + The key invariant is: + zero_layer.forward(h, r_full) == plain_layer.forward(h, r_real) + + zero_expert_output + + We create a plain FusedMoE layer with the same weights and real-expert-only + router logits, compute the zero expert output via the ZeroExpertRouter, and + verify the sum matches the FusedMoE output. + """ + layer, vllm_config = zero_expert_moe + num_experts = 4 + zero_expert_num = 1 + total_experts = num_experts + zero_expert_num + + hidden_states = torch.randn( + num_tokens, layer.hidden_size, dtype=torch.bfloat16, device="cuda" + ) + router_logits = torch.randn( + num_tokens, total_experts, dtype=torch.float32, device="cuda" + ) + + with torch.no_grad(): + for param in layer.parameters(): + if param.dtype.is_floating_point: + param.normal_(0, 0.01) + + with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config): + get_forward_context().all_moe_layers = None + + # Create a plain FusedMoE layer with the same config but no zero + # experts. Use a separate prefix to avoid collision. + plain_layer = FusedMoE( + num_experts=num_experts, + top_k=layer.top_k, + hidden_size=layer.hidden_size, + intermediate_size=layer.intermediate_size_per_partition, + params_dtype=torch.bfloat16, + prefix="test_zero_expert_moe_plain", + renormalize=False, + scoring_func="softmax", + e_score_correction_bias=layer.e_score_correction_bias, + ).cuda() + + # Share weights from the zero expert layer. + plain_layer.w13_weight.data.copy_(layer.w13_weight.data) + plain_layer.w2_weight.data.copy_(layer.w2_weight.data) + plain_layer.quant_method.process_weights_after_loading(plain_layer) + + # Compute routing via the ZeroExpertRouter. This produces masked + # topk_weights/topk_ids (zero expert entries have weight=0, id=0) + # and stores zero_expert_output as a side effect. + topk_weights, topk_ids = layer.router.select_experts( + hidden_states, router_logits + ) + zero_output = layer.router.zero_expert_output + + # Compute real expert output using the plain layer with the masked + # routing from the ZeroExpertRouter. + real_output = plain_layer.quant_method.apply( + layer=plain_layer, + x=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_experts_input=None, + ) + + # Get the combined output from the zero expert layer. + full_output = layer.forward(hidden_states, router_logits) + + assert zero_output is not None, "Zero expert output should not be None" + assert not torch.isnan(real_output).any(), "Real expert output has NaN" + assert not torch.isnan(zero_output).any(), "Zero expert output has NaN" + assert not torch.isnan(full_output).any(), "Full output has NaN" + + expected = real_output + zero_output + torch.testing.assert_close( + full_output, + expected, + atol=0, + rtol=0, + msg="FusedMoE output should equal plain FusedMoE output " + "plus zero expert contribution", + ) + + +@pytest.mark.parametrize("num_tokens", [1, 32]) +def test_zero_expert_moe_zero_expert_is_identity(zero_expert_moe, num_tokens): + """Validate zero expert identity behavior. + + When routing strongly favors the zero expert, its contribution should + be a scaled version of hidden_states (identity operation). We verify + this by manually computing the expected zero expert output from the + routing weights and comparing against what the router produces. + """ + layer, vllm_config = zero_expert_moe + num_experts = 4 + zero_expert_num = 1 + total_experts = num_experts + zero_expert_num + + hidden_states = torch.randn( + num_tokens, layer.hidden_size, dtype=torch.bfloat16, device="cuda" + ) + # Strongly bias toward the zero expert (index 4). + router_logits = torch.full( + (num_tokens, total_experts), -10.0, dtype=torch.float32, device="cuda" + ) + router_logits[:, num_experts] = 10.0 # zero expert gets high logit + + with torch.no_grad(): + for param in layer.parameters(): + if param.dtype.is_floating_point: + param.normal_(0, 0.01) + + with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config): + get_forward_context().all_moe_layers = None + + # Run routing to get topk_weights/topk_ids before masking. + from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( + fused_topk_bias, + ) + + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=router_logits, + e_score_correction_bias=layer.router.e_score_correction_bias.data, + topk=layer.top_k, + renormalize=layer.router.renormalize, + scoring_func=layer.router.scoring_func, + ) + + # Manually compute expected zero expert identity output: + # For each token, sum routing weights assigned to zero expert slots, + # then multiply by hidden_states. + zero_mask = topk_ids >= num_experts + zero_weight_per_token = (topk_weights * zero_mask.float()).sum( + dim=-1, keepdim=True + ) + expected_zero_output = (hidden_states.float() * zero_weight_per_token).to( + hidden_states.dtype + ) + + # Run routing directly to trigger zero expert computation + # without going through the runner (which consumes the output). + layer.router.select_experts(hidden_states, router_logits) + actual_zero_output = layer.router.zero_expert_output + + assert actual_zero_output is not None + assert zero_mask.any(), ( + "With high zero expert logit, at least some slots should route " + "to the zero expert" + ) + + torch.testing.assert_close( + actual_zero_output, + expected_zero_output, + atol=1e-3, + rtol=1e-3, + msg="Zero expert identity output should equal " + "hidden_states * sum(zero_expert_weights)", + ) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index b342e0c6e918..926f0d1d0154 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -33,9 +33,6 @@ from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, ) -from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import ( - ZeroExpertFusedMoE, -) from vllm.triton_utils import HAS_TRITON _config: dict[str, Any] | None = None @@ -68,7 +65,6 @@ def get_config() -> dict[str, Any] | None: "GateLinear", "RoutingMethodType", "SharedFusedMoE", - "ZeroExpertFusedMoE", "activation_without_mul", "apply_moe_activation", "override_config", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fad84923bf5a..190a9cc3b5d7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -274,6 +274,7 @@ def __init__( gate: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None, routed_input_transform: torch.nn.Module | None = None, + zero_expert_type: str | None = None, ): super().__init__() @@ -462,6 +463,8 @@ def __init__( # TODO(bnell): once we can construct the MK at init time, we # can make this a value. indices_type_getter=lambda: self.quant_method.topk_indices_dtype, + zero_expert_type=zero_expert_type, + num_logical_experts=self.logical_num_experts, ) self.routing_method_type: RoutingMethodType = self.router.routing_method_type diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 11027e894bee..42d418d7e537 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import ( RoutingSimulatorRouter, ) +from vllm.model_executor.layers.fused_moe.router.zero_expert_router import ( + ZeroExpertRouter, +) EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState() @@ -49,6 +52,9 @@ def create_fused_moe_router( # eplb parameters enable_eplb: bool = False, eplb_state: EplbLayerState = EMPTY_EPLB_STATE, + # zero expert parameters + zero_expert_type: str | None = None, + num_logical_experts: int | None = None, ) -> FusedMoERouter: """ Factory function to create the appropriate FusedMoERouter subclass based on @@ -56,10 +62,11 @@ def create_fused_moe_router( The selection logic follows this priority order: 1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set - 2. GroupedTopKRouter - if use_grouped_topk is True - 3. CustomRoutingRouter - if custom_routing_function is not None - 4. FusedTopKBiasRouter - if e_score_correction_bias is not None - 5. FusedTopKRouter - default fallback + 2. ZeroExpertRouter - if zero_expert_type is not None + 3. GroupedTopKRouter - if use_grouped_topk is True + 4. CustomRoutingRouter - if custom_routing_function is not None + 5. FusedTopKBiasRouter - if e_score_correction_bias is not None + 6. FusedTopKRouter - default fallback Common arguments: top_k: Number of experts to select per token @@ -86,6 +93,12 @@ def create_fused_moe_router( enable_eplb: Whether EPLB is enabled eplb_state: EPLB (Expert Parallelism Load Balancing) state + Zero expert arguments: + zero_expert_type: Type of zero expert (e.g. identity). If not None, + creates a ZeroExpertRouter. + num_logical_experts: Number of real (non-zero) experts. Required when + zero_expert_type is not None. + Returns: An instance of the appropriate FusedMoERouter subclass """ @@ -100,6 +113,27 @@ def create_fused_moe_router( indices_type_getter=indices_type_getter, ) + if zero_expert_type is not None: + assert num_logical_experts is not None, ( + "num_logical_experts is required when zero_expert_type is set" + ) + assert e_score_correction_bias is not None, ( + "e_score_correction_bias is required when zero_expert_type is set" + ) + return ZeroExpertRouter( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + e_score_correction_bias=e_score_correction_bias, + num_logical_experts=num_logical_experts, + zero_expert_type=zero_expert_type, + scoring_func=scoring_func, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + if use_grouped_topk: assert custom_routing_function is None if num_expert_group is None or topk_group is None: diff --git a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py new file mode 100644 index 000000000000..c87070bc5acf --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch + +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.config import ( + RoutingMethodType, + get_routing_method_type, +) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + zero_experts_compute_triton, +) +from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter +from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( + fused_topk_bias, +) + + +class ZeroExpertRouter(BaseRouter): + """Router that handles zero expert computation as part of routing. + + Routes over all experts (real + zero) using full e_score_correction_bias. + Computes zero expert identity contributions as a side effect during routing. + Remaps zero expert IDs to real expert ID 0 (with weight 0) so downstream + MoE computation can ignore them. + """ + + def __init__( + self, + top_k: int, + global_num_experts: int, + eplb_state: EplbLayerState, + e_score_correction_bias: torch.Tensor, + num_logical_experts: int, + zero_expert_type: str, + scoring_func: str = "softmax", + renormalize: bool = False, + routed_scaling_factor: float = 1.0, + enable_eplb: bool = False, + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + ): + super().__init__( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + self.e_score_correction_bias = e_score_correction_bias + self.num_logical_experts = num_logical_experts + self.zero_expert_type = zero_expert_type + self.scoring_func = scoring_func + self.renormalize = renormalize + self.routed_scaling_factor = routed_scaling_factor + self._zero_expert_output: torch.Tensor | None = None + + @property + def routing_method_type(self) -> RoutingMethodType: + return get_routing_method_type( + scoring_func=self.scoring_func, + top_k=self.top_k, + renormalize=self.renormalize, + num_expert_group=None, + has_e_score_bias=True, + ) + + def _compute_routing( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + indices_type: torch.dtype | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute routing with full bias, compute zero expert output, + mask zero expert IDs.""" + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=router_logits, + e_score_correction_bias=self.e_score_correction_bias.data, + topk=self.top_k, + renormalize=self.renormalize, + scoring_func=self.scoring_func, + indices_type=indices_type, + ) + + if self.routed_scaling_factor != 1.0: + topk_weights *= self.routed_scaling_factor + + # Compute zero expert output using pre-EPLB topk_ids/weights. + # zero_experts_compute_triton modifies its inputs in-place, so + # pass clones. + self._zero_expert_output = zero_experts_compute_triton( + expert_indices=topk_ids.clone(), + expert_scales=topk_weights.clone(), + num_experts=self.num_logical_experts, + zero_expert_type=self.zero_expert_type, + hidden_states=hidden_states, + ) + + # Mask zero expert entries: remap zero expert IDs to 0 with weight 0 + # so downstream MoE computation ignores them. + zero_mask = topk_ids >= self.num_logical_experts + topk_ids[zero_mask] = 0 + topk_weights[zero_mask] = 0.0 + + return topk_weights, topk_ids + + @property + def zero_expert_output(self) -> torch.Tensor | None: + """Retrieve and clear the zero expert output.""" + output = self._zero_expert_output + self._zero_expert_output = None + return output diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py index a881c27d542c..692d45d34607 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py @@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) +from vllm.model_executor.layers.fused_moe.router.zero_expert_router import ( + ZeroExpertRouter, +) from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( SharedExperts, @@ -443,6 +446,19 @@ def _maybe_sync_shared_experts_stream( if self._shared_experts is not None: self._shared_experts.maybe_sync_shared_experts_stream(shared_experts_input) + def _maybe_add_zero_expert_output( + self, + result: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if isinstance(self.router, ZeroExpertRouter): + zero_expert_output = self.router.zero_expert_output + assert zero_expert_output is not None + if isinstance(result, tuple): + result = (result[0], result[1] + zero_expert_output) + else: + result = result + zero_expert_output + return result + def forward( self, hidden_states: torch.Tensor, @@ -494,7 +510,9 @@ def forward( self._encode_layer_name(), ) - return self._maybe_reduce_output(fused_output, og_hidden_dims) + result = self._maybe_reduce_output(fused_output, og_hidden_dims) + + return self._maybe_add_zero_expert_output(result) def forward_dispatch( self, diff --git a/vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py b/vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py deleted file mode 100644 index 97d21767f4fc..000000000000 --- a/vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager - -import torch -from torch import nn - -from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton -from vllm.model_executor.layers.fused_moe.layer import FusedMoE - - -class ZeroExpertFusedMoE(FusedMoE): - """ - A FusedMoE operation that also computes the results of zero experts. - Zero experts perform identity operations (scaled pass-through) instead - of full MLP computations. - - This class uses memoization to avoid redundant routing computation: - routing is computed once and reused for both zero expert computation - and the main FusedMoE forward pass. - """ - - def __init__( - self, - zero_expert_num: int, - zero_expert_type: str, - router: nn.Module, - **kwargs, - ): - # ZeroExpertFusedMoE manages its own custom_routing_function for memoization - assert ( - "custom_routing_function" not in kwargs - or kwargs.get("custom_routing_function") is None - ), ( - "ZeroExpertFusedMoE does not support external custom_routing_function. " - "It manages its own for routing memoization." - ) - - # Automatically slice router's e_score_correction_bias to only include - # real experts (not zero_experts) for the base FusedMoE. - # The full bias will be used temporarily in forward() for routing. - if hasattr(router, "e_score_correction_bias") and "num_experts" in kwargs: - num_real_experts = kwargs["num_experts"] - router_bias = router.e_score_correction_bias - user_bias = kwargs.get("e_score_correction_bias") - - # Use router's bias if: - # 1. User didn't provide bias, or - # 2. User provided full bias (same size as router) - if user_bias is None or user_bias.shape[0] == router_bias.shape[0]: - kwargs["e_score_correction_bias"] = router_bias[:num_real_experts] - - # FusedMoE no longer accepts zero_expert_num/zero_expert_type. - # We handle zero experts ourselves in forward(). - super().__init__(**kwargs) - # Store the actual zero_expert_num and zero_expert_type for our own use - self._actual_zero_expert_num = zero_expert_num - self._actual_zero_expert_type = zero_expert_type - self._router = router # Full router (includes zero experts) - - # Expose zero_expert_num and zero_expert_type as attributes for - # compatibility with quantization methods that check these attributes - self.zero_expert_num = 0 - self.zero_expert_type = None - - # Memoization state for routing results - self._memoized_topk_weights: torch.Tensor | None = None - self._memoized_topk_ids: torch.Tensor | None = None - - # Create custom_routing_function to reuse memoized routing results - def custom_routing_function(hidden_states, gating_output, topk, renormalize): - """Return memoized `topk_weights` and `topk_ids`.""" - if self._memoized_topk_weights is None or self._memoized_topk_ids is None: - raise RuntimeError( - "ZeroExpertFusedMoE: routing results not memoized. " - "Call select_experts first to compute routing." - ) - return self._memoized_topk_weights, self._memoized_topk_ids - - self.custom_routing_function = custom_routing_function - - @contextmanager - def _temporarily_set_attrs(self, **attrs): - """ - Temporarily set attributes using object.__setattr__ and restore them. - - This bypasses nn.Module.__setattr__ to avoid Dynamo tracing issues. - When PyTorch Dynamo traces the forward pass, it cannot handle - nn.Module.__setattr__ calls (which include parameter registration logic), - resulting in "Unsupported" errors. Using object.__setattr__ directly - sets the attribute without triggering nn.Module's custom __setattr__, - allowing Dynamo to trace the code successfully. - """ - originals = {key: getattr(self, key) for key in attrs} - try: - for key, value in attrs.items(): - object.__setattr__(self, key, value) - yield - finally: - for key, value in originals.items(): - object.__setattr__(self, key, value) - - def _compute_zero_expert_result( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> torch.Tensor | None: - """Compute zero expert results using pre-computed routing.""" - if ( - self._actual_zero_expert_num is None - or self._actual_zero_expert_num <= 0 - or self._actual_zero_expert_type is None - ): - return None - - return zero_experts_compute_triton( - expert_indices=topk_ids.clone(), - expert_scales=topk_weights.clone(), - num_experts=self.logical_num_experts, - zero_expert_type=self._actual_zero_expert_type, - hidden_states=hidden_states, - ) - - def forward( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, # Full logits including zero experts - ) -> torch.Tensor: - """ - Forward pass with zero expert support and routing memoization. - - Args: - hidden_states: Input hidden states - router_logits: Full router logits (including zero experts) - - Returns: - Combined output from real experts and zero experts - """ - # Prepare temporary attribute overrides for routing computation - temp_attrs = { - "custom_routing_function": None, # Disable for first routing - } - if self._router is not None: - temp_attrs["e_score_correction_bias"] = self._router.e_score_correction_bias - - # Compute routing with temporary attributes - # Pass full router_logits (including zero experts) so that zero experts - # can be properly identified in topk_ids - with self._temporarily_set_attrs(**temp_attrs): - topk_weights, topk_ids = self.select_experts( - hidden_states=hidden_states, - router_logits=router_logits, # Full logits (includes zero experts) - ) - - # Compute zero expert result if needed - zero_expert_result = self._compute_zero_expert_result( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - ) - - # Memoize routing results for reuse in super().forward() - self._memoized_topk_weights = topk_weights - self._memoized_topk_ids = topk_ids - - # Slice router_logits for real experts only - router_logits_sliced = router_logits[..., : self.logical_num_experts] - - # Compute real expert results (will reuse memoized routing via - # custom_routing_function) - # zero_expert_num is already 0, so FusedMoE won't handle zero experts - fused_out = super().forward( - hidden_states=hidden_states, - router_logits=router_logits_sliced, - ) - - # Combine results - # Both zero_expert_result and fused_out are computed from the same - # hidden_states, so they should be on the same device. - if zero_expert_result is not None: - fused_out = fused_out + zero_expert_result - - # Clear memoization after use - self._memoized_topk_weights = None - self._memoized_topk_ids = None - - return fused_out diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index a9e2c2268ee1..375b0b69b1f9 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -46,7 +46,7 @@ from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE, ZeroExpertFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -292,12 +292,10 @@ def __init__( prefix=f"{prefix}.gate", ) - assert config.zero_expert_num is not None assert config.zero_expert_type is not None - self.experts = ZeroExpertFusedMoE( - zero_expert_num=config.zero_expert_num, + self.experts = FusedMoE( zero_expert_type=config.zero_expert_type, - router=self.router, + e_score_correction_bias=self.router.e_score_correction_bias, num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, @@ -332,7 +330,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states_padded.to(self.router_params_dtype) ) - # ZeroExpertFusedMoE handles routing memoization and zero expert computation + # FusedMoE handles routing memoization and zero expert computation # internally. Pass full router_logits (including zero experts) so that # zero experts can be properly identified in routing. final_hidden_states = self.experts(