diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 72311f9d3ffe..7c9c28242baf 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1237,19 +1237,22 @@ def broadcast_tensor_dict( async_handle.wait() return tensor_dict - def send_tensor_dict( + def isend_tensor_dict( self, tensor_dict: Dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, - async_send: bool = False, - ) -> Optional[List[P2PWork]]: - """Send the input tensor dictionary. - NOTE: `dst` is the local rank of the source rank. + ) -> List["P2PWork"]: """ - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return tensor_dict + Non-blocking send of a tensor dictionary. Returns a list of P2PWork + handles that the caller must wait on before the send buffers can be + reused or freed. + + This is the async building-block; send_tensor_dict() is a + synchronous wrapper around this method. + """ + if self.world_size <= 1: + return [] all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_rank = ( @@ -1263,46 +1266,89 @@ def send_tensor_dict( dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size, f"Invalid dst rank ({dst})" - assert isinstance( - tensor_dict, dict - ), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # Note: While switching to Device-to-Device (D2D) would introduce an extra - # Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks - # show better overall transmission performance with D2D due to: - # 1. Superior D2D transfer bandwidth - # 2. Ability to overlap send and recv operations - # Thus the net performance gain justifies this approach. - send_func = torch.distributed.isend if async_send else torch.distributed.send - p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send) + # Async send metadata (pickle object on CPU) + p2p_works = self.send_object(metadata_list, dst=dst, async_send=True) for tensor in tensor_list: if tensor.numel() == 0: - # Skip sending empty tensors. continue - - # send-allgather: send only a slice, then do allgather. if all_gather_group is not None and tensor.numel() % all_gather_size == 0: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] comm_group = metadata_group if tensor.is_cpu else group - work = send_func(tensor, self.ranks[dst], group=comm_group) - if async_send: - p2p_works.append(P2PWork(work, tensor)) + work = torch.distributed.isend(tensor, self.ranks[dst], group=comm_group) + + if tensor.is_cuda: + tensor.record_stream(torch.cuda.current_stream(tensor.device)) + + p2p_works.append(P2PWork(work, tensor)) + return p2p_works - def recv_tensor_dict( + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + async_send: bool = False, + ) -> Optional[List[P2PWork]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the destination rank. + + When async_send=True, returns a list of P2PWork handles. + When async_send=False (default), blocks until all sends complete. + """ + if self.world_size == 1: + return tensor_dict if not async_send else [] + + handles = self.isend_tensor_dict( + tensor_dict, + dst=dst, + all_gather_group=all_gather_group, + ) + + # Async mode + if async_send: + return handles + + # Sync mode + for h in handles: + if h.work is not None: + h.work.wait() + return None + + def irecv_tensor_dict( self, src: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Recv the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. + ) -> Tuple[ + Optional[Dict[str, Union[torch.Tensor, Any]]], + List[torch.distributed.Work], + List[Callable[[], None]], + ]: """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return None + Non-blocking receive of a tensor dictionary. + + Returns: + tensor_dict: pre-allocated tensor dict (tensor data is not ready yet, + need to wait until handles complete) + handles: irecv's work object list, caller needs to wait + postprocess: the function list to be executed after wait completed + (used by all_gather to rebuild complete tensor) + + Usage: + tensor_dict, handles, postprocess = pp_group.irecv_tensor_dict(src=src) + # ... here can do overlap to perform computation ... + for h in handles: + h.wait() + for fn in postprocess: + fn() + # till now data in tensor_dict is ready + """ + if not torch.distributed.is_initialized() or self.world_size <= 1: + return None, [], [] all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_rank = ( @@ -1316,40 +1362,85 @@ def recv_tensor_dict( src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size, f"Invalid src rank ({src})" + # recv metadata recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + handles: List[torch.distributed.Work] = [] + postprocess: List[Callable[[], None]] = [] + for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - tensor_dict[key] = tensor + full_tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if full_tensor.numel() == 0: + tensor_dict[key] = full_tensor continue - # send-allgather: send only a slice, then do allgather. + # send-allgather use_all_gather = ( all_gather_group is not None - and tensor.numel() % all_gather_size == 0 + and full_tensor.numel() % all_gather_size == 0 ) if use_all_gather: - orig_shape = tensor.shape - tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + orig_shape = full_tensor.shape + slice_tensor = full_tensor.reshape(all_gather_size, -1)[ + all_gather_rank + ] + comm_group = metadata_group if slice_tensor.is_cpu else group + handle = torch.distributed.irecv( + slice_tensor, src=self.ranks[src], group=comm_group + ) + handles.append(handle) + + def _postprocess( + key=key, + slice_tensor=slice_tensor, + orig_shape=tuple(orig_shape), + all_gather_group=all_gather_group, + ): + assert all_gather_group is not None + tensor_dict[key] = all_gather_group.all_gather( + slice_tensor, dim=0 + ).reshape(orig_shape) + + postprocess.append(_postprocess) + tensor_dict[key] = slice_tensor + else: + comm_group = metadata_group if full_tensor.is_cpu else group + handle = torch.distributed.irecv( + full_tensor, src=self.ranks[src], group=comm_group + ) + handles.append(handle) + tensor_dict[key] = full_tensor + else: + tensor_dict[key] = value - # We have to use irecv here to make it work for both isend and send. - comm_group = metadata_group if tensor.is_cpu else group - work = torch.distributed.irecv( - tensor, src=self.ranks[src], group=comm_group - ) - work.wait() + return tensor_dict, handles, postprocess - if use_all_gather: - tensor = all_gather_group.all_gather(tensor, dim=0) - tensor = tensor.reshape(orig_shape) + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary (synchronous wrapper). + NOTE: `src` is the local rank of the source rank. + """ + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + tensor_dict, handles, postprocess = self.irecv_tensor_dict( + src=src, + all_gather_group=all_gather_group, + ) + + for handle in handles: + handle.wait() + for fn in postprocess: + fn() - tensor_dict[key] = tensor - else: - tensor_dict[key] = value return tensor_dict def barrier(self): diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index ba9cc0ac2342..458a0c895d51 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -90,9 +90,13 @@ def event_loop_pp(self: Scheduler): self.mbs[mb_id] = self.get_next_batch_to_run() self.running_mbs[mb_id] = self.running_batch self.cur_batch: Optional[ScheduleBatch] = self.mbs[mb_id] + + # Async pp processing + proxy_recv_state = None if self.cur_batch: server_is_idle = False - pp_proxy_tensors = self._pp_recv_proxy_tensors() + proxy_recv_state = self._pp_irecv_proxy_tensors() + next_pp_outputs = None next_batch_result = None d2h_event = None @@ -104,13 +108,16 @@ def event_loop_pp(self: Scheduler): ) ) self._pp_commit_comm_work(self.send_proxy_work) + if self.cur_batch: + pp_proxy_tensors = self._pp_wait_proxy_tensors(proxy_recv_state) result, self.launch_event = self._pp_launch_batch( mb_id, pp_proxy_tensors, self.mb_metadata, self.last_rank_comm_queue, ) + if self.server_args.pp_async_batch_depth == 0: next_pp_outputs, next_batch_result, d2h_event = ( self._pp_commit_send_output_work_and_preprocess_output_tensors( @@ -229,10 +236,12 @@ def event_loop_pp_disagg_prefill(self: Scheduler): self.mbs[mb_id] = batch self.running_mbs[mb_id] = self.running_batch + # Async PP processing + proxy_recv_state = None self.cur_batch: Optional[ScheduleBatch] = self.mbs[mb_id] if self.cur_batch: server_is_idle = False - pp_proxy_tensors = self._pp_recv_proxy_tensors() + proxy_recv_state = self._pp_irecv_proxy_tensors() if self.server_args.pp_async_batch_depth > 0: next_pp_outputs, next_batch_result, d2h_event = ( @@ -242,13 +251,16 @@ def event_loop_pp_disagg_prefill(self: Scheduler): ) ) self._pp_commit_comm_work(self.send_proxy_work) + if self.cur_batch: + pp_proxy_tensors = self._pp_wait_proxy_tensors(proxy_recv_state) result, self.launch_event = self._pp_launch_batch( mb_id, pp_proxy_tensors, self.mb_metadata, self.last_rank_comm_queue, ) + if self.server_args.pp_async_batch_depth == 0: next_pp_outputs, next_batch_result, d2h_event = ( self._pp_commit_send_output_work_and_preprocess_output_tensors( @@ -378,11 +390,14 @@ def event_loop_pp_disagg_decode(self: Scheduler): self.running_mbs[mb_id] = self.running_batch self.cur_batch: Optional[ScheduleBatch] = self.mbs[mb_id] + + # Async PP processing + proxy_recv_state = None if self.cur_batch: server_is_idle = False pp_proxy_tensors = None if not self.cur_batch.forward_mode.is_prebuilt(): - pp_proxy_tensors = self._pp_recv_proxy_tensors() + proxy_recv_state = self._pp_irecv_proxy_tensors() # early send output if possible if self.server_args.pp_async_batch_depth > 0: @@ -395,6 +410,9 @@ def event_loop_pp_disagg_decode(self: Scheduler): self._pp_commit_comm_work(self.send_proxy_work) if self.cur_batch: + pp_proxy_tensors = None + if not self.cur_batch.forward_mode.is_prebuilt(): + pp_proxy_tensors = self._pp_wait_proxy_tensors(proxy_recv_state) result, self.launch_event = self._pp_launch_batch( mb_id, pp_proxy_tensors, @@ -746,6 +764,40 @@ def process_bootstrapped_queue( return [[req.rid for req in good_reqs], [req.rid for req in failed_reqs]] return None + def _pp_irecv_proxy_tensors( + self: "Scheduler", + ) -> Optional[Tuple[Dict, List, List]]: + """Start async recv of proxy tensors from the previous PP stage. + Returns None if this is the first rank (no recv needed), + otherwise returns (tensor_dict, handles, postprocess) from irecv_tensor_dict. + """ + if self.pp_group.is_first_rank: + return None + + tensor_dict, handles, postprocess = self.pp_group.irecv_tensor_dict( + all_gather_group=( + self.attn_tp_group if self.require_attn_tp_allgather else None + ), + ) + return (tensor_dict, handles, postprocess) + + def _pp_wait_proxy_tensors( + self: "Scheduler", + recv_state: Optional[Tuple[Dict, List, List]], + ) -> Optional["PPProxyTensors"]: + """Wait for async proxy tensor recv to complete.""" + if recv_state is None: + return None + + tensor_dict, handles, postprocess = recv_state + + for h in handles: + h.wait() + for fn in postprocess: + fn() + + return PPProxyTensors(tensor_dict) + def _pp_pd_get_bootstrapped_ids(self: Scheduler): # communicate pre-consensus bootstrapp reqs if self.pp_group.is_first_rank: