Skip to content
15 changes: 10 additions & 5 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
get_attention_cp_rank,
get_attention_cp_size,
get_attention_dp_size,
get_attention_tp_group,
get_attention_tp_rank,
get_attention_tp_size,
get_global_dp_buffer,
Expand Down Expand Up @@ -740,7 +741,7 @@ def _scattered_to_tp_attn_full(
return tuple(gathered_hidden_states)

hidden_states, local_hidden_states = (
get_local_dp_buffer(),
get_local_dp_buffer(get_attention_tp_group()),
hidden_states,
)
attn_tp_all_gather_into_tensor(
Expand Down Expand Up @@ -837,7 +838,7 @@ def _gather_hidden_states_and_residual(

if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
residual, local_residual = (
get_local_dp_buffer(),
get_local_dp_buffer(get_attention_tp_group()),
residual,
)
attn_tp_all_gather_into_tensor(residual, local_residual)
Expand All @@ -854,7 +855,7 @@ def _gather_hidden_states_and_residual(
hidden_states += residual

hidden_states, local_hidden_states = (
get_global_dp_buffer(),
get_global_dp_buffer(get_tp_group()),
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
Expand Down Expand Up @@ -993,8 +994,12 @@ def _scatter_hidden_states(
context: CommunicateContext,
allow_reduce_scatter: bool = False,
):
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
group = get_tp_group()
else:
group = get_attention_tp_group()
hidden_states, global_hidden_states = (
get_local_dp_buffer(),
get_local_dp_buffer(group),
hidden_states,
)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
Expand All @@ -1015,7 +1020,7 @@ def _gather(
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
get_local_dp_buffer(),
get_local_dp_buffer(get_attention_tp_group()),
hidden_states,
)
attn_tp_all_gather_into_tensor(
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/communicator_nsa_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sglang.srt.layers.dp_attention import (
attn_cp_all_gather_into_tensor,
attn_cp_reduce_scatter_tensor,
get_attention_cp_group,
get_local_dp_buffer,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Expand Down Expand Up @@ -154,7 +155,7 @@ def _gather_hidden_states_and_residual(
if nsa_use_prefill_cp(forward_batch):
assert context.attn_dp_size == 1
hidden_states, local_hidden_states = (
get_local_dp_buffer(),
get_local_dp_buffer(get_attention_cp_group()),
hidden_states,
)
attn_cp_all_gather_into_tensor(
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def set_dp_buffer_len(
cls._global_num_tokens = global_num_tokens

@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):
def get_global_dp_buffer(cls, group: GroupCoordinator) -> torch.Tensor:
with use_symmetric_memory(group, disabled=not cls._dp_max_padding):
buffer = torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
Expand All @@ -126,8 +126,8 @@ def get_global_dp_buffer(cls) -> torch.Tensor:
return buffer

@classmethod
def get_local_dp_buffer(cls) -> torch.Tensor:
with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):
def get_local_dp_buffer(cls, group: GroupCoordinator) -> torch.Tensor:
with use_symmetric_memory(group, disabled=not cls._dp_max_padding):
Comment on lines -129 to +130
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make group: Optional[GroupCoordinator] = None, if group==None, then we still use get_tp_group() by default.

Copy link
Copy Markdown
Contributor Author

@wangfakang wangfakang Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make group: Optional[GroupCoordinator] = None, if group==None, then we still use get_tp_group() by default.

Using default values can easily lead to group inconsistency, so it's necessary to explicitly declare the correct group to ensure consistency with the communication operator's context.

buffer = torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
Expand Down Expand Up @@ -183,12 +183,12 @@ def set_dp_buffer_len(
)


def get_global_dp_buffer() -> torch.Tensor:
return _DpGatheredBufferWrapper.get_global_dp_buffer()
def get_global_dp_buffer(group: GroupCoordinator) -> torch.Tensor:
return _DpGatheredBufferWrapper.get_global_dp_buffer(group=group)


def get_local_dp_buffer() -> torch.Tensor:
return _DpGatheredBufferWrapper.get_local_dp_buffer()
def get_local_dp_buffer(group: GroupCoordinator) -> torch.Tensor:
return _DpGatheredBufferWrapper.get_local_dp_buffer(group=group)


def get_global_dp_buffer_len() -> int:
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/moe/token_dispatcher/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def dispatch(
def combine(self, combine_input: StandardCombineInput) -> torch.Tensor:
(hidden_states,) = combine_input
if should_use_flashinfer_cutlass_moe_fp4_allgather():
hidden_states, global_hidden_states = get_local_dp_buffer(), hidden_states
hidden_states, global_hidden_states = (
get_local_dp_buffer(get_tp_group()),
hidden_states,
)
get_tp_group().reduce_scatterv(
global_hidden_states,
output=hidden_states,
Expand Down
Loading