diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 936eecb90b97..599982bb094d 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -47,6 +47,7 @@ get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, + get_dp_global_num_tokens, get_global_dp_buffer, get_local_dp_buffer, is_allocation_symmetric, @@ -55,6 +56,7 @@ from sglang.srt.layers.flashinfer_comm_fusion import is_flashinfer_allreduce_unavailable from sglang.srt.layers.moe import ( get_moe_a2a_backend, + should_use_dp_reduce_scatterv, should_use_flashinfer_cutlass_moe_fp4_allgather, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -1007,8 +1009,13 @@ def _scatter_hidden_states( get_local_dp_buffer(), hidden_states, ) - if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): - # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead. + if should_use_dp_reduce_scatterv(): + get_tp_group().reduce_scatterv( + global_hidden_states, + output=hidden_states, + sizes=get_dp_global_num_tokens(), + ) + elif allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): dp_reduce_scatter_tensor(hidden_states, global_hidden_states) else: dp_scatter(hidden_states, global_hidden_states, forward_batch) diff --git a/python/sglang/srt/layers/moe/__init__.py b/python/sglang/srt/layers/moe/__init__.py index 74d23ecd7c70..3984a8322814 100644 --- a/python/sglang/srt/layers/moe/__init__.py +++ b/python/sglang/srt/layers/moe/__init__.py @@ -10,6 +10,7 @@ get_tbo_token_distribution_threshold, initialize_moe_config, is_tbo_enabled, + should_use_dp_reduce_scatterv, should_use_flashinfer_cutlass_moe_fp4_allgather, ) @@ -23,6 +24,7 @@ "get_moe_a2a_backend", "get_moe_runner_backend", "get_deepep_mode", + "should_use_dp_reduce_scatterv", "should_use_flashinfer_cutlass_moe_fp4_allgather", "is_tbo_enabled", "get_tbo_token_distribution_threshold", diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 0d5fa7ddbce4..d7fba833275a 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -289,6 +289,21 @@ def should_use_flashinfer_cutlass_moe_fp4_allgather(): ) +def should_use_dp_reduce_scatterv(): + """ + Use reduce_scatterv in the standard dispatcher's combine() for DP attention + with EP, replacing the default all-reduce + dp_scatter path. + Only changes the combine (post-kernel) communication; dispatch is unchanged. + """ + return ( + not should_use_flashinfer_cutlass_moe_fp4_allgather() + and get_moe_a2a_backend().is_none() + and is_dp_attention_enabled() + and get_attention_dp_size() > 1 + and get_moe_expert_parallel_world_size() == get_attention_dp_size() + ) + + @contextmanager def speculative_moe_backend_context(): """ diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 8f3475a24323..2029465ac499 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -57,6 +57,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe import ( get_moe_a2a_backend, + should_use_dp_reduce_scatterv, should_use_flashinfer_cutlass_moe_fp4_allgather, ) from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class @@ -352,6 +353,7 @@ def forward( and not should_allreduce_fusion and not use_reduce_scatter and not should_use_flashinfer_cutlass_moe_fp4_allgather() + and not should_use_dp_reduce_scatterv() ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)