Skip to content
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,8 +1301,12 @@ def _scatter_hidden_states_moe(

# DP scatter (if DP attention is enabled)
if context.attn_dp_size > 1:
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
group = get_tp_group()
else:
group = get_attention_tp_group()
hidden_states_output, global_hidden_states = (
get_local_dp_buffer(),
get_local_dp_buffer(group),
hidden_states,
)
dp_scatter(hidden_states_output, global_hidden_states, forward_batch)
Expand Down
37 changes: 27 additions & 10 deletions python/sglang/srt/models/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
fused_rope_inplace,
)
from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group,
)
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.attention.dsv4.compressor import Compressor
Expand All @@ -32,11 +36,13 @@
from sglang.srt.layers.dp_attention import (
_DpGatheredBufferWrapper,
attn_tp_all_gather,
dp_gather_partial,
attn_tp_all_reduce,
dp_gather_replicate,
dp_scatter,
get_attention_cp_rank,
get_attention_cp_size,
get_attention_dp_size,
get_attention_tp_group,
get_attention_tp_rank,
get_attention_tp_size,
get_global_dp_buffer,
Expand Down Expand Up @@ -303,7 +309,8 @@ def __init__(
self.hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=attn_tp_size > 1,
reduce_results=attn_tp_size == get_tensor_model_parallel_world_size()
and attn_tp_size > 1,
prefix=add_prefix("wo_b", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
Expand Down Expand Up @@ -516,9 +523,6 @@ def forward(
forward_batch: ForwardBatch,
) -> torch.Tensor:
if not get_attn_tp_context().input_scattered and x.shape[0] == 0:
assert (
not self.wo_b.reduce_results
), "short-circuiting allreduce will lead to hangs"
return x
Comment on lines 525 to 526
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the assertion that checks self.wo_b.reduce_results could lead to hangs if not handled carefully. The original assertion assert (not self.wo_b.reduce_results) was a safeguard against short-circuiting an all-reduce operation when x.shape[0] == 0.

While the logic for input_scattered seems to cover the DP case, it's safer to retain a check to prevent potential hangs in non-DP scenarios where token distribution might be uneven across ranks, or if all ranks have zero tokens but reduce_results is true. A more robust approach would be to handle the zero-sized tensor case within the RowParallelLinear layer itself, but as a direct fix, consider reintroducing a check or ensuring that x.shape[0] == 0 implies all ranks have zero tokens when reduce_results is true.


attn_backend = forward_batch.attn_backend
Expand Down Expand Up @@ -601,6 +605,8 @@ def forward(
o = torch.einsum("tgd,grd->tgr", o, wo_a)

o, _ = self.wo_b(o.flatten(1))
if self.tp_size > 1 and self.tp_size < get_tensor_model_parallel_world_size():
o = attn_tp_all_reduce(o)

return o

Expand Down Expand Up @@ -830,8 +836,14 @@ def forward(
input_ids = input_ids[cp_rank::cp_size].contiguous()
input_ids_global = input_ids
elif _use_tp_moe_gather:
hidden_states, local_hidden_states = get_global_dp_buffer(), hidden_states
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
hidden_states, local_hidden_states = (
get_global_dp_buffer(get_tp_group()),
hidden_states,
)
# hidden_states here follow TP_ATTN_FULL semantics: they are replicated
# within an attention-TP group. Use replicate gather to avoid summing the
# same activations across attention-TP ranks before entering MLP/MoE.
dp_gather_replicate(hidden_states, local_hidden_states, forward_batch)
_a2a_scatter_chunks: Optional[List[torch.Tensor]] = None
if _use_tp_attn_a2a_scatter:
s, r = get_attention_tp_size(), get_attention_tp_rank()
Expand All @@ -846,7 +858,10 @@ def forward(
input_ids_global=input_ids_global,
)
if _use_tp_moe_gather:
hidden_states, global_hidden_states = get_local_dp_buffer(), hidden_states
hidden_states, global_hidden_states = (
get_local_dp_buffer(get_attention_tp_group()),
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
if _use_tp_attn_a2a_scatter:
assert _a2a_scatter_chunks is not None
Expand Down Expand Up @@ -950,7 +965,9 @@ def forward(
dtype=input_ids.dtype,
device=input_ids.device,
)
dp_gather_partial(input_ids_global, input_ids[:, None], forward_batch)
# Token ids are replicated within an attention-TP group. Use replicate
# gather here to avoid summing duplicated ids when attention_tp_size > 1.
dp_gather_replicate(input_ids_global, input_ids[:, None], forward_batch)
input_ids_global = input_ids_global.squeeze(-1)
else:
input_ids_global = input_ids
Expand Down
Loading