diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index b9d9385dbd1a..0fdff48fe9c8 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -284,17 +284,9 @@ def update( # -------------------------------- usage ------------------------------------ def logical_to_all_physical( - self, - layer_id: int, - logical_expert_id: int, - require_global_experts: bool = False, + self, layer_id: int, logical_expert_id: int ) -> List[int]: # Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario - if require_global_experts: - num_physical_experts = self.logical_to_all_physical_map_cpu[layer_id].shape[ - -1 - ] - return list(torch.arange(0, num_physical_experts)) return [ physical_expert_id for physical_expert_id in self.logical_to_all_physical_map_cpu[ @@ -363,10 +355,14 @@ def _compute_logical_to_all_physical_map( ) # Replace by the nearest physical expert - if nearest_expert != -1: - logical_to_all_physical_map[layer_id][logical_expert_id] = [ - nearest_expert - ] + mapped_physical_experts = logical_to_all_physical_map[layer_id][ + logical_expert_id + ] + if ( + nearest_expert != -1 + and nearest_expert not in mapped_physical_experts + ): + mapped_physical_experts[0] = nearest_expert logical_to_all_physical_map = _pad_nested_array( logical_to_all_physical_map, pad_value=-1 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 1aadc61fe4ed..4924fca1549d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -539,12 +539,9 @@ def weight_loader( # This is a shared expert. physical_expert_ids = [expert_id] else: - require_global_experts = getattr( - param, "_sglang_require_global_experts", False - ) physical_expert_ids = ( global_expert_location_metadata.logical_to_all_physical( - self.layer_id, expert_id, require_global_experts + self.layer_id, expert_id ) ) @@ -1129,6 +1126,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_num_experts=self.num_local_experts, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, + tile_tokens_dim=None, routing_method_type=RoutingMethodType.DeepSeekV3, do_finalize=True, output=symm_output, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 425d4f97fc75..7b4870c978c6 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1245,6 +1245,7 @@ def apply_with_router_logits( routed_scaling_factor=( routed_scaling_factor if routed_scaling_factor is not None else 1.0 ), + tile_tokens_dim=None, routing_method_type=routing_method_type, use_shuffled_weight=False, ) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index f3f7b85ea106..99ae27684f2e 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -695,6 +695,7 @@ def apply( else 1.0 ), use_routing_scales_on_input=use_routing_scales_on_input, + tile_tokens_dim=None, routing_method_type=routing_method_type, ) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 137cdb683ec5..847eaf0ee250 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -681,6 +681,7 @@ def apply( layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset layer.num_local_experts, # local num experts None, + None, # tile_tokens_dim 1, # routing_method_type, renormalize True, # do finalize output=symm_output,