diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 5e0931ead0b9..aeb8449a17d7 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -24,8 +24,8 @@ tensor_model_parallel_all_reduce, ) from sglang.srt.layers.dp_attention import ( - attn_tp_all_gather, - attn_tp_reduce_scatter, + attn_tp_all_gather_into_tensor, + attn_tp_reduce_scatter_tensor, dp_gather_partial, dp_scatter, get_attention_dp_size, @@ -309,8 +309,8 @@ def _scattered_to_tp_attn_full( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - attn_tp_all_gather( - list(hidden_states.tensor_split(context.attn_tp_size)), + attn_tp_all_gather_into_tensor( + hidden_states, local_hidden_states, ) return hidden_states @@ -400,9 +400,7 @@ def _gather_hidden_states_and_residual( ].clone(), residual, ) - attn_tp_all_gather( - list(residual.tensor_split(context.attn_tp_size)), local_residual - ) + attn_tp_all_gather_into_tensor(residual, local_residual) if context.attn_dp_size != 1: if context.attn_tp_rank == 0: hidden_states += residual @@ -442,9 +440,11 @@ def _scatter_hidden_states_and_residual( *, residual_input_mode, ): - tensor_list = list(hidden_states.tensor_split(context.attn_tp_size)) - hidden_states = tensor_list[context.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) + input_hidden_states = hidden_states + hidden_states = hidden_states.tensor_split(context.attn_tp_size)[ + context.attn_tp_rank + ] + attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states) if residual_input_mode == ScatterMode.TP_ATTN_FULL: residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank] if hidden_states.shape[0] != 0: @@ -547,8 +547,8 @@ def _gather( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - attn_tp_all_gather( - list(hidden_states.tensor_split(context.attn_tp_size)), + attn_tp_all_gather_into_tensor( + hidden_states, local_hidden_states, ) return hidden_states, residual diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index c1402576b784..55db1333663e 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -355,12 +355,13 @@ def dp_scatter( ) -def attn_tp_reduce_scatter( - output: torch.Tensor, - input_list: List[torch.Tensor], -): - return get_attention_tp_group().reduce_scatter(output, input_list) +def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): + return get_attention_tp_group().reduce_scatter_tensor(output, input) + + +def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor): + return get_attention_tp_group().all_gather_into_tensor(output, input) -def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): - return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list) +def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor): + return get_attention_tp_group().all_gather(input, output_tensor_list=output_list) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 5a2e53de99ad..0aee86f68a28 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -29,6 +29,7 @@ from sglang.srt.layers.dp_attention import ( DPPaddingMode, attn_tp_all_gather, + attn_tp_all_gather_into_tensor, dp_gather_replicate, dp_scatter, get_attention_dp_rank, @@ -456,15 +457,31 @@ def _get_logits( if self.do_tensor_parallel_all_gather: if self.use_attn_tp_group: - global_logits = torch.empty( - (self.config.vocab_size, logits.shape[0]), - device=logits.device, - dtype=logits.dtype, - ) - global_logits = global_logits.T - attn_tp_all_gather( - list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits - ) + if self.config.vocab_size % self.attn_tp_size == 0: + global_logits = torch.empty( + ( + self.attn_tp_size, + logits.shape[0], + self.config.vocab_size // self.attn_tp_size, + ), + device=logits.device, + dtype=logits.dtype, + ) + attn_tp_all_gather_into_tensor(global_logits, logits) + global_logits = global_logits.permute(1, 0, 2).reshape( + logits.shape[0], self.config.vocab_size + ) + else: + global_logits = torch.empty( + (self.config.vocab_size, logits.shape[0]), + device=logits.device, + dtype=logits.dtype, + ) + global_logits = global_logits.T + attn_tp_all_gather( + list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), + logits, + ) logits = global_logits else: logits = tensor_model_parallel_all_gather(logits)