Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
Expand Down Expand Up @@ -309,8 +309,8 @@ def _scattered_to_tp_attn_full(
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
attn_tp_all_gather_into_tensor(
hidden_states,
local_hidden_states,
)
return hidden_states
Expand Down Expand Up @@ -400,9 +400,7 @@ def _gather_hidden_states_and_residual(
].clone(),
residual,
)
attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
attn_tp_all_gather_into_tensor(residual, local_residual)
if context.attn_dp_size != 1:
if context.attn_tp_rank == 0:
hidden_states += residual
Expand Down Expand Up @@ -442,9 +440,11 @@ def _scatter_hidden_states_and_residual(
*,
residual_input_mode,
):
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
input_hidden_states = hidden_states
hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
context.attn_tp_rank
]
attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
if hidden_states.shape[0] != 0:
Expand Down Expand Up @@ -547,8 +547,8 @@ def _gather(
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
attn_tp_all_gather_into_tensor(
hidden_states,
local_hidden_states,
)
return hidden_states, residual
Expand Down
15 changes: 8 additions & 7 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,13 @@ def dp_scatter(
)


def attn_tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().reduce_scatter_tensor(output, input)


def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().all_gather_into_tensor(output, input)


def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
35 changes: 26 additions & 9 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
attn_tp_all_gather,
attn_tp_all_gather_into_tensor,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
Expand Down Expand Up @@ -456,15 +457,31 @@ def _get_logits(

if self.do_tensor_parallel_all_gather:
if self.use_attn_tp_group:
global_logits = torch.empty(
(self.config.vocab_size, logits.shape[0]),
device=logits.device,
dtype=logits.dtype,
)
global_logits = global_logits.T
attn_tp_all_gather(
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
)
if self.config.vocab_size % self.attn_tp_size == 0:
global_logits = torch.empty(
(
self.attn_tp_size,
logits.shape[0],
self.config.vocab_size // self.attn_tp_size,
),
device=logits.device,
dtype=logits.dtype,
)
attn_tp_all_gather_into_tensor(global_logits, logits)
global_logits = global_logits.permute(1, 0, 2).reshape(
logits.shape[0], self.config.vocab_size
)
else:
global_logits = torch.empty(
(self.config.vocab_size, logits.shape[0]),
device=logits.device,
dtype=logits.dtype,
)
global_logits = global_logits.T
attn_tp_all_gather(
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)),
logits,
)
logits = global_logits
else:
logits = tensor_model_parallel_all_gather(logits)
Expand Down