Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/models/bailing_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Expand Down Expand Up @@ -386,6 +387,7 @@ def forward_normal(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/models/bailing_moe_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import should_use_dp_reduce_scatterv
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
Expand Down Expand Up @@ -347,7 +348,12 @@ def forward(
if self.num_shared_experts > 0:
final_hidden_states = final_hidden_states + shared_output

if self.tp_size > 1 and not use_reduce_scatter and not should_allreduce_fusion:
if (
self.tp_size > 1
and not use_reduce_scatter
and not should_allreduce_fusion
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states

Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
get_moe_runner_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Expand Down Expand Up @@ -655,6 +656,7 @@ def forward_normal_dual_stream(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
Expand Down Expand Up @@ -744,6 +746,7 @@ def _post_combine_hook(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/models/exaone_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe import get_moe_a2a_backend, should_use_dp_reduce_scatterv
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
Expand Down Expand Up @@ -300,7 +300,11 @@ def forward(

if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1 and not use_reduce_scatter:
if (
self.tp_size > 1
and not use_reduce_scatter
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

return final_hidden_states.view(num_tokens, hidden_dim)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Expand Down Expand Up @@ -598,6 +599,7 @@ def forward_normal_dual_stream(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
Expand Down Expand Up @@ -632,6 +634,7 @@ def forward_normal(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
Expand Down
11 changes: 7 additions & 4 deletions python/sglang/srt/models/hunyuan_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import should_use_dp_reduce_scatterv
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Expand Down Expand Up @@ -191,10 +192,11 @@ def _forward_single_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states=hidden_states, topk_output=topk_output
)

if self.ep_size > 1:
skip_post_reduce = should_use_dp_reduce_scatterv()
if self.ep_size > 1 and not skip_post_reduce:
final_hidden_states = moe_expert_parallel_all_reduce(final_hidden_states)

if self.tp_size > 1:
if self.tp_size > 1 and not skip_post_reduce:
final_hidden_states = moe_tensor_model_parallel_all_reduce(
final_hidden_states
)
Expand Down Expand Up @@ -222,10 +224,11 @@ def _forward_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
current_stream.wait_stream(self.alt_stream)
final_hidden_states = final_hidden_states + shared_output

if self.ep_size > 1:
skip_post_reduce = should_use_dp_reduce_scatterv()
if self.ep_size > 1 and not skip_post_reduce:
final_hidden_states = moe_expert_parallel_all_reduce(final_hidden_states)

if self.tp_size > 1:
if self.tp_size > 1 and not skip_post_reduce:
final_hidden_states = moe_tensor_model_parallel_all_reduce(
final_hidden_states
)
Expand Down
12 changes: 10 additions & 2 deletions python/sglang/srt/models/llada2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
should_use_dp_reduce_scatterv,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
Expand Down Expand Up @@ -379,7 +383,11 @@ def forward_normal(
if self.num_shared_experts > 0:
final_hidden_states = final_hidden_states + shared_output

if self.tp_size > 1 and not use_reduce_scatter:
if (
self.tp_size > 1
and not use_reduce_scatter
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)

Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.moe import should_use_dp_reduce_scatterv
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Expand Down Expand Up @@ -145,7 +146,11 @@ def forward(

out_aD = routed_out + shared_out

if self.tp_size > 1 and not use_reduce_scatter:
if (
self.tp_size > 1
and not use_reduce_scatter
and not should_use_dp_reduce_scatterv()
):
out_aD = tensor_model_parallel_all_reduce(out_aD)

return out_aD
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/mimo_v2_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
get_moe_runner_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
Expand Down Expand Up @@ -302,6 +303,7 @@ def forward_normal(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Expand Down Expand Up @@ -556,6 +557,7 @@ def forward_normal(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/models/sarvam_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe import should_use_flashinfer_cutlass_moe_fp4_allgather
from sglang.srt.layers.moe import (
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
Expand Down Expand Up @@ -375,6 +378,7 @@ def forward_normal_dual_stream(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
Expand Down Expand Up @@ -418,6 +422,7 @@ def forward_normal(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/models/sdar_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Expand Down Expand Up @@ -160,12 +161,13 @@ def forward_normal(
topk_output = self.topk(hidden_states, router_logits)
out = self.experts(hidden_states, topk_output) # (T, H)

# TP all-reduce (unless fused / reduce_scatter / fp4 allgather path)
# TP all-reduce (unless fused / reduce_scatter / fp4 allgather / dp reduce_scatterv path)
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
out = tensor_model_parallel_all_reduce(out)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/step3p5.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_dp_reduce_scatterv,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Expand Down Expand Up @@ -237,6 +238,7 @@ def forward_normal(
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

Expand Down
Loading