Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
dp_gather_partial,
dp_reduce_scatter_tensor,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
Expand Down Expand Up @@ -149,10 +150,13 @@ def __init__(
layer_scatter_modes: LayerScatterModes,
input_layernorm: torch.nn.Module,
post_attention_layernorm: torch.nn.Module,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter: bool = False,
):
self.layer_scatter_modes = layer_scatter_modes
self.input_layernorm = input_layernorm
self.post_attention_layernorm = post_attention_layernorm
self.allow_reduce_scatter = allow_reduce_scatter

self._context = CommunicateContext.init_new()
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
Expand Down Expand Up @@ -239,6 +243,15 @@ def postprocess_layer(
residual=residual,
forward_batch=forward_batch,
context=self._context,
allow_reduce_scatter=self.allow_reduce_scatter,
)

def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
return (
self.allow_reduce_scatter
and self._communicate_summable_tensor_pair_fn
is CommunicateSummableTensorPairFn._scatter_hidden_states
and forward_batch.dp_padding_mode.is_max_len()
)


Expand Down Expand Up @@ -524,6 +537,7 @@ def _trivial(
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
**kwargs,
):
return hidden_states, residual

Expand All @@ -533,15 +547,17 @@ def _scatter_hidden_states(
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
allow_reduce_scatter: bool = False,
):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
else:
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual

@staticmethod
Expand All @@ -550,6 +566,7 @@ def _gather(
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
**kwargs,
):
hidden_states += residual
residual = None
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
Expand Down Expand Up @@ -355,6 +356,17 @@ def dp_scatter(
)


def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
get_tp_group().reduce_scatter_tensor(output, input)
else:
scattered_local_tokens = input.tensor_split(
get_tensor_model_parallel_world_size()
)[get_tensor_model_parallel_rank()]
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)


def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().reduce_scatter_tensor(output, input)

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor
# It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight)

def forward(self, input_, can_fuse_mlp_allreduce=False):
def forward(self, input_, skip_all_reduce=False):
if self.input_is_parallel:
input_parallel = input_
else:
Expand All @@ -1294,7 +1294,7 @@ def forward(self, input_, can_fuse_mlp_allreduce=False):
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,10 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
self.dp_padding_mode = dp_padding_mode

if dp_padding_mode.is_max_len():
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
# where transferred tokens should be padded to the same length.
# when DP gather mode is all gather, we will use
# all_gather_into_tensor to gather hidden states, where transferred
# tokens should be padded to the same length. We will also use
# reduce-scatter instead of all-reduce after MLP.
max_num_tokens = max(global_num_tokens)
global_num_tokens = [max_num_tokens] * sync_group_size
buffer_len = max_num_tokens * sync_group_size
Expand Down
42 changes: 33 additions & 9 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,21 @@ def __init__(
)
self.act_fn = SiluAndMul()

def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
def forward(
self,
x,
forward_batch=None,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
):
if (self.tp_size == 1) and x.shape[0] == 0:
return x

gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
x, _ = self.down_proj(
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
)
return x


Expand Down Expand Up @@ -441,6 +449,7 @@ def forward(
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not self._enable_deepep_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024
Expand All @@ -450,15 +459,20 @@ def forward(
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self.forward_normal_dual_stream(
hidden_states, can_fuse_mlp_allreduce
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
)
else:
return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
return self.forward_normal(
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
)
else:
return self.forward_deepep(hidden_states, forward_batch)

def forward_normal_dual_stream(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:

current_stream = torch.cuda.current_stream()
Expand Down Expand Up @@ -486,12 +500,15 @@ def forward_normal_dual_stream(
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states

def forward_normal(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj
Expand Down Expand Up @@ -520,7 +537,7 @@ def forward_normal(
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states

Expand Down Expand Up @@ -1822,6 +1839,7 @@ def __init__(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
)

def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
Expand Down Expand Up @@ -1884,7 +1902,13 @@ def forward(
and not self.is_nextn
)

hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
)

if can_fuse_mlp_allreduce:
hidden_states._sglang_needs_allreduce_fusion = True
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):

gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
return x


Expand Down
Loading