diff --git a/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py b/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py index 68b2407c2e4b..9ab785af3135 100644 --- a/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py +++ b/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py @@ -10,6 +10,7 @@ from tests.kernels.moe.utils import make_test_quant_config from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace from vllm.distributed.parallel_state import ( ensure_model_parallel_initialized, @@ -201,7 +202,7 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig): dtype=torch.int32, device=device, ) - fml.enable_eplb = True + fml.eplb_state = EplbLayerState() fml.set_eplb_state( lidx, torch.zeros( diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index 89e28d950f9d..2b27202b6b6f 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -1004,7 +1004,6 @@ def make_fake_moe_layer( activation: str = "silu", indices_type: torch.dtype | None = None, expert_map: torch.Tensor | None = None, - enable_eplb: bool = False, expert_load_view: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, @@ -1022,7 +1021,6 @@ def make_fake_moe_layer( router = create_fused_moe_router( top_k=top_k, global_num_experts=global_num_experts, - # eplb_state=None, # TODO renormalize=renormalize, use_grouped_topk=use_grouped_topk, num_expert_group=num_expert_group, @@ -1032,7 +1030,6 @@ def make_fake_moe_layer( routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, num_fused_shared_experts=0, # TODO - enable_eplb=enable_eplb, # TODO(bnell): once we can construct the MK at init time, we # can make this a value. indices_type_getter=lambda: indices_type, diff --git a/tests/kernels/moe/test_routing.py b/tests/kernels/moe/test_routing.py index 90a6cd841efd..41dea8121938 100644 --- a/tests/kernels/moe/test_routing.py +++ b/tests/kernels/moe/test_routing.py @@ -36,9 +36,11 @@ def _is_aiter_capable() -> bool: NUM_EXPERTS = [8, 16, 64] -def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerState: +def setup_eplb_state( + enable_eplb: bool, global_num_experts: int +) -> EplbLayerState | None: if not enable_eplb: - return EplbLayerState() + return None # Initialize EPLB state with proper tensors for testing # For testing purposes, we use a simple 1:1 mapping (no redundant experts) @@ -349,7 +351,6 @@ def test_fused_topk( top_k=top_k, global_num_experts=global_num_experts, renormalize=renormalize, - enable_eplb=enable_eplb, eplb_state=eplb_state, ) @@ -400,7 +401,6 @@ def test_fused_topk_bias( top_k=top_k, global_num_experts=global_num_experts, renormalize=renormalize, - enable_eplb=enable_eplb, eplb_state=eplb_state, ) @@ -469,7 +469,6 @@ def test_grouped_topk( top_k=top_k, global_num_experts=global_num_experts, renormalize=renormalize, - enable_eplb=enable_eplb, eplb_state=eplb_state, ) @@ -540,7 +539,6 @@ def test_custom( global_num_experts=global_num_experts, custom_routing_function=custom_routing_function, renormalize=renormalize, - enable_eplb=enable_eplb, eplb_state=eplb_state, ) @@ -580,7 +578,6 @@ def test_custom( # router = create_fused_moe_router( # top_k=top_k, # global_num_experts=global_num_experts, -# enable_eplb=enable_eplb, # eplb_state=eplb_state, # ) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 1da39caccd80..319a5f22c922 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -940,6 +940,17 @@ class EplbLayerState: GPU work. """ + def set_layer_state( + self, + moe_layer_idx: int, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + self.expert_load_view = expert_load_view[moe_layer_idx] + self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.logical_replica_count = logical_replica_count[moe_layer_idx] + def _node_count_with_rank_mapping( pg: ProcessGroup | StatelessProcessGroup, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ca18e8588798..68035df0cf48 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -203,13 +203,24 @@ def __init__( compilation_config.static_all_moe_layers.append(prefix) self.layer_name = prefix - self.enable_eplb = enable_eplb - # TODO(bnell): should this be owned by router? - self.eplb_state = EplbLayerState() self.expert_placement_strategy: ExpertPlacementStrategy = ( vllm_config.parallel_config.expert_placement_strategy ) + self.eplb_state: EplbLayerState | None = None + if enable_eplb: + if self.use_ep and self.global_num_experts % self.ep_size != 0: + raise ValueError( + f"EPLB currently only supports even distribution of " + f"experts across ranks. Got {self.global_num_experts} experts " + f"and {self.ep_size} EP ranks." + ) + self.eplb_state = EplbLayerState() + else: + assert not self.use_ep or num_redundant_experts == 0, ( + "Redundant experts are only supported with EPLB." + ) + # ROCm aiter shared experts fusion # AITER only supports gated activations (silu/gelu), so disable it # for non-gated MoE (is_act_and_mul=False) @@ -237,17 +248,6 @@ def __init__( ) # Determine expert maps - if self.use_ep: - if self.enable_eplb: - assert self.global_num_experts % self.ep_size == 0, ( - "EPLB currently only supports even distribution of " - "experts across ranks." - ) - else: - assert num_redundant_experts == 0, ( - "Redundant experts are only supported with EPLB." - ) - max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens # Create ExpertMapManager to handle expert mapping and placement for EP. @@ -261,7 +261,7 @@ def __init__( num_expert_group=num_expert_group, moe_parallel_config=self.moe_parallel_config, placement_strategy=self.expert_placement_strategy, - enable_eplb=self.enable_eplb, + enable_eplb=enable_eplb, num_fused_shared_experts=self.num_fused_shared_experts, rocm_aiter_enabled=self.rocm_aiter_fmoe_enabled, ) @@ -313,7 +313,6 @@ def __init__( routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, num_fused_shared_experts=self.num_fused_shared_experts, - enable_eplb=enable_eplb, # TODO(bnell): once we can construct the MK at init time, we # can make this a value. indices_type_getter=lambda: self.quant_method.topk_indices_dtype, @@ -381,7 +380,7 @@ def _get_quant_method() -> FusedMoEMethodBase: "is_act_and_mul=False is supported only for CUDA and XPU for now" ) - if self.enable_eplb and not self.quant_method.supports_eplb: + if enable_eplb and not self.quant_method.supports_eplb: # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not # contain essential differences, but the current quant API @@ -1283,10 +1282,20 @@ def set_eplb_state( This is used later in forward pass, where we get the expert mapping and record the load metrics in `expert_load_view`. + + Args: + moe_layer_idx: Index of this MoE layer + expert_load_view: View into global expert load tracking tensor + logical_to_physical_map: Mapping from logical to physical expert IDs + logical_replica_count: Number of replicas for each logical expert """ - self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx] - self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] - self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx] + if self.eplb_state is not None: + self.eplb_state.set_layer_state( + moe_layer_idx, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ) def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: diff --git a/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py b/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py index 9d92f570df5a..46d69cf90978 100644 --- a/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py +++ b/vllm/model_executor/layers/fused_moe/router/aiter_shared_routed_fused_moe_router.py @@ -33,18 +33,16 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, num_fused_shared_experts: int, + eplb_state: EplbLayerState | None = None, scoring_func: str = "softmax", renormalize: bool = True, - enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) self.renormalize = renormalize diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 0138eb59c91c..3bc83e0648e8 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -148,8 +148,7 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, - enable_eplb: bool = False, + eplb_state: EplbLayerState | None = None, # TODO(bnell): Once the MK is constructed at layer init time, we # can make this a plain value instead of a callback. indices_type_getter: Callable[[], torch.dtype | None] | None = None, @@ -159,12 +158,17 @@ def __init__( time, so we need to supply a callback to get it at runtime. This is because the indices type is supplied by modular kernels which are created after MoE layer/router construction. + + Args: + top_k: Number of experts to select per token + global_num_experts: Total number of experts + eplb_state: Optional EPLBLayerState for load balancing + indices_type_getter: Optional callback to get indices dtype """ super().__init__() self.top_k = top_k self.global_num_experts = global_num_experts self.eplb_state = eplb_state - self.enable_eplb = enable_eplb self.indices_type_getter = indices_type_getter self.capture_fn: Callable[[torch.Tensor], None] | None = None @@ -174,21 +178,16 @@ def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> N def _validate_eplb_state(self) -> None: """Validate that EPLB state is properly initialized if EPLB is enabled.""" - if self.enable_eplb: - if self.eplb_state.expert_load_view is None: - raise ValueError("enable_eplb=True requires expert_load_view != None") - if self.eplb_state.logical_to_physical_map is None: - raise ValueError( - "enable_eplb=True requires logical_to_physical_map != None" - ) - if self.eplb_state.logical_replica_count is None: - raise ValueError( - "enable_eplb=True requires logical_replica_count != None" - ) - if self.eplb_state.should_record_tensor is None: - raise ValueError( - "enable_eplb=True requires should_record_tensor != None" - ) + if self.eplb_state is not None: + eplb_state = self.eplb_state + if eplb_state.expert_load_view is None: + raise ValueError("EPLB requires expert_load_view != None") + if eplb_state.logical_to_physical_map is None: + raise ValueError("EPLB requires logical_to_physical_map != None") + if eplb_state.logical_replica_count is None: + raise ValueError("EPLB requires logical_replica_count != None") + if eplb_state.should_record_tensor is None: + raise ValueError("EPLB requires should_record_tensor != None") def _get_indices_type(self) -> torch.dtype | None: """Get the desired indices dtype from the getter function.""" @@ -198,17 +197,18 @@ def _get_indices_type(self) -> torch.dtype | None: def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor: """Apply EPLB mapping to convert logical expert IDs to physical expert IDs.""" - if self.enable_eplb: - assert self.eplb_state.expert_load_view is not None - assert self.eplb_state.logical_to_physical_map is not None - assert self.eplb_state.logical_replica_count is not None - assert self.eplb_state.should_record_tensor is not None + if self.eplb_state is not None: + eplb_state = self.eplb_state + assert eplb_state.expert_load_view is not None + assert eplb_state.logical_to_physical_map is not None + assert eplb_state.logical_replica_count is not None + assert eplb_state.should_record_tensor is not None return eplb_map_to_physical_and_record( topk_ids=topk_ids, - logical_to_physical_map=self.eplb_state.logical_to_physical_map, - logical_replica_count=self.eplb_state.logical_replica_count, - expert_load_view=self.eplb_state.expert_load_view, - record_enabled=self.eplb_state.should_record_tensor, + logical_to_physical_map=eplb_state.logical_to_physical_map, + logical_replica_count=eplb_state.logical_replica_count, + expert_load_view=eplb_state.expert_load_view, + record_enabled=eplb_state.should_record_tensor, ) return topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index a3e0075f2b7d..731afffd15f8 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -16,17 +16,15 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, custom_routing_function: Callable, + eplb_state: EplbLayerState | None = None, renormalize: bool = True, - enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) self.custom_routing_function = custom_routing_function diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 84eaad7f65e6..0ca5c3f97952 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -235,11 +235,10 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, e_score_correction_bias: torch.Tensor | None = None, renormalize: bool = True, routed_scaling_factor: float = 1.0, - enable_eplb: bool = False, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, *, scoring_func: str = "sigmoid", @@ -249,7 +248,6 @@ def __init__( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) self.e_score_correction_bias = e_score_correction_bias diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py index 45311dba08e3..a4800eabb908 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py @@ -120,17 +120,15 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, scoring_func: str = "softmax", renormalize: bool = True, - enable_eplb: bool = False, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) self.renormalize = renormalize diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py index 37d812a24bbe..6f792b46a0aa 100644 --- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -251,7 +251,6 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, num_expert_group: int, topk_group: int, renormalize: bool = True, @@ -259,14 +258,13 @@ def __init__( routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, num_fused_shared_experts: int = 0, - enable_eplb: bool = False, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) self.num_expert_group = num_expert_group diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 2f2c4d39c460..39674c8f8832 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -35,8 +35,6 @@ ZeroExpertRouter, ) -EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState() - def create_fused_moe_router( # common parameters @@ -56,8 +54,7 @@ def create_fused_moe_router( # custom routing parameters custom_routing_function: Callable | None = None, # eplb parameters - enable_eplb: bool = False, - eplb_state: EplbLayerState = EMPTY_EPLB_STATE, + eplb_state: EplbLayerState | None = None, # zero expert parameters zero_expert_type: str | None = None, num_logical_experts: int | None = None, @@ -98,8 +95,7 @@ def create_fused_moe_router( custom_routing_function: Optional custom routing function EPLB arguments: - enable_eplb: Whether EPLB is enabled - eplb_state: EPLB (Expert Parallelism Load Balancing) state + eplb_state: Optional EplbLayerState, None when EPLB is disabled. Zero expert arguments: zero_expert_type: Type of zero expert (e.g. identity). If not None, @@ -120,7 +116,6 @@ def create_fused_moe_router( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) @@ -141,7 +136,6 @@ def create_fused_moe_router( scoring_func=scoring_func, renormalize=renormalize, routed_scaling_factor=routed_scaling_factor, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) @@ -163,7 +157,6 @@ def create_fused_moe_router( routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, num_fused_shared_experts=num_fused_shared_experts, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) if ( @@ -186,7 +179,6 @@ def create_fused_moe_router( eplb_state=eplb_state, custom_routing_function=custom_routing_function, renormalize=renormalize, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) @@ -200,7 +192,6 @@ def create_fused_moe_router( e_score_correction_bias=e_score_correction_bias, renormalize=renormalize, routed_scaling_factor=routed_scaling_factor, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, scoring_func=scoring_func, hash_indices_table=hash_indices_table, @@ -218,7 +209,6 @@ def create_fused_moe_router( num_fused_shared_experts=num_fused_shared_experts, renormalize=renormalize, scoring_func=scoring_func, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) @@ -228,6 +218,5 @@ def create_fused_moe_router( eplb_state=eplb_state, renormalize=renormalize, scoring_func=scoring_func, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) diff --git a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py index 8fb36b72cb70..233dc82667c8 100644 --- a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py +++ b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py @@ -313,15 +313,13 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, - enable_eplb: bool = False, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) diff --git a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py index 65760727770a..54f0fa4fb0ac 100644 --- a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py +++ b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py @@ -32,21 +32,19 @@ def __init__( self, top_k: int, global_num_experts: int, - eplb_state: EplbLayerState, e_score_correction_bias: torch.Tensor, num_logical_experts: int, zero_expert_type: str, scoring_func: str = "softmax", renormalize: bool = False, routed_scaling_factor: float = 1.0, - enable_eplb: bool = False, + eplb_state: EplbLayerState | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): super().__init__( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, - enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) self.e_score_correction_bias = e_score_correction_bias diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py index b14571fe5013..efa28ac3b6ae 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py @@ -309,10 +309,6 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: - if layer.enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." - ) assert self.moe_quant_config is not None from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py index 88cdbadd3f83..c697b137420b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py @@ -307,7 +307,6 @@ def apply_monolithic( router_logits: torch.Tensor, input_ids: torch.Tensor | None = None, ) -> torch.Tensor: - assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." assert layer.activation in ( MoEActivation.SILU, MoEActivation.SWIGLUOAI, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 13cbbabd7c3d..db148508cde4 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -2102,7 +2102,7 @@ def apply_monolithic( assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - if layer.enable_eplb: + if layer.eplb_state is not None: raise NotImplementedError( "EPLB is not supported for FlashInfer TRTLLM MXFP8 MoE backend." )