From 9fe13927d6bdbff2ac3be1a6f41d23e49fac31a7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 Apr 2026 20:54:37 +0000 Subject: [PATCH 01/10] eplb manager Signed-off-by: Bill Nell --- .../layers/fused_moe/eplb_manager.py | 246 ++++++++++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 184 +++++-------- 2 files changed, 308 insertions(+), 122 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/eplb_manager.py diff --git a/vllm/model_executor/layers/fused_moe/eplb_manager.py b/vllm/model_executor/layers/fused_moe/eplb_manager.py new file mode 100644 index 000000000000..68137ccb01bc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/eplb_manager.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +EPLB (Expert Parallelism Load Balancing) Manager. + +This module provides the EplbManager class which encapsulates all EPLB-related +functionality for MoE layers, including state management, expert weight +collection, and expert parameter mapping. +""" + +from collections.abc import Iterable + +import torch + +from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState + + +class EplbManager: + """ + Manages Expert Parallelism Load Balancing (EPLB) state and operations + for a MoE layer. + + This class encapsulates all EPLB-related functionality including: + - Runtime state (expert load view, logical-to-physical mapping) + - Expert weight collection for load balancing + - Expert parameter mapping for weight loading with redundant experts + - Validation of EPLB configuration constraints + """ + + def __init__( + self, + ep_size: int, + global_num_experts: int, + logical_num_experts: int, + num_redundant_experts: int = 0, + ): + """ + Initialize EPLB manager. + + Args: + ep_size: Expert parallel world size + global_num_experts: Total number of experts (including redundant) + logical_num_experts: Number of logical (non-redundant) experts + num_redundant_experts: Number of redundant experts + """ + self.ep_size = ep_size + self.global_num_experts = global_num_experts + self.logical_num_experts = logical_num_experts + self.num_redundant_experts = num_redundant_experts + + # Runtime EPLB state. + self.state = EplbLayerState() + + # Validate EPLB configuration. + # EPLB currently only supports even distribution of experts across ranks + if self.global_num_experts % self.ep_size != 0: + raise ValueError( + f"EPLB currently only supports even distribution of " + f"experts across ranks. Got {self.global_num_experts} experts " + f"and {self.ep_size} EP ranks." + ) + + def set_state( + self, + moe_layer_idx: int, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + """ + Register the EPLB state for this layer. + + This is used later in forward pass, where we get the expert mapping + and record the load metrics in `expert_load_view`. + + Args: + moe_layer_idx: Index of this MoE layer + expert_load_view: View into global expert load tracking tensor + logical_to_physical_map: Mapping from logical to physical expert IDs + logical_replica_count: Number of replicas for each logical expert + """ + self.state.expert_load_view = expert_load_view[moe_layer_idx] + self.state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.state.logical_replica_count = logical_replica_count[moe_layer_idx] + + @staticmethod + def get_expert_weights( + layer: torch.nn.Module, # FusedMoE + ) -> Iterable[torch.Tensor]: + """ + Collect expert weights from the MoE layer for EPLB. + + Returns weights reshaped as (local_num_experts, -1) for efficient + expert weight swapping during load balancing. + + Args: + layer: The FusedMoE layer to collect weights from + + Returns: + Iterable of expert weight tensors + """ + + def _maybe_make_contiguous( + name: str, p: torch.nn.Parameter + ) -> torch.nn.Parameter: + """ + In some cases, the last 2 dimensions (the non-expert dimensions) + of the weight scale tensor are transposed. This function + transforms the tensor (view update) so the tensor is contiguous(). + Example: A non-contiguous scale tensor, + `x` of shape (E, 32, 16) and stride (512, 1, 32) is transformed to + `x_` of shape (E, 16, 32) and stride (512, 32, 1). + Note that we specifically use torch.transpose() so `x_` refers + to the same underlying memory. The tensors `x` and `x_`, pointing + to the same underlying memory make this transformation safe in the + context of EPLB. i.e. It is the same memory and just the view + is different. + Note: This function handles the "weight_scale" tensors specifically. + This could however be generalized to handle similar tensors. + """ + if p.ndim != 3: + return p + if p.is_contiguous(): + # Already contiguous. do nothing. + return p + # p is non-contiguous. We only handle the case where the last 2 + # dimensions of the scales tensor is transposed. We can handle + # other cases when they become relevant. + is_transposed_12 = p.stride(1) == 1 and p.stride(2) != 1 + if "weight_scale" not in name or not is_transposed_12: + # do nothing. + return p + + # Do not update the layer parameter as the layer's MoE operations would + # expect the parameter's tensor to the same shape / stride. Instead, + # make a new torch.nn.Parameter that is used just in the context of + # EPLB. + return torch.nn.Parameter( + torch.transpose(p.data, 1, 2), requires_grad=False + ) + + weights = list(layer.named_parameters()) + weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights] + + # `w13_input_scale` and `w2_input_scale` are global per-tensor + # activation scales shared across all experts (e.g. NVFP4). + # They are broadcast views (stride 0) from .expand() and are + # not actual expert weights, so exclude them from EPLB. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + "w13_input_scale", + "w2_input_scale", + } + + assert all( + weight.is_contiguous() + for name, weight in weights + if not ( + name.startswith("_shared_experts.") + or name.startswith("_gate.") + or name.startswith("_routed_input_transform.") + or name.startswith("_routed_output_transform.") + ) + and name not in NON_EXPERT_WEIGHTS + ) + + return [ + weight.view(layer.local_num_experts, -1) + for name, weight in weights + if name not in NON_EXPERT_WEIGHTS + and weight.shape != torch.Size([]) + and not name.startswith("_shared_experts.") + # exclude parameters from non-expert submodules, + # e.g. gate/shared/transforms. + and not name.startswith("_gate.") + and not name.startswith("_routed_input_transform.") + and not name.startswith("_routed_output_transform.") + ] + + @staticmethod + def make_expert_params_mapping( + model: torch.nn.Module, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + num_redundant_experts: int = 0, + ) -> list[tuple[str, str, int, str]]: + """ + Create expert parameter mapping for weight loading with redundant experts. + + This mapping handles the physical-to-logical expert ID conversion needed + when loading weights with EPLB redundant experts. + + Args: + model: The model containing the MoE layer + ckpt_gate_proj_name: Name of gate projection in checkpoint + ckpt_down_proj_name: Name of down projection in checkpoint + ckpt_up_proj_name: Name of up projection in checkpoint + num_experts: Number of logical (non-redundant) experts + num_redundant_experts: Number of redundant experts + + Returns: + List of tuples (param_name, weight_name, expert_id, shard_id) + where: + - param_name: Parameter name in the layer + - weight_name: Weight name in checkpoint + - expert_id: Physical expert ID + - shard_id: Shard identifier (w1, w2, w3) + """ + num_physical_experts = num_experts + num_redundant_experts + + # In the returned mapping: + # - `expert_id` is the physical expert id + # - `weight_name` contains the weight name of the logical expert + # So that we should map the expert id to logical in `weight_name` + physical_to_logical_map = ( + EplbState.build_initial_global_physical_to_logical_map( + num_experts, num_redundant_experts + ) + ) + + base_layer = ( + "base_layer." + if any(".base_layer." in name for name, _ in model.named_parameters()) + else "" + ) + + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + f"experts.{base_layer}w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else f"experts.{base_layer}w2_", + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.{base_layer}", + expert_id, + shard_id, + ) + for expert_id in range(num_physical_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7174cdd88f25..668a848f3589 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -16,7 +16,6 @@ get_pcp_group, get_tensor_model_parallel_world_size, ) -from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -26,6 +25,7 @@ FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) @@ -342,12 +342,25 @@ def __init__( self.layer_name = prefix self.enable_eplb = enable_eplb - # TODO(bnell): should this be owned by router? - self.eplb_state = EplbLayerState() self.expert_placement_strategy: ExpertPlacementStrategy = ( vllm_config.parallel_config.expert_placement_strategy ) + # Create EPLB manager (always constructed for consistent API) + self.eplb_manager: EplbManager | None = None + if enable_eplb: + self.eplb_manager = EplbManager( + ep_size=self.moe_parallel_config.ep_size, + global_num_experts=self.global_num_experts, + logical_num_experts=self.logical_num_experts, + num_redundant_experts=num_redundant_experts, + ) + else: + # EPLB validation is handled by EplbManager.__init__ + assert not self.use_ep or num_redundant_experts == 0, ( + "Redundant experts are only supported with EPLB." + ) + # ROCm aiter shared experts fusion # AITER only supports gated activations (silu/gelu), so disable it # for non-gated MoE (is_act_and_mul=False) @@ -374,16 +387,6 @@ def __init__( # Determine expert maps if self.use_ep: - if self.enable_eplb: - assert self.global_num_experts % self.ep_size == 0, ( - "EPLB currently only supports even distribution of " - "experts across ranks." - ) - else: - assert num_redundant_experts == 0, ( - "Redundant experts are only supported with EPLB." - ) - self.expert_placement_strategy = determine_expert_placement_strategy( expert_placement_strategy=self.expert_placement_strategy, moe_parallel_config=self.moe_parallel_config, @@ -1435,82 +1438,18 @@ def load_weights( yield param_name def get_expert_weights(self) -> Iterable[torch.Tensor]: - def _maybe_make_contiguous( - name: str, p: torch.nn.Parameter - ) -> torch.nn.Parameter: - """ - In some cases, the last 2 dimensions (the non-expert dimensions) - of the weight scale tensor are transposed. This function - transforms the tensor (view update) so the tensor is contiguous(). - Example: A non-contiguous scale tensor, - `x` of shape (E, 32, 16) and stride (512, 1, 32) is transformed to - `x_` of shape (E, 16, 32) and stride (512, 32, 1). - Note that we specifically use torch.transpose() so `x_` refers - to the same underlying memory. The tensors `x` and `x_`, pointing - to the same underlying memory make this transformation safe in the - context of EPLB. i.e. It is the same memory and just the view - is different. - Note: This function handles the "weight_scale" tensors specifically. - This could however be generalized to handle similar tensors. - """ - if p.ndim != 3: - return p - if p.is_contiguous(): - # Already contiguous. do nothing. - return p - # p is non-contiguous. We only handle the case where the last 2 - # dimensions of the scales tensor is transposed. We can handle - # other cases when they become relevant. - is_transposed_12 = p.stride(1) == 1 and p.stride(2) != 1 - if "weight_scale" not in name or not is_transposed_12: - # do nothing. - return p - - # Do not update the layer parameter as the layer's MoE operations would - # expect the parameter's tensor to the same shape / stride. Instead, - # make a new torch.nn.Parameter that is used just in the context of - # EPLB. - return torch.nn.Parameter( - torch.transpose(p.data, 1, 2), requires_grad=False - ) + """ + Collect expert weights for EPLB load balancing. - weights = list(self.named_parameters()) - weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights] - - # `w13_input_scale` and `w2_input_scale` are global per-tensor - # activation scales shared across all experts (e.g. NVFP4). - # They are broadcast views (stride 0) from .expand() and are - # not actual expert weights, so exclude them from EPLB. - NON_EXPERT_WEIGHTS = { - "e_score_correction_bias", - "w13_input_scale", - "w2_input_scale", - } + Returns weights reshaped as (local_num_experts, -1) for efficient + expert weight swapping during load balancing. - assert all( - weight.is_contiguous() - for name, weight in weights - if not ( - name.startswith("_shared_experts.") - or name.startswith("_gate.") - or name.startswith("_routed_input_transform.") - or name.startswith("_routed_output_transform.") - ) - and name not in NON_EXPERT_WEIGHTS - ) + Delegates to EplbManager. - return [ - weight.view(self.local_num_experts, -1) - for name, weight in weights - if name not in NON_EXPERT_WEIGHTS - and weight.shape != torch.Size([]) - and not name.startswith("_shared_experts.") - # exclude parameters from non-expert submodules, - # e.g. gate/shared/transforms. - and not name.startswith("_gate.") - and not name.startswith("_routed_input_transform.") - and not name.startswith("_routed_output_transform.") - ] + Returns: + Iterable of expert weight tensors + """ + return EplbManager.get_expert_weights(self) def set_eplb_state( self, @@ -1524,10 +1463,22 @@ def set_eplb_state( This is used later in forward pass, where we get the expert mapping and record the load metrics in `expert_load_view`. + + Delegates to EplbManager for state management. + + Args: + moe_layer_idx: Index of this MoE layer + expert_load_view: View into global expert load tracking tensor + logical_to_physical_map: Mapping from logical to physical expert IDs + logical_replica_count: Number of replicas for each logical expert """ - self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx] - self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] - self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx] + if self.eplb_manager is not None: + self.eplb_manager.set_state( + moe_layer_idx, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ) def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: @@ -1570,41 +1521,30 @@ def make_expert_params_mapping( num_experts: int, num_redundant_experts: int = 0, ) -> list[tuple[str, str, int, str]]: - num_physical_experts = num_experts + num_redundant_experts - - # In the returned mapping: - # - `expert_id` is the physical expert id - # - `weight_name` contains the weight name of the logical expert - # So that we should map the expert id to logical in `weight_name` - physical_to_logical_map = ( - EplbState.build_initial_global_physical_to_logical_map( - num_experts, num_redundant_experts - ) - ) + """ + Create expert parameter mapping for weight loading. - base_layer = ( - "base_layer." - if any(".base_layer." in name for name, _ in model.named_parameters()) - else "" - ) + Delegates to EplbManager for proper handling of redundant experts. - return [ - # (param_name, weight_name, expert_id, shard_id) - ( - f"experts.{base_layer}w13_" - if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] - else f"experts.{base_layer}w2_", - f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.{base_layer}", - expert_id, - shard_id, - ) - for expert_id in range(num_physical_experts) - for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] + Args: + model: The model containing the MoE layer + ckpt_gate_proj_name: Name of gate projection in checkpoint + ckpt_down_proj_name: Name of down projection in checkpoint + ckpt_up_proj_name: Name of up projection in checkpoint + num_experts: Number of logical (non-redundant) experts + num_redundant_experts: Number of redundant experts + + Returns: + List of tuples (param_name, weight_name, expert_id, shard_id) + """ + return EplbManager.make_expert_params_mapping( + model, + ckpt_gate_proj_name, + ckpt_down_proj_name, + ckpt_up_proj_name, + num_experts, + num_redundant_experts, + ) @property def hidden_size(self) -> int: From 933147767d3012a623f464b72597c6edc10649d4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 Apr 2026 21:51:18 +0000 Subject: [PATCH 02/10] eplb manager Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 8 ++-- .../layers/fused_moe/router/base_router.py | 47 +++++++++++-------- .../fused_moe/router/custom_routing_router.py | 11 +++-- .../router/fused_topk_bias_router.py | 11 +++-- .../fused_moe/router/fused_topk_router.py | 11 +++-- .../fused_moe/router/grouped_topk_router.py | 11 +++-- .../layers/fused_moe/router/router_factory.py | 31 +++++------- .../router/routing_simulator_router.py | 13 ++--- .../fused_moe/router/zero_expert_router.py | 11 +++-- .../compressed_tensors_moe_w4a8_fp8.py | 4 -- .../compressed_tensors_moe_w4a8_int8.py | 1 - 11 files changed, 80 insertions(+), 79 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 668a848f3589..069215bd4cc3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -341,7 +341,6 @@ def __init__( compilation_config.static_all_moe_layers.append(prefix) self.layer_name = prefix - self.enable_eplb = enable_eplb self.expert_placement_strategy: ExpertPlacementStrategy = ( vllm_config.parallel_config.expert_placement_strategy ) @@ -392,7 +391,7 @@ def __init__( moe_parallel_config=self.moe_parallel_config, num_expert_group=num_expert_group, num_redundant_experts=num_redundant_experts, - enable_eplb=self.enable_eplb, + enable_eplb=enable_eplb, ) self._expert_map: torch.Tensor | None @@ -470,7 +469,7 @@ def __init__( self.router = create_fused_moe_router( top_k=top_k, global_num_experts=self.global_num_experts, - eplb_state=self.eplb_state, + eplb_manager=self.eplb_manager, renormalize=renormalize, use_grouped_topk=use_grouped_topk, num_expert_group=num_expert_group, @@ -480,7 +479,6 @@ def __init__( routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, num_fused_shared_experts=self.num_fused_shared_experts, - enable_eplb=enable_eplb, # TODO(bnell): once we can construct the MK at init time, we # can make this a value. indices_type_getter=lambda: self.quant_method.topk_indices_dtype, @@ -546,7 +544,7 @@ def _get_quant_method() -> FusedMoEMethodBase: "is_act_and_mul=False is supported only for CUDA and ROCm for now" ) - if self.enable_eplb and not self.quant_method.supports_eplb: + if enable_eplb and not self.quant_method.supports_eplb: # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not # contain essential differences, but the current quant API diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 0138eb59c91c..18f19a6eacf1 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -2,16 +2,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod from collections.abc import Callable +from typing import TYPE_CHECKING import torch -from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager + if current_platform.is_cuda_alike(): @triton.jit @@ -148,8 +151,7 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, - enable_eplb: bool = False, + eplb_manager: EplbManager | None = None, # TODO(bnell): Once the MK is constructed at layer init time, we # can make this a plain value instead of a callback. indices_type_getter: Callable[[], torch.dtype | None] | None = None, @@ -159,12 +161,17 @@ def __init__( time, so we need to supply a callback to get it at runtime. This is because the indices type is supplied by modular kernels which are created after MoE layer/router construction. + + Args: + top_k: Number of experts to select per token + global_num_experts: Total number of experts + eplb_manager: Optional EPLB manager for load balancing + indices_type_getter: Optional callback to get indices dtype """ super().__init__() self.top_k = top_k self.global_num_experts = global_num_experts - self.eplb_state = eplb_state - self.enable_eplb = enable_eplb + self.eplb_manager = eplb_manager self.indices_type_getter = indices_type_getter self.capture_fn: Callable[[torch.Tensor], None] | None = None @@ -174,18 +181,19 @@ def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> N def _validate_eplb_state(self) -> None: """Validate that EPLB state is properly initialized if EPLB is enabled.""" - if self.enable_eplb: - if self.eplb_state.expert_load_view is None: + if self.eplb_manager is not None: + eplb_state = self.eplb_manager.state + if eplb_state.expert_load_view is None: raise ValueError("enable_eplb=True requires expert_load_view != None") - if self.eplb_state.logical_to_physical_map is None: + if eplb_state.logical_to_physical_map is None: raise ValueError( "enable_eplb=True requires logical_to_physical_map != None" ) - if self.eplb_state.logical_replica_count is None: + if eplb_state.logical_replica_count is None: raise ValueError( "enable_eplb=True requires logical_replica_count != None" ) - if self.eplb_state.should_record_tensor is None: + if eplb_state.should_record_tensor is None: raise ValueError( "enable_eplb=True requires should_record_tensor != None" ) @@ -198,17 +206,18 @@ def _get_indices_type(self) -> torch.dtype | None: def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor: """Apply EPLB mapping to convert logical expert IDs to physical expert IDs.""" - if self.enable_eplb: - assert self.eplb_state.expert_load_view is not None - assert self.eplb_state.logical_to_physical_map is not None - assert self.eplb_state.logical_replica_count is not None - assert self.eplb_state.should_record_tensor is not None + if self.eplb_manager is not None: + eplb_state = self.eplb_manager.state + assert eplb_state.expert_load_view is not None + assert eplb_state.logical_to_physical_map is not None + assert eplb_state.logical_replica_count is not None + assert eplb_state.should_record_tensor is not None return eplb_map_to_physical_and_record( topk_ids=topk_ids, - logical_to_physical_map=self.eplb_state.logical_to_physical_map, - logical_replica_count=self.eplb_state.logical_replica_count, - expert_load_view=self.eplb_state.expert_load_view, - record_enabled=self.eplb_state.should_record_tensor, + logical_to_physical_map=eplb_state.logical_to_physical_map, + logical_replica_count=eplb_state.logical_replica_count, + expert_load_view=eplb_state.expert_load_view, + record_enabled=eplb_state.should_record_tensor, ) return topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index c1bd7a6993ab..5be080242d10 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from typing import TYPE_CHECKING import torch -from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager + class CustomRoutingRouter(BaseRouter): """Router using a custom user-provided routing function.""" @@ -16,17 +19,15 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, custom_routing_function: Callable, + eplb_manager: EplbManager | None = None, renormalize: bool = True, - enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, - enable_eplb=enable_eplb, + eplb_manager=eplb_manager, indices_type_getter=indices_type_getter, ) self.custom_routing_function = custom_routing_function diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 84eaad7f65e6..35ef95a85199 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable +from typing import TYPE_CHECKING import torch import torch.nn.functional as F @@ -9,13 +10,15 @@ import vllm._custom_ops as ops import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops -from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, get_routing_method_type, ) from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager + def vllm_topk_softmax( topk_weights: torch.Tensor, @@ -235,11 +238,10 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, e_score_correction_bias: torch.Tensor | None = None, renormalize: bool = True, routed_scaling_factor: float = 1.0, - enable_eplb: bool = False, + eplb_manager: EplbManager | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, *, scoring_func: str = "sigmoid", @@ -248,8 +250,7 @@ def __init__( super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, - enable_eplb=enable_eplb, + eplb_manager=eplb_manager, indices_type_getter=indices_type_getter, ) self.e_score_correction_bias = e_score_correction_bias diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py index 45311dba08e3..7c540bdd46ad 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py @@ -1,18 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from typing import TYPE_CHECKING import torch import vllm._custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops -from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, get_routing_method_type, ) from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager + def vllm_topk_softmax( topk_weights: torch.Tensor, @@ -120,17 +123,15 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, scoring_func: str = "softmax", renormalize: bool = True, - enable_eplb: bool = False, + eplb_manager: EplbManager | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, - enable_eplb=enable_eplb, + eplb_manager=eplb_manager, indices_type_getter=indices_type_getter, ) self.renormalize = renormalize diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py index 74c3a62a1f11..b622f3bc7f57 100644 --- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable from functools import partial +from typing import TYPE_CHECKING import torch from vllm import _custom_ops as ops from vllm import envs as envs from vllm._aiter_ops import rocm_aiter_ops -from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, @@ -25,6 +25,9 @@ from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager + def fused_grouped_topk( hidden_states: torch.Tensor, @@ -251,7 +254,6 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, num_expert_group: int, topk_group: int, renormalize: bool = True, @@ -259,14 +261,13 @@ def __init__( routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, num_fused_shared_experts: int = 0, - enable_eplb: bool = False, + eplb_manager: EplbManager | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, - enable_eplb=enable_eplb, + eplb_manager=eplb_manager, indices_type_getter=indices_type_getter, ) self.num_expert_group = num_expert_group diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index da7896de6159..718f734ac43a 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from typing import TYPE_CHECKING import torch import vllm.envs as envs -from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import RoutingMethodType + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( CustomRoutingRouter, ) @@ -29,8 +32,6 @@ ZeroExpertRouter, ) -EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState() - def create_fused_moe_router( # common parameters @@ -50,8 +51,7 @@ def create_fused_moe_router( # custom routing parameters custom_routing_function: Callable | None = None, # eplb parameters - enable_eplb: bool = False, - eplb_state: EplbLayerState = EMPTY_EPLB_STATE, + eplb_manager: EplbManager | None = None, # zero expert parameters zero_expert_type: str | None = None, num_logical_experts: int | None = None, @@ -91,8 +91,7 @@ def create_fused_moe_router( custom_routing_function: Optional custom routing function EPLB arguments: - enable_eplb: Whether EPLB is enabled - eplb_state: EPLB (Expert Parallelism Load Balancing) state + eplb_manager: Optional EPLB (Expert Parallelism Load Balancing) manager Zero expert arguments: zero_expert_type: Type of zero expert (e.g. identity). If not None, @@ -112,8 +111,7 @@ def create_fused_moe_router( return RoutingSimulatorRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, - enable_eplb=enable_eplb, + eplb_manager=eplb_manager, indices_type_getter=indices_type_getter, ) @@ -127,14 +125,13 @@ def create_fused_moe_router( return ZeroExpertRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, + eplb_manager=eplb_manager, e_score_correction_bias=e_score_correction_bias, num_logical_experts=num_logical_experts, zero_expert_type=zero_expert_type, scoring_func=scoring_func, renormalize=renormalize, routed_scaling_factor=routed_scaling_factor, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) @@ -148,7 +145,7 @@ def create_fused_moe_router( grouped_topk_router = GroupedTopKRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, + eplb_manager=eplb_manager, num_expert_group=num_expert_group, topk_group=topk_group, renormalize=renormalize, @@ -156,7 +153,6 @@ def create_fused_moe_router( routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, num_fused_shared_experts=num_fused_shared_experts, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) if ( @@ -176,10 +172,9 @@ def create_fused_moe_router( return CustomRoutingRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, + eplb_manager=eplb_manager, custom_routing_function=custom_routing_function, renormalize=renormalize, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) @@ -189,11 +184,10 @@ def create_fused_moe_router( return FusedTopKBiasRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, + eplb_manager=eplb_manager, e_score_correction_bias=e_score_correction_bias, renormalize=renormalize, routed_scaling_factor=routed_scaling_factor, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, scoring_func=scoring_func, hash_indices_table=hash_indices_table, @@ -202,9 +196,8 @@ def create_fused_moe_router( return FusedTopKRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, + eplb_manager=eplb_manager, renormalize=renormalize, scoring_func=scoring_func, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) diff --git a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py index 8fb36b72cb70..2db45f581634 100644 --- a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py +++ b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py @@ -2,16 +2,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import torch import vllm.envs as envs -from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager + + logger = init_logger(__name__) @@ -313,15 +316,13 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, - enable_eplb: bool = False, + eplb_manager: EplbManager | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, - enable_eplb=enable_eplb, + eplb_manager=eplb_manager, indices_type_getter=indices_type_getter, ) diff --git a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py index 65760727770a..d8057e4300c3 100644 --- a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py +++ b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from typing import TYPE_CHECKING import torch -from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, get_routing_method_type, @@ -18,6 +18,9 @@ fused_topk_bias, ) +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager + class ZeroExpertRouter(BaseRouter): """Router that handles zero expert computation as part of routing. @@ -32,21 +35,19 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, e_score_correction_bias: torch.Tensor, num_logical_experts: int, zero_expert_type: str, scoring_func: str = "softmax", renormalize: bool = False, routed_scaling_factor: float = 1.0, - enable_eplb: bool = False, + eplb_manager: EplbManager | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_state=eplb_state, - enable_eplb=enable_eplb, + eplb_manager=eplb_manager, indices_type_getter=indices_type_getter, ) self.e_score_correction_bias = e_score_correction_bias diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py index b14571fe5013..efa28ac3b6ae 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py @@ -309,10 +309,6 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: - if layer.enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." - ) assert self.moe_quant_config is not None from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py index 88cdbadd3f83..c697b137420b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py @@ -307,7 +307,6 @@ def apply_monolithic( router_logits: torch.Tensor, input_ids: torch.Tensor | None = None, ) -> torch.Tensor: - assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." assert layer.activation in ( MoEActivation.SILU, MoEActivation.SWIGLUOAI, From 34988207a9faab704fd0ffa9316b8dcd6f8ec982 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 28 Apr 2026 19:52:00 +0000 Subject: [PATCH 03/10] fix Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe_layer.py | 2 -- vllm/model_executor/layers/fused_moe/router/base_router.py | 5 +---- .../layers/fused_moe/router/custom_routing_router.py | 2 +- .../layers/fused_moe/router/fused_topk_bias_router.py | 2 +- .../layers/fused_moe/router/fused_topk_router.py | 2 +- .../layers/fused_moe/router/grouped_topk_router.py | 2 +- .../model_executor/layers/fused_moe/router/router_factory.py | 5 +---- .../layers/fused_moe/router/routing_simulator_router.py | 2 +- .../layers/fused_moe/router/zero_expert_router.py | 2 +- 9 files changed, 8 insertions(+), 16 deletions(-) diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index 89e28d950f9d..243ace519fb8 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -1004,7 +1004,6 @@ def make_fake_moe_layer( activation: str = "silu", indices_type: torch.dtype | None = None, expert_map: torch.Tensor | None = None, - enable_eplb: bool = False, expert_load_view: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, @@ -1032,7 +1031,6 @@ def make_fake_moe_layer( routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, num_fused_shared_experts=0, # TODO - enable_eplb=enable_eplb, # TODO(bnell): once we can construct the MK at init time, we # can make this a value. indices_type_getter=lambda: indices_type, diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 18f19a6eacf1..e32816b395b9 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -2,19 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod from collections.abc import Callable -from typing import TYPE_CHECKING import torch +from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager - if current_platform.is_cuda_alike(): @triton.jit diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index 5be080242d10..41385c940040 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -20,7 +20,7 @@ def __init__( top_k: int, global_num_experts: int, custom_routing_function: Callable, - eplb_manager: EplbManager | None = None, + eplb_manager: "EplbManager | None" = None, renormalize: bool = True, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 35ef95a85199..6d3bd6ac5529 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -241,7 +241,7 @@ def __init__( e_score_correction_bias: torch.Tensor | None = None, renormalize: bool = True, routed_scaling_factor: float = 1.0, - eplb_manager: EplbManager | None = None, + eplb_manager: "EplbManager | None" = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, *, scoring_func: str = "sigmoid", diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py index 7c540bdd46ad..d88786491d7b 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py @@ -125,7 +125,7 @@ def __init__( global_num_experts: int, scoring_func: str = "softmax", renormalize: bool = True, - eplb_manager: EplbManager | None = None, + eplb_manager: "EplbManager | None" = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py index b622f3bc7f57..461c5c351f05 100644 --- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -261,7 +261,7 @@ def __init__( routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, num_fused_shared_experts: int = 0, - eplb_manager: EplbManager | None = None, + eplb_manager: "EplbManager | None" = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 718f734ac43a..89592830b23b 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -1,15 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -from typing import TYPE_CHECKING import torch import vllm.envs as envs from vllm.model_executor.layers.fused_moe.config import RoutingMethodType - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager +from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( CustomRoutingRouter, ) diff --git a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py index 2db45f581634..7d0b8ba8b61a 100644 --- a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py +++ b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py @@ -316,7 +316,7 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_manager: EplbManager | None = None, + eplb_manager: "EplbManager | None" = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( diff --git a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py index d8057e4300c3..d61056026c01 100644 --- a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py +++ b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py @@ -41,7 +41,7 @@ def __init__( scoring_func: str = "softmax", renormalize: bool = False, routed_scaling_factor: float = 1.0, - eplb_manager: EplbManager | None = None, + eplb_manager: "EplbManager | None" = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( From 692912ed475ce91a3f080419d7662418ec5def2f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 28 Apr 2026 20:04:53 +0000 Subject: [PATCH 04/10] fix Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index 243ace519fb8..b79fa925c2cb 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -1254,7 +1254,7 @@ def _test_body_eplb( ), ) - eplb_moe_layer.eplb_state.should_record_tensor = torch.ones( + eplb_moe_layer.eplb_manager.state.should_record_tensor = torch.ones( (), dtype=torch.bool, device=device ) From 90c74a86f07c19ed2e78f48447dcc83e0c1e7080 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 5 May 2026 15:12:28 +0000 Subject: [PATCH 05/10] move mapping fn back to FusedMoE Signed-off-by: Bill Nell --- .../layers/fused_moe/eplb_manager.py | 69 +------------------ vllm/model_executor/layers/fused_moe/layer.py | 65 +++++++++-------- 2 files changed, 39 insertions(+), 95 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/eplb_manager.py b/vllm/model_executor/layers/fused_moe/eplb_manager.py index 68137ccb01bc..d444664602e6 100644 --- a/vllm/model_executor/layers/fused_moe/eplb_manager.py +++ b/vllm/model_executor/layers/fused_moe/eplb_manager.py @@ -13,7 +13,7 @@ import torch -from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState +from vllm.distributed.eplb.eplb_state import EplbLayerState class EplbManager: @@ -177,70 +177,3 @@ def _maybe_make_contiguous( and not name.startswith("_routed_input_transform.") and not name.startswith("_routed_output_transform.") ] - - @staticmethod - def make_expert_params_mapping( - model: torch.nn.Module, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, - num_redundant_experts: int = 0, - ) -> list[tuple[str, str, int, str]]: - """ - Create expert parameter mapping for weight loading with redundant experts. - - This mapping handles the physical-to-logical expert ID conversion needed - when loading weights with EPLB redundant experts. - - Args: - model: The model containing the MoE layer - ckpt_gate_proj_name: Name of gate projection in checkpoint - ckpt_down_proj_name: Name of down projection in checkpoint - ckpt_up_proj_name: Name of up projection in checkpoint - num_experts: Number of logical (non-redundant) experts - num_redundant_experts: Number of redundant experts - - Returns: - List of tuples (param_name, weight_name, expert_id, shard_id) - where: - - param_name: Parameter name in the layer - - weight_name: Weight name in checkpoint - - expert_id: Physical expert ID - - shard_id: Shard identifier (w1, w2, w3) - """ - num_physical_experts = num_experts + num_redundant_experts - - # In the returned mapping: - # - `expert_id` is the physical expert id - # - `weight_name` contains the weight name of the logical expert - # So that we should map the expert id to logical in `weight_name` - physical_to_logical_map = ( - EplbState.build_initial_global_physical_to_logical_map( - num_experts, num_redundant_experts - ) - ) - - base_layer = ( - "base_layer." - if any(".base_layer." in name for name, _ in model.named_parameters()) - else "" - ) - - return [ - # (param_name, weight_name, expert_id, shard_id) - ( - f"experts.{base_layer}w13_" - if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] - else f"experts.{base_layer}w2_", - f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.{base_layer}", - expert_id, - shard_id, - ) - for expert_id in range(num_physical_experts) - for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 069215bd4cc3..086d5a93be56 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -16,6 +16,7 @@ get_pcp_group, get_tensor_model_parallel_world_size, ) +from vllm.distributed.eplb.eplb_state import EplbState from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -539,9 +540,11 @@ def _get_quant_method() -> FusedMoEMethodBase: # for heuristic purposes, so it must be initialized first. self.quant_method: FusedMoEMethodBase = _get_quant_method() - if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike(): + if not self.moe_config.is_act_and_mul and not ( + current_platform.is_cuda_alike() or current_platform.is_xpu() + ): raise NotImplementedError( - "is_act_and_mul=False is supported only for CUDA and ROCm for now" + "is_act_and_mul=False is supported only for CUDA and XPU for now" ) if enable_eplb and not self.quant_method.supports_eplb: @@ -1104,9 +1107,6 @@ def weight_loader( return_success: bool = False, ) -> bool | None: quant_config_name = self.quant_config and self.quant_config.get_name() - if quant_config_name == "humming": - assert hasattr(self.quant_method, "weight_schema") - quant_config_name = self.quant_method.weight_schema.quant_method if quant_config_name == "gpt_oss_mxfp4": # (FIXME) for gpt-oss all experts are combined if "bias" in weight_name: @@ -1519,31 +1519,42 @@ def make_expert_params_mapping( num_experts: int, num_redundant_experts: int = 0, ) -> list[tuple[str, str, int, str]]: - """ - Create expert parameter mapping for weight loading. - - Delegates to EplbManager for proper handling of redundant experts. - - Args: - model: The model containing the MoE layer - ckpt_gate_proj_name: Name of gate projection in checkpoint - ckpt_down_proj_name: Name of down projection in checkpoint - ckpt_up_proj_name: Name of up projection in checkpoint - num_experts: Number of logical (non-redundant) experts - num_redundant_experts: Number of redundant experts + num_physical_experts = num_experts + num_redundant_experts + + # In the returned mapping: + # - `expert_id` is the physical expert id + # - `weight_name` contains the weight name of the logical expert + # So that we should map the expert id to logical in `weight_name` + physical_to_logical_map = ( + EplbState.build_initial_global_physical_to_logical_map( + num_experts, num_redundant_experts + ) + ) - Returns: - List of tuples (param_name, weight_name, expert_id, shard_id) - """ - return EplbManager.make_expert_params_mapping( - model, - ckpt_gate_proj_name, - ckpt_down_proj_name, - ckpt_up_proj_name, - num_experts, - num_redundant_experts, + base_layer = ( + "base_layer." + if any(".base_layer." in name for name, _ in model.named_parameters()) + else "" ) + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + f"experts.{base_layer}w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else f"experts.{base_layer}w2_", + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.{base_layer}", + expert_id, + shard_id, + ) + for expert_id in range(num_physical_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + @property def hidden_size(self) -> int: return self.moe_config.hidden_dim From 0780907d3a6d9a51e84dfebc62e0e8d4cce01516 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 6 May 2026 18:20:43 +0000 Subject: [PATCH 06/10] review comments + redo stuff Signed-off-by: Bill Nell --- .../test_eplb_fused_moe_layer_dep_nvfp4.py | 3 +- tests/kernels/moe/test_moe_layer.py | 3 +- tests/kernels/moe/test_routing.py | 11 +- .../test_routed_experts_capture.py | 5 +- vllm/distributed/eplb/eplb_state.py | 11 ++ .../layers/fused_moe/eplb_manager.py | 178 ------------------ vllm/model_executor/layers/fused_moe/layer.py | 105 ++++++++--- .../layers/fused_moe/router/base_router.py | 16 +- .../fused_moe/router/custom_routing_router.py | 9 +- .../router/fused_topk_bias_router.py | 9 +- .../fused_moe/router/fused_topk_router.py | 9 +- .../fused_moe/router/grouped_topk_router.py | 9 +- .../layers/fused_moe/router/router_factory.py | 18 +- .../router/routing_simulator_router.py | 11 +- .../fused_moe/router/zero_expert_router.py | 9 +- .../layers/quantization/modelopt.py | 2 +- 16 files changed, 141 insertions(+), 267 deletions(-) delete mode 100644 vllm/model_executor/layers/fused_moe/eplb_manager.py diff --git a/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py b/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py index 68b2407c2e4b..9ab785af3135 100644 --- a/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py +++ b/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py @@ -10,6 +10,7 @@ from tests.kernels.moe.utils import make_test_quant_config from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace from vllm.distributed.parallel_state import ( ensure_model_parallel_initialized, @@ -201,7 +202,7 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig): dtype=torch.int32, device=device, ) - fml.enable_eplb = True + fml.eplb_state = EplbLayerState() fml.set_eplb_state( lidx, torch.zeros( diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index b79fa925c2cb..2b27202b6b6f 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -1021,7 +1021,6 @@ def make_fake_moe_layer( router = create_fused_moe_router( top_k=top_k, global_num_experts=global_num_experts, - # eplb_state=None, # TODO renormalize=renormalize, use_grouped_topk=use_grouped_topk, num_expert_group=num_expert_group, @@ -1254,7 +1253,7 @@ def _test_body_eplb( ), ) - eplb_moe_layer.eplb_manager.state.should_record_tensor = torch.ones( + eplb_moe_layer.eplb_state.should_record_tensor = torch.ones( (), dtype=torch.bool, device=device ) diff --git a/tests/kernels/moe/test_routing.py b/tests/kernels/moe/test_routing.py index 90a6cd841efd..41dea8121938 100644 --- a/tests/kernels/moe/test_routing.py +++ b/tests/kernels/moe/test_routing.py @@ -36,9 +36,11 @@ def _is_aiter_capable() -> bool: NUM_EXPERTS = [8, 16, 64] -def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerState: +def setup_eplb_state( + enable_eplb: bool, global_num_experts: int +) -> EplbLayerState | None: if not enable_eplb: - return EplbLayerState() + return None # Initialize EPLB state with proper tensors for testing # For testing purposes, we use a simple 1:1 mapping (no redundant experts) @@ -349,7 +351,6 @@ def test_fused_topk( top_k=top_k, global_num_experts=global_num_experts, renormalize=renormalize, - enable_eplb=enable_eplb, eplb_state=eplb_state, ) @@ -400,7 +401,6 @@ def test_fused_topk_bias( top_k=top_k, global_num_experts=global_num_experts, renormalize=renormalize, - enable_eplb=enable_eplb, eplb_state=eplb_state, ) @@ -469,7 +469,6 @@ def test_grouped_topk( top_k=top_k, global_num_experts=global_num_experts, renormalize=renormalize, - enable_eplb=enable_eplb, eplb_state=eplb_state, ) @@ -540,7 +539,6 @@ def test_custom( global_num_experts=global_num_experts, custom_routing_function=custom_routing_function, renormalize=renormalize, - enable_eplb=enable_eplb, eplb_state=eplb_state, ) @@ -580,7 +578,6 @@ def test_custom( # router = create_fused_moe_router( # top_k=top_k, # global_num_experts=global_num_experts, -# enable_eplb=enable_eplb, # eplb_state=eplb_state, # ) diff --git a/tests/model_executor/test_routed_experts_capture.py b/tests/model_executor/test_routed_experts_capture.py index 0527417d1506..656661ee2b24 100644 --- a/tests/model_executor/test_routed_experts_capture.py +++ b/tests/model_executor/test_routed_experts_capture.py @@ -57,8 +57,7 @@ def _make_router() -> DummyRouter: return DummyRouter( top_k=2, global_num_experts=16, - eplb_state=EplbLayerState(), - enable_eplb=False, + eplb_state=None, indices_type_getter=None, ) @@ -84,7 +83,7 @@ def capture_fn(ids): def test_base_router_capture_with_eplb_enabled(): router = _make_router() - router.enable_eplb = True + router.eplb_state = EplbLayerState() router.eplb_state.expert_load_view = torch.zeros(32, dtype=torch.int64) router.eplb_state.logical_to_physical_map = torch.arange(32).view(32, 1) router.eplb_state.logical_replica_count = torch.ones(32, dtype=torch.int64) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 1da39caccd80..319a5f22c922 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -940,6 +940,17 @@ class EplbLayerState: GPU work. """ + def set_layer_state( + self, + moe_layer_idx: int, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + self.expert_load_view = expert_load_view[moe_layer_idx] + self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.logical_replica_count = logical_replica_count[moe_layer_idx] + def _node_count_with_rank_mapping( pg: ProcessGroup | StatelessProcessGroup, diff --git a/vllm/model_executor/layers/fused_moe/eplb_manager.py b/vllm/model_executor/layers/fused_moe/eplb_manager.py deleted file mode 100644 index a5f349472e66..000000000000 --- a/vllm/model_executor/layers/fused_moe/eplb_manager.py +++ /dev/null @@ -1,178 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -""" -EPLB (Expert Parallelism Load Balancing) Manager. - -This module provides the EplbManager class which encapsulates all EPLB-related -functionality for MoE layers, including state management, expert weight -collection, and expert parameter mapping. -""" - -from collections.abc import Iterable - -import torch - -from vllm.distributed.eplb.eplb_state import EplbLayerState - - -class EplbManager: - """ - Manages Expert Parallelism Load Balancing (EPLB) state and operations - for a MoE layer. - - This class encapsulates all EPLB-related functionality including: - - Runtime state (expert load view, logical-to-physical mapping) - - Expert weight collection for load balancing - - Expert parameter mapping for weight loading with redundant experts - - Validation of EPLB configuration constraints - """ - - def __init__( - self, - ep_size: int, - global_num_experts: int, - logical_num_experts: int, - num_redundant_experts: int = 0, - ): - """ - Initialize EPLB manager. - - Args: - ep_size: Expert parallel world size - global_num_experts: Total number of experts (including redundant) - logical_num_experts: Number of logical (non-redundant) experts - num_redundant_experts: Number of redundant experts - """ - self.ep_size = ep_size - self.global_num_experts = global_num_experts - self.logical_num_experts = logical_num_experts - self.num_redundant_experts = num_redundant_experts - - # Runtime EPLB state. - self.state = EplbLayerState() - - # Validate EPLB configuration. - # EPLB currently only supports even distribution of experts across ranks - if self.global_num_experts % self.ep_size != 0: - raise ValueError( - f"EPLB currently only supports even distribution of " - f"experts across ranks. Got {self.global_num_experts} experts " - f"and {self.ep_size} EP ranks." - ) - - def set_state( - self, - moe_layer_idx: int, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - """ - Register the EPLB state for this layer. - - This is used later in forward pass, where we get the expert mapping - and record the load metrics in `expert_load_view`. - - Args: - moe_layer_idx: Index of this MoE layer - expert_load_view: View into global expert load tracking tensor - logical_to_physical_map: Mapping from logical to physical expert IDs - logical_replica_count: Number of replicas for each logical expert - """ - self.state.expert_load_view = expert_load_view[moe_layer_idx] - self.state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] - self.state.logical_replica_count = logical_replica_count[moe_layer_idx] - - @staticmethod - def get_expert_weights( - layer: torch.nn.Module, # FusedMoE - ) -> Iterable[torch.Tensor]: - """ - Collect expert weights from the MoE layer for EPLB. - - Returns weights reshaped as (local_num_experts, -1) for efficient - expert weight swapping during load balancing. - - Args: - layer: The FusedMoE layer to collect weights from - - Returns: - Iterable of expert weight tensors - """ - - def _maybe_make_contiguous( - name: str, p: torch.nn.Parameter - ) -> torch.nn.Parameter: - """ - In some cases, the last 2 dimensions (the non-expert dimensions) - of the weight scale tensor are transposed. This function - transforms the tensor (view update) so the tensor is contiguous(). - Example: A non-contiguous scale tensor, - `x` of shape (E, 32, 16) and stride (512, 1, 32) is transformed to - `x_` of shape (E, 16, 32) and stride (512, 32, 1). - Note that we specifically use torch.transpose() so `x_` refers - to the same underlying memory. The tensors `x` and `x_`, pointing - to the same underlying memory make this transformation safe in the - context of EPLB. i.e. It is the same memory and just the view - is different. - Note: This function handles the "weight_scale" tensors specifically. - This could however be generalized to handle similar tensors. - """ - if p.ndim != 3: - return p - if p.is_contiguous(): - # Already contiguous. do nothing. - return p - # p is non-contiguous. We only handle the case where the last 2 - # dimensions of the scales tensor is transposed. We can handle - # other cases when they become relevant. - is_transposed_12 = p.stride(1) == 1 and p.stride(2) != 1 - if "weight_scale" not in name or not is_transposed_12: - # do nothing. - return p - - # Do not update the layer parameter as the layer's MoE operations would - # expect the parameter's tensor to the same shape / stride. Instead, - # make a new torch.nn.Parameter that is used just in the context of - # EPLB. - return torch.nn.Parameter( - torch.transpose(p.data, 1, 2), requires_grad=False - ) - - weights = list(layer.named_parameters()) - weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights] - - # `w13_input_scale` and `w2_input_scale` are global per-tensor - # activation scales shared across all experts (e.g. NVFP4). - # They are broadcast views (stride 0) from .expand() and are - # not actual expert weights, so exclude them from EPLB. - NON_EXPERT_WEIGHTS = { - "e_score_correction_bias", - "w13_input_scale", - "w2_input_scale", - } - - # Parameters of non-expert submodules that live inside runner (MoERunner). - # These must be excluded from EPLB weight rearrangement. - NON_EXPERT_PREFIXES = ( - "runner._shared_experts.", - "runner.gate.", - "runner.routed_input_transform.", - "runner.routed_output_transform.", - ) - - assert all( - weight.is_contiguous() - for name, weight in weights - if not name.startswith(NON_EXPERT_PREFIXES) - and name not in NON_EXPERT_WEIGHTS - ) - - return [ - weight.view(layer.local_num_experts, -1) - for name, weight in weights - if name not in NON_EXPERT_WEIGHTS - and weight.shape != torch.Size([]) - and not name.startswith(NON_EXPERT_PREFIXES) - ] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 086d5a93be56..37fbfe5c658c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -16,7 +16,7 @@ get_pcp_group, get_tensor_model_parallel_world_size, ) -from vllm.distributed.eplb.eplb_state import EplbState +from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -26,7 +26,6 @@ FusedMoEQuantConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) @@ -347,14 +346,15 @@ def __init__( ) # Create EPLB manager (always constructed for consistent API) - self.eplb_manager: EplbManager | None = None + self.eplb_state: EplbLayerState | None = None if enable_eplb: - self.eplb_manager = EplbManager( - ep_size=self.moe_parallel_config.ep_size, - global_num_experts=self.global_num_experts, - logical_num_experts=self.logical_num_experts, - num_redundant_experts=num_redundant_experts, - ) + if self.global_num_experts % self.ep_size != 0: + raise ValueError( + f"EPLB currently only supports even distribution of " + f"experts across ranks. Got {self.global_num_experts} experts " + f"and {self.ep_size} EP ranks." + ) + self.eplb_state = EplbLayerState() else: # EPLB validation is handled by EplbManager.__init__ assert not self.use_ep or num_redundant_experts == 0, ( @@ -470,7 +470,7 @@ def __init__( self.router = create_fused_moe_router( top_k=top_k, global_num_experts=self.global_num_experts, - eplb_manager=self.eplb_manager, + eplb_state=self.eplb_state, renormalize=renormalize, use_grouped_topk=use_grouped_topk, num_expert_group=num_expert_group, @@ -1436,18 +1436,81 @@ def load_weights( yield param_name def get_expert_weights(self) -> Iterable[torch.Tensor]: - """ - Collect expert weights for EPLB load balancing. + def _maybe_make_contiguous( + name: str, p: torch.nn.Parameter + ) -> torch.nn.Parameter: + """ + In some cases, the last 2 dimensions (the non-expert dimensions) + of the weight scale tensor are transposed. This function + transforms the tensor (view update) so the tensor is contiguous(). + Example: A non-contiguous scale tensor, + `x` of shape (E, 32, 16) and stride (512, 1, 32) is transformed to + `x_` of shape (E, 16, 32) and stride (512, 32, 1). + Note that we specifically use torch.transpose() so `x_` refers + to the same underlying memory. The tensors `x` and `x_`, pointing + to the same underlying memory make this transformation safe in the + context of EPLB. i.e. It is the same memory and just the view + is different. + Note: This function handles the "weight_scale" tensors specifically. + This could however be generalized to handle similar tensors. + """ + if p.ndim != 3: + return p + if p.is_contiguous(): + # Already contiguous. do nothing. + return p + # p is non-contiguous. We only handle the case where the last 2 + # dimensions of the scales tensor is transposed. We can handle + # other cases when they become relevant. + is_transposed_12 = p.stride(1) == 1 and p.stride(2) != 1 + if "weight_scale" not in name or not is_transposed_12: + # do nothing. + return p + + # Do not update the layer parameter as the layer's MoE operations would + # expect the parameter's tensor to the same shape / stride. Instead, + # make a new torch.nn.Parameter that is used just in the context of + # EPLB. + return torch.nn.Parameter( + torch.transpose(p.data, 1, 2), requires_grad=False + ) - Returns weights reshaped as (local_num_experts, -1) for efficient - expert weight swapping during load balancing. + weights = list(self.named_parameters()) + weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights] + + # `w13_input_scale` and `w2_input_scale` are global per-tensor + # activation scales shared across all experts (e.g. NVFP4). + # They are broadcast views (stride 0) from .expand() and are + # not actual expert weights, so exclude them from EPLB. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + "w13_input_scale", + "w2_input_scale", + } - Delegates to EplbManager. + # Parameters of non-expert submodules that live inside runner (MoERunner). + # These must be excluded from EPLB weight rearrangement. + NON_EXPERT_PREFIXES = ( + "runner._shared_experts.", + "runner.gate.", + "runner.routed_input_transform.", + "runner.routed_output_transform.", + ) - Returns: - Iterable of expert weight tensors - """ - return EplbManager.get_expert_weights(self) + assert all( + weight.is_contiguous() + for name, weight in weights + if not name.startswith(NON_EXPERT_PREFIXES) + and name not in NON_EXPERT_WEIGHTS + ) + + return [ + weight.view(self.local_num_experts, -1) + for name, weight in weights + if name not in NON_EXPERT_WEIGHTS + and weight.shape != torch.Size([]) + and not name.startswith(NON_EXPERT_PREFIXES) + ] def set_eplb_state( self, @@ -1470,8 +1533,8 @@ def set_eplb_state( logical_to_physical_map: Mapping from logical to physical expert IDs logical_replica_count: Number of replicas for each logical expert """ - if self.eplb_manager is not None: - self.eplb_manager.set_state( + if self.eplb_state is not None: + self.eplb_state.set_layer_state( moe_layer_idx, expert_load_view, logical_to_physical_map, diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index e32816b395b9..ff023c1cd19c 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -5,7 +5,7 @@ import torch -from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) @@ -148,7 +148,7 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_manager: EplbManager | None = None, + eplb_state: EplbLayerState | None = None, # TODO(bnell): Once the MK is constructed at layer init time, we # can make this a plain value instead of a callback. indices_type_getter: Callable[[], torch.dtype | None] | None = None, @@ -162,13 +162,13 @@ def __init__( Args: top_k: Number of experts to select per token global_num_experts: Total number of experts - eplb_manager: Optional EPLB manager for load balancing + eplb_state: Optional EPLBLayerState for load balancing indices_type_getter: Optional callback to get indices dtype """ super().__init__() self.top_k = top_k self.global_num_experts = global_num_experts - self.eplb_manager = eplb_manager + self.eplb_state = eplb_state self.indices_type_getter = indices_type_getter self.capture_fn: Callable[[torch.Tensor], None] | None = None @@ -178,8 +178,8 @@ def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> N def _validate_eplb_state(self) -> None: """Validate that EPLB state is properly initialized if EPLB is enabled.""" - if self.eplb_manager is not None: - eplb_state = self.eplb_manager.state + if self.eplb_state is not None: + eplb_state = self.eplb_state if eplb_state.expert_load_view is None: raise ValueError("enable_eplb=True requires expert_load_view != None") if eplb_state.logical_to_physical_map is None: @@ -203,8 +203,8 @@ def _get_indices_type(self) -> torch.dtype | None: def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor: """Apply EPLB mapping to convert logical expert IDs to physical expert IDs.""" - if self.eplb_manager is not None: - eplb_state = self.eplb_manager.state + if self.eplb_state is not None: + eplb_state = self.eplb_state assert eplb_state.expert_load_view is not None assert eplb_state.logical_to_physical_map is not None assert eplb_state.logical_replica_count is not None diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index 00e3e7520031..6983a385a0af 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -1,16 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -from typing import TYPE_CHECKING import torch +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager - class CustomRoutingRouter(BaseRouter): """Router using a custom user-provided routing function.""" @@ -20,14 +17,14 @@ def __init__( top_k: int, global_num_experts: int, custom_routing_function: Callable, - eplb_manager: "EplbManager | None" = None, + eplb_state: EplbLayerState | None = None, renormalize: bool = True, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, indices_type_getter=indices_type_getter, ) self.custom_routing_function = custom_routing_function diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 6d3bd6ac5529..0ca5c3f97952 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable -from typing import TYPE_CHECKING import torch import torch.nn.functional as F @@ -10,15 +9,13 @@ import vllm._custom_ops as ops import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, get_routing_method_type, ) from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager - def vllm_topk_softmax( topk_weights: torch.Tensor, @@ -241,7 +238,7 @@ def __init__( e_score_correction_bias: torch.Tensor | None = None, renormalize: bool = True, routed_scaling_factor: float = 1.0, - eplb_manager: "EplbManager | None" = None, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, *, scoring_func: str = "sigmoid", @@ -250,7 +247,7 @@ def __init__( super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, indices_type_getter=indices_type_getter, ) self.e_score_correction_bias = e_score_correction_bias diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py index d88786491d7b..a4800eabb908 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py @@ -1,21 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -from typing import TYPE_CHECKING import torch import vllm._custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, get_routing_method_type, ) from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager - def vllm_topk_softmax( topk_weights: torch.Tensor, @@ -125,13 +122,13 @@ def __init__( global_num_experts: int, scoring_func: str = "softmax", renormalize: bool = True, - eplb_manager: "EplbManager | None" = None, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, indices_type_getter=indices_type_getter, ) self.renormalize = renormalize diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py index 461c5c351f05..77624a1b9077 100644 --- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable from functools import partial -from typing import TYPE_CHECKING import torch from vllm import _custom_ops as ops from vllm import envs as envs from vllm._aiter_ops import rocm_aiter_ops +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, @@ -25,9 +25,6 @@ from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager - def fused_grouped_topk( hidden_states: torch.Tensor, @@ -261,13 +258,13 @@ def __init__( routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, num_fused_shared_experts: int = 0, - eplb_manager: "EplbManager | None" = None, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, indices_type_getter=indices_type_getter, ) self.num_expert_group = num_expert_group diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 89592830b23b..debcf17edaa3 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -5,8 +5,8 @@ import torch import vllm.envs as envs +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import RoutingMethodType -from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( CustomRoutingRouter, ) @@ -48,7 +48,7 @@ def create_fused_moe_router( # custom routing parameters custom_routing_function: Callable | None = None, # eplb parameters - eplb_manager: EplbManager | None = None, + eplb_state: EplbLayerState | None = None, # zero expert parameters zero_expert_type: str | None = None, num_logical_experts: int | None = None, @@ -88,7 +88,7 @@ def create_fused_moe_router( custom_routing_function: Optional custom routing function EPLB arguments: - eplb_manager: Optional EPLB (Expert Parallelism Load Balancing) manager + eplb_state: Optional EplbLayerState, None when EPLB is disabled. Zero expert arguments: zero_expert_type: Type of zero expert (e.g. identity). If not None, @@ -108,7 +108,7 @@ def create_fused_moe_router( return RoutingSimulatorRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, indices_type_getter=indices_type_getter, ) @@ -122,7 +122,7 @@ def create_fused_moe_router( return ZeroExpertRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, e_score_correction_bias=e_score_correction_bias, num_logical_experts=num_logical_experts, zero_expert_type=zero_expert_type, @@ -142,7 +142,7 @@ def create_fused_moe_router( grouped_topk_router = GroupedTopKRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, num_expert_group=num_expert_group, topk_group=topk_group, renormalize=renormalize, @@ -169,7 +169,7 @@ def create_fused_moe_router( return CustomRoutingRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, custom_routing_function=custom_routing_function, renormalize=renormalize, indices_type_getter=indices_type_getter, @@ -181,7 +181,7 @@ def create_fused_moe_router( return FusedTopKBiasRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, e_score_correction_bias=e_score_correction_bias, renormalize=renormalize, routed_scaling_factor=routed_scaling_factor, @@ -193,7 +193,7 @@ def create_fused_moe_router( return FusedTopKRouter( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, renormalize=renormalize, scoring_func=scoring_func, indices_type_getter=indices_type_getter, diff --git a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py index 7d0b8ba8b61a..233dc82667c8 100644 --- a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py +++ b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py @@ -2,19 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import Any import torch import vllm.envs as envs +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager - - logger = init_logger(__name__) @@ -316,13 +313,13 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_manager: "EplbManager | None" = None, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, indices_type_getter=indices_type_getter, ) diff --git a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py index d61056026c01..54f0fa4fb0ac 100644 --- a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py +++ b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -from typing import TYPE_CHECKING import torch +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, get_routing_method_type, @@ -18,9 +18,6 @@ fused_topk_bias, ) -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.eplb_manager import EplbManager - class ZeroExpertRouter(BaseRouter): """Router that handles zero expert computation as part of routing. @@ -41,13 +38,13 @@ def __init__( scoring_func: str = "softmax", renormalize: bool = False, routed_scaling_factor: float = 1.0, - eplb_manager: "EplbManager | None" = None, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, - eplb_manager=eplb_manager, + eplb_state=eplb_state, indices_type_getter=indices_type_getter, ) self.e_score_correction_bias = e_score_correction_bias diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0862efbea294..5f137e778066 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1931,7 +1931,7 @@ def apply_monolithic( assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - if layer.enable_eplb: + if layer.eplb_state is not None: raise NotImplementedError( "EPLB is not supported for FlashInfer TRTLLM MXFP8 MoE backend." ) From 33db853a049917aa24d1afeb76f1c5287b81371d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 11 May 2026 16:38:21 +0000 Subject: [PATCH 07/10] review comments + fix merge Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 4 ---- .../fused_moe/router/aiter_shared_routed_fused_moe_router.py | 4 +--- vllm/model_executor/layers/fused_moe/router/router_factory.py | 1 - 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5f5b3f4f02b7..cce59f14a505 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -353,7 +353,6 @@ def __init__( vllm_config.parallel_config.expert_placement_strategy ) - # Create EPLB manager (always constructed for consistent API) self.eplb_state: EplbLayerState | None = None if enable_eplb: if self.global_num_experts % self.ep_size != 0: @@ -364,7 +363,6 @@ def __init__( ) self.eplb_state = EplbLayerState() else: - # EPLB validation is handled by EplbManager.__init__ assert not self.use_ep or num_redundant_experts == 0, ( "Redundant experts are only supported with EPLB." ) @@ -1536,8 +1534,6 @@ def set_eplb_state( This is used later in forward pass, where we get the expert mapping and record the load metrics in `expert_load_view`. - Delegates to EplbManager for state management. - Args: moe_layer_idx: Index of this MoE layer expert_load_view: View into global expert load tracking tensor diff --git a/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py b/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py index 9d92f570df5a..385e7ab50def 100644 --- a/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py +++ b/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py @@ -33,18 +33,16 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, + eplb_state: EplbLayerState | None, num_fused_shared_experts: int, scoring_func: str = "softmax", renormalize: bool = True, - enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) self.renormalize = renormalize diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 0e878385b2ee..39674c8f8832 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -209,7 +209,6 @@ def create_fused_moe_router( num_fused_shared_experts=num_fused_shared_experts, renormalize=renormalize, scoring_func=scoring_func, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) From 0e79f80398cf882db5c6af5bde6e3b0b372e6fc0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 11 May 2026 17:26:18 +0000 Subject: [PATCH 08/10] review comments Signed-off-by: Bill Nell --- .../router/aiter_shared_routed_fused_moe_router.py | 2 +- .../layers/fused_moe/router/base_router.py | 14 ++++---------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py b/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py index 385e7ab50def..46d69cf90978 100644 --- a/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py +++ b/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py @@ -33,8 +33,8 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState | None, num_fused_shared_experts: int, + eplb_state: EplbLayerState | None = None, scoring_func: str = "softmax", renormalize: bool = True, indices_type_getter: Callable[[], torch.dtype | None] | None = None, diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index ff023c1cd19c..3bc83e0648e8 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -181,19 +181,13 @@ def _validate_eplb_state(self) -> None: if self.eplb_state is not None: eplb_state = self.eplb_state if eplb_state.expert_load_view is None: - raise ValueError("enable_eplb=True requires expert_load_view != None") + raise ValueError("EPLB requires expert_load_view != None") if eplb_state.logical_to_physical_map is None: - raise ValueError( - "enable_eplb=True requires logical_to_physical_map != None" - ) + raise ValueError("EPLB requires logical_to_physical_map != None") if eplb_state.logical_replica_count is None: - raise ValueError( - "enable_eplb=True requires logical_replica_count != None" - ) + raise ValueError("EPLB requires logical_replica_count != None") if eplb_state.should_record_tensor is None: - raise ValueError( - "enable_eplb=True requires should_record_tensor != None" - ) + raise ValueError("EPLB requires should_record_tensor != None") def _get_indices_type(self) -> torch.dtype | None: """Get the desired indices dtype from the getter function.""" From 6f2a267e73d2c29abedf4d5a6ba58202d96e9a95 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 12 May 2026 13:46:41 +0000 Subject: [PATCH 09/10] fix merge Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 70bc810393c4..15cb7fe2a406 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -209,7 +209,7 @@ def __init__( self.eplb_state: EplbLayerState | None = None if enable_eplb: - if self.global_num_experts % self.ep_size != 0: + if self.use_ep and self.global_num_experts % self.ep_size != 0: raise ValueError( f"EPLB currently only supports even distribution of " f"experts across ranks. Got {self.global_num_experts} experts " From be9d8ef4e18a0aaca59039d18bbe4a63a0d05a59 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 12 May 2026 15:31:18 +0000 Subject: [PATCH 10/10] fix merge issue Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 15cb7fe2a406..68035df0cf48 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -261,7 +261,7 @@ def __init__( num_expert_group=num_expert_group, moe_parallel_config=self.moe_parallel_config, placement_strategy=self.expert_placement_strategy, - enable_eplb=self.enable_eplb, + enable_eplb=enable_eplb, num_fused_shared_experts=self.num_fused_shared_experts, rocm_aiter_enabled=self.rocm_aiter_fmoe_enabled, )