diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 97e8040e8d35..2675fcda1166 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -611,7 +611,7 @@ class Envs: # Distributed SGLANG_DSV4_FIX_TP_ATTN_A2A_SCATTER = EnvBool(True) - + SGLANG_SHARED_EXPERT_TP1 = EnvBool(False) # Symmetric Memory SGLANG_SYMM_MEM_PREALLOC_GB_SIZE = EnvInt(-1) SGLANG_DEBUG_SYMM_MEM = EnvBool(False) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7537be81cd2d..f64881f9913b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1155,8 +1155,10 @@ def check_quantized_moe_compatibility(self): ) if ( - moe_intermediate_size // moe_tp_size - ) % weight_block_size_n != 0 and not _use_aiter: + not envs.SGLANG_SHARED_EXPERT_TP1.get() + and (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0 + and not _use_aiter + ): raise ValueError( f"For quantized MoE models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 " f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by ep_size ({self.moe_ep_size}). " diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2c8637f0ec48..cd02b9c0bbfd 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -534,6 +534,7 @@ def __init__( self.shared_experts_is_int8 = False self.shared_experts_is_fp8 = False self.shared_experts_weight_block_size = None + self._shared_expert_tp1 = False # Shared experts: skip when fused into MoE kernel (self.num_fused_shared_experts > 0) # or when DeepEP fusion is enabled (shared expert is local slot 16 in FusedMoE, no separate MLP). if ( @@ -543,7 +544,19 @@ def __init__( and not _is_deepep_fusion ): intermediate_size = config.moe_intermediate_size * config.n_shared_experts - # disable tp for shared experts when enable deepep moe, or with fp4 allgather + # Disable TP for shared experts for A2A/FP4 allgather paths, or when + # explicitly requested for DSV4 checkpoints whose shared scales are + # not divisible by the global TP size. + _shared_expert_use_tp1 = ( + get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_nixl() + or get_moe_a2a_backend().is_mori() + or get_moe_a2a_backend().is_ascend_fuseep() + or get_moe_a2a_backend().is_flashinfer() + or should_use_flashinfer_cutlass_moe_fp4_allgather() + or envs.SGLANG_SHARED_EXPERT_TP1.get() + ) self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, @@ -552,18 +565,9 @@ def __init__( reduce_results=False, swiglu_limit=getattr(config, "swiglu_limit", None), prefix=add_prefix("shared_experts", prefix), - **( - dict(tp_rank=0, tp_size=1) - if get_moe_a2a_backend().is_deepep() - or get_moe_a2a_backend().is_mooncake() - or get_moe_a2a_backend().is_nixl() - or get_moe_a2a_backend().is_mori() - or get_moe_a2a_backend().is_ascend_fuseep() - or get_moe_a2a_backend().is_flashinfer() - or should_use_flashinfer_cutlass_moe_fp4_allgather() - else {} - ), + **(dict(tp_rank=0, tp_size=1) if _shared_expert_use_tp1 else {}), ) + self._shared_expert_tp1 = _shared_expert_use_tp1 is_packed_weight = hasattr( self.shared_experts.gate_up_proj.quant_method, "quant_config" ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in { @@ -740,7 +744,7 @@ def forward_normal_dual_stream( final_hidden_states = maybe_fuse_routed_scale_and_shared_add( self.experts, final_hidden_states, - shared_output, + None if self._shared_expert_tp1 else shared_output, self.routed_scaling_factor, ) @@ -750,6 +754,10 @@ def forward_normal_dual_stream( should_allreduce_fusion=should_allreduce_fusion, ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + # TP1 shared experts are replicated, so add them after all-reduce to + # avoid summing the same shared output once per TP rank. + if self._shared_expert_tp1: + final_hidden_states += shared_output return final_hidden_states def forward_normal( @@ -842,7 +850,7 @@ def _post_combine_hook( final_hidden_states = maybe_fuse_routed_scale_and_shared_add( self.experts, final_hidden_states, - shared_output, + None if self._shared_expert_tp1 else shared_output, self.routed_scaling_factor, ) @@ -852,6 +860,10 @@ def _post_combine_hook( should_allreduce_fusion=should_allreduce_fusion, ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + # TP1 shared experts are replicated, so add them after all-reduce to + # avoid summing the same shared output once per TP rank. + if shared_output is not None and self._shared_expert_tp1: + final_hidden_states += shared_output return final_hidden_states def forward_cpu(