diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index cc2ba95a614..24731d2b878 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -523,17 +523,25 @@ def all_gather( self, input_: torch.Tensor, dim: int = -1, - tensor_list: List[torch.Tensor] = None, + output_tensor_list: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: - return input_ + if output_tensor_list is not None: + logger.warning( + "Performing in-place all-gather with a group size of 1. " + "This may be unnecessary; consider bypassing it for better efficiency." + ) + output_tensor_list[0].copy_(input_) + return None + else: + return input_ - if tensor_list is not None: + if output_tensor_list is not None: # TODO(ch-wan): support other backends return torch.distributed.all_gather( - tensor_list, input_, group=self.device_group + output_tensor_list, input_, group=self.device_group ) assert (