diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5ab499f67074..a47587846dae 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -634,6 +634,16 @@ def apply( topk_weights, topk_ids, router_logits = topk_output + # Get expert_map for EP support + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = self.moe_runner_config.num_experts + output = fused_marlin_moe( x, layer.w13_weight_packed, @@ -643,6 +653,8 @@ def apply( router_logits, topk_weights, topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, g_idx1=layer.w13_weight_g_idx, g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices,