From 6996769b783eb641a1a0b99c1c05f790b45658d4 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Mon, 16 Jun 2025 00:59:08 +0000 Subject: [PATCH 1/3] fix gather under world size one --- python/sglang/srt/distributed/parallel_state.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index cc2ba95a614..f510ae4e234 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -523,17 +523,19 @@ def all_gather( self, input_: torch.Tensor, dim: int = -1, - tensor_list: List[torch.Tensor] = None, + output_tensor_list: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: + if output_tensor_list is not None: + output_tensor_list[0].copy_(input_) return input_ - if tensor_list is not None: + if output_tensor_list is not None: # TODO(ch-wan): support other backends return torch.distributed.all_gather( - tensor_list, input_, group=self.device_group + output_tensor_list, input_, group=self.device_group ) assert ( From 815698095f577dffab00be0e814b6e83c051aba3 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Mon, 16 Jun 2025 01:04:11 +0000 Subject: [PATCH 2/3] minor --- python/sglang/srt/distributed/parallel_state.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index f510ae4e234..78ce89d7d88 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -530,7 +530,9 @@ def all_gather( if world_size == 1: if output_tensor_list is not None: output_tensor_list[0].copy_(input_) - return input_ + return None + else: + return input_ if output_tensor_list is not None: # TODO(ch-wan): support other backends From ddcd88e5e96f6276cb366e00dc6b1f93931be53b Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Tue, 17 Jun 2025 01:06:47 -0700 Subject: [PATCH 3/3] add warning info --- python/sglang/srt/distributed/parallel_state.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 78ce89d7d88..24731d2b878 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -529,6 +529,10 @@ def all_gather( # Bypass the function if we are using only 1 GPU. if world_size == 1: if output_tensor_list is not None: + logger.warning( + "Performing in-place all-gather with a group size of 1. " + "This may be unnecessary; consider bypassing it for better efficiency." + ) output_tensor_list[0].copy_(input_) return None else: