diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 30f068595188..909085402255 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -1301,8 +1301,12 @@ def _scatter_hidden_states_moe( # DP scatter (if DP attention is enabled) if context.attn_dp_size > 1: + if get_tensor_model_parallel_world_size() == get_attention_dp_size(): + group = get_tp_group() + else: + group = get_attention_tp_group() hidden_states_output, global_hidden_states = ( - get_local_dp_buffer(), + get_local_dp_buffer(group), hidden_states, ) dp_scatter(hidden_states_output, global_hidden_states, forward_batch) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index f4128884ad58..0a4740fa4a98 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -17,7 +17,11 @@ fused_rope_inplace, ) from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config -from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group, +) from sglang.srt.environ import envs from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.layers.attention.dsv4.compressor import Compressor @@ -32,11 +36,13 @@ from sglang.srt.layers.dp_attention import ( _DpGatheredBufferWrapper, attn_tp_all_gather, - dp_gather_partial, + attn_tp_all_reduce, + dp_gather_replicate, dp_scatter, get_attention_cp_rank, get_attention_cp_size, get_attention_dp_size, + get_attention_tp_group, get_attention_tp_rank, get_attention_tp_size, get_global_dp_buffer, @@ -303,7 +309,8 @@ def __init__( self.hidden_size, bias=False, quant_config=quant_config, - reduce_results=attn_tp_size > 1, + reduce_results=attn_tp_size == get_tensor_model_parallel_world_size() + and attn_tp_size > 1, prefix=add_prefix("wo_b", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, @@ -516,9 +523,6 @@ def forward( forward_batch: ForwardBatch, ) -> torch.Tensor: if not get_attn_tp_context().input_scattered and x.shape[0] == 0: - assert ( - not self.wo_b.reduce_results - ), "short-circuiting allreduce will lead to hangs" return x attn_backend = forward_batch.attn_backend @@ -601,6 +605,8 @@ def forward( o = torch.einsum("tgd,grd->tgr", o, wo_a) o, _ = self.wo_b(o.flatten(1)) + if self.tp_size > 1 and self.tp_size < get_tensor_model_parallel_world_size(): + o = attn_tp_all_reduce(o) return o @@ -830,8 +836,14 @@ def forward( input_ids = input_ids[cp_rank::cp_size].contiguous() input_ids_global = input_ids elif _use_tp_moe_gather: - hidden_states, local_hidden_states = get_global_dp_buffer(), hidden_states - dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + hidden_states, local_hidden_states = ( + get_global_dp_buffer(get_tp_group()), + hidden_states, + ) + # hidden_states here follow TP_ATTN_FULL semantics: they are replicated + # within an attention-TP group. Use replicate gather to avoid summing the + # same activations across attention-TP ranks before entering MLP/MoE. + dp_gather_replicate(hidden_states, local_hidden_states, forward_batch) _a2a_scatter_chunks: Optional[List[torch.Tensor]] = None if _use_tp_attn_a2a_scatter: s, r = get_attention_tp_size(), get_attention_tp_rank() @@ -846,7 +858,10 @@ def forward( input_ids_global=input_ids_global, ) if _use_tp_moe_gather: - hidden_states, global_hidden_states = get_local_dp_buffer(), hidden_states + hidden_states, global_hidden_states = ( + get_local_dp_buffer(get_attention_tp_group()), + hidden_states, + ) dp_scatter(hidden_states, global_hidden_states, forward_batch) if _use_tp_attn_a2a_scatter: assert _a2a_scatter_chunks is not None @@ -950,7 +965,9 @@ def forward( dtype=input_ids.dtype, device=input_ids.device, ) - dp_gather_partial(input_ids_global, input_ids[:, None], forward_batch) + # Token ids are replicated within an attention-TP group. Use replicate + # gather here to avoid summing duplicated ids when attention_tp_size > 1. + dp_gather_replicate(input_ids_global, input_ids[:, None], forward_batch) input_ids_global = input_ids_global.squeeze(-1) else: input_ids_global = input_ids