Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 143 additions & 52 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,19 +1237,22 @@ def broadcast_tensor_dict(
async_handle.wait()
return tensor_dict

def send_tensor_dict(
def isend_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
async_send: bool = False,
) -> Optional[List[P2PWork]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
) -> List["P2PWork"]:
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return tensor_dict
Non-blocking send of a tensor dictionary. Returns a list of P2PWork
handles that the caller must wait on before the send buffers can be
reused or freed.

This is the async building-block; send_tensor_dict() is a
synchronous wrapper around this method.
"""
if self.world_size <= 1:
return []

all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = (
Expand All @@ -1263,46 +1266,89 @@ def send_tensor_dict(
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"

assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# Note: While switching to Device-to-Device (D2D) would introduce an extra
# Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
# show better overall transmission performance with D2D due to:
# 1. Superior D2D transfer bandwidth
# 2. Ability to overlap send and recv operations
# Thus the net performance gain justifies this approach.

send_func = torch.distributed.isend if async_send else torch.distributed.send
p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
# Async send metadata (pickle object on CPU)
p2p_works = self.send_object(metadata_list, dst=dst, async_send=True)

for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue

# send-allgather: send only a slice, then do allgather.
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

comm_group = metadata_group if tensor.is_cpu else group
work = send_func(tensor, self.ranks[dst], group=comm_group)
if async_send:
p2p_works.append(P2PWork(work, tensor))
work = torch.distributed.isend(tensor, self.ranks[dst], group=comm_group)

if tensor.is_cuda:
tensor.record_stream(torch.cuda.current_stream(tensor.device))

p2p_works.append(P2PWork(work, tensor))

return p2p_works

def recv_tensor_dict(
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
async_send: bool = False,
) -> Optional[List[P2PWork]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the destination rank.

When async_send=True, returns a list of P2PWork handles.
When async_send=False (default), blocks until all sends complete.
"""
if self.world_size == 1:
return tensor_dict if not async_send else []

handles = self.isend_tensor_dict(
tensor_dict,
dst=dst,
all_gather_group=all_gather_group,
)

# Async mode
if async_send:
return handles

# Sync mode
for h in handles:
if h.work is not None:
h.work.wait()
return None

def irecv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
) -> Tuple[
Optional[Dict[str, Union[torch.Tensor, Any]]],
List[torch.distributed.Work],
List[Callable[[], None]],
]:
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
Non-blocking receive of a tensor dictionary.

Returns:
tensor_dict: pre-allocated tensor dict (tensor data is not ready yet,
need to wait until handles complete)
handles: irecv's work object list, caller needs to wait
postprocess: the function list to be executed after wait completed
(used by all_gather to rebuild complete tensor)

Usage:
tensor_dict, handles, postprocess = pp_group.irecv_tensor_dict(src=src)
# ... here can do overlap to perform computation ...
for h in handles:
h.wait()
for fn in postprocess:
fn()
# till now data in tensor_dict is ready
"""
if not torch.distributed.is_initialized() or self.world_size <= 1:
return None, [], []

all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = (
Expand All @@ -1316,40 +1362,85 @@ def recv_tensor_dict(
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"

# recv metadata
recv_metadata_list = self.recv_object(src=src)

tensor_dict: Dict[str, Any] = {}
handles: List[torch.distributed.Work] = []
postprocess: List[Callable[[], None]] = []

for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
full_tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
if full_tensor.numel() == 0:
tensor_dict[key] = full_tensor
continue

# send-allgather: send only a slice, then do allgather.
# send-allgather
use_all_gather = (
all_gather_group is not None
and tensor.numel() % all_gather_size == 0
and full_tensor.numel() % all_gather_size == 0
)

if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
orig_shape = full_tensor.shape
slice_tensor = full_tensor.reshape(all_gather_size, -1)[
all_gather_rank
]
comm_group = metadata_group if slice_tensor.is_cpu else group
handle = torch.distributed.irecv(
slice_tensor, src=self.ranks[src], group=comm_group
)
handles.append(handle)

def _postprocess(
key=key,
slice_tensor=slice_tensor,
orig_shape=tuple(orig_shape),
all_gather_group=all_gather_group,
):
assert all_gather_group is not None
tensor_dict[key] = all_gather_group.all_gather(
slice_tensor, dim=0
).reshape(orig_shape)

postprocess.append(_postprocess)
tensor_dict[key] = slice_tensor
else:
comm_group = metadata_group if full_tensor.is_cpu else group
handle = torch.distributed.irecv(
full_tensor, src=self.ranks[src], group=comm_group
)
handles.append(handle)
tensor_dict[key] = full_tensor
else:
tensor_dict[key] = value

# We have to use irecv here to make it work for both isend and send.
comm_group = metadata_group if tensor.is_cpu else group
work = torch.distributed.irecv(
tensor, src=self.ranks[src], group=comm_group
)
work.wait()
return tensor_dict, handles, postprocess

if use_all_gather:
tensor = all_gather_group.all_gather(tensor, dim=0)
tensor = tensor.reshape(orig_shape)
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary (synchronous wrapper).
NOTE: `src` is the local rank of the source rank.
"""
if not torch.distributed.is_initialized() or self.world_size == 1:
return None

tensor_dict, handles, postprocess = self.irecv_tensor_dict(
src=src,
all_gather_group=all_gather_group,
)

for handle in handles:
handle.wait()
for fn in postprocess:
fn()

tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict

def barrier(self):
Expand Down
58 changes: 55 additions & 3 deletions python/sglang/srt/managers/scheduler_pp_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,13 @@ def event_loop_pp(self: Scheduler):
self.mbs[mb_id] = self.get_next_batch_to_run()
self.running_mbs[mb_id] = self.running_batch
self.cur_batch: Optional[ScheduleBatch] = self.mbs[mb_id]

# Async pp processing
proxy_recv_state = None
if self.cur_batch:
server_is_idle = False
pp_proxy_tensors = self._pp_recv_proxy_tensors()
proxy_recv_state = self._pp_irecv_proxy_tensors()

next_pp_outputs = None
next_batch_result = None
d2h_event = None
Expand All @@ -104,13 +108,16 @@ def event_loop_pp(self: Scheduler):
)
)
self._pp_commit_comm_work(self.send_proxy_work)

if self.cur_batch:
pp_proxy_tensors = self._pp_wait_proxy_tensors(proxy_recv_state)
result, self.launch_event = self._pp_launch_batch(
mb_id,
pp_proxy_tensors,
self.mb_metadata,
self.last_rank_comm_queue,
)

if self.server_args.pp_async_batch_depth == 0:
next_pp_outputs, next_batch_result, d2h_event = (
self._pp_commit_send_output_work_and_preprocess_output_tensors(
Expand Down Expand Up @@ -229,10 +236,12 @@ def event_loop_pp_disagg_prefill(self: Scheduler):
self.mbs[mb_id] = batch
self.running_mbs[mb_id] = self.running_batch

# Async PP processing
proxy_recv_state = None
self.cur_batch: Optional[ScheduleBatch] = self.mbs[mb_id]
if self.cur_batch:
server_is_idle = False
pp_proxy_tensors = self._pp_recv_proxy_tensors()
proxy_recv_state = self._pp_irecv_proxy_tensors()

if self.server_args.pp_async_batch_depth > 0:
next_pp_outputs, next_batch_result, d2h_event = (
Expand All @@ -242,13 +251,16 @@ def event_loop_pp_disagg_prefill(self: Scheduler):
)
)
self._pp_commit_comm_work(self.send_proxy_work)

if self.cur_batch:
pp_proxy_tensors = self._pp_wait_proxy_tensors(proxy_recv_state)
result, self.launch_event = self._pp_launch_batch(
mb_id,
pp_proxy_tensors,
self.mb_metadata,
self.last_rank_comm_queue,
)

if self.server_args.pp_async_batch_depth == 0:
next_pp_outputs, next_batch_result, d2h_event = (
self._pp_commit_send_output_work_and_preprocess_output_tensors(
Expand Down Expand Up @@ -378,11 +390,14 @@ def event_loop_pp_disagg_decode(self: Scheduler):
self.running_mbs[mb_id] = self.running_batch

self.cur_batch: Optional[ScheduleBatch] = self.mbs[mb_id]

# Async PP processing
proxy_recv_state = None
if self.cur_batch:
server_is_idle = False
pp_proxy_tensors = None
if not self.cur_batch.forward_mode.is_prebuilt():
pp_proxy_tensors = self._pp_recv_proxy_tensors()
proxy_recv_state = self._pp_irecv_proxy_tensors()

# early send output if possible
if self.server_args.pp_async_batch_depth > 0:
Expand All @@ -395,6 +410,9 @@ def event_loop_pp_disagg_decode(self: Scheduler):
self._pp_commit_comm_work(self.send_proxy_work)

if self.cur_batch:
pp_proxy_tensors = None
if not self.cur_batch.forward_mode.is_prebuilt():
pp_proxy_tensors = self._pp_wait_proxy_tensors(proxy_recv_state)
result, self.launch_event = self._pp_launch_batch(
mb_id,
pp_proxy_tensors,
Expand Down Expand Up @@ -746,6 +764,40 @@ def process_bootstrapped_queue(
return [[req.rid for req in good_reqs], [req.rid for req in failed_reqs]]
return None

def _pp_irecv_proxy_tensors(
self: "Scheduler",
) -> Optional[Tuple[Dict, List, List]]:
"""Start async recv of proxy tensors from the previous PP stage.
Returns None if this is the first rank (no recv needed),
otherwise returns (tensor_dict, handles, postprocess) from irecv_tensor_dict.
"""
if self.pp_group.is_first_rank:
return None

tensor_dict, handles, postprocess = self.pp_group.irecv_tensor_dict(
all_gather_group=(
self.attn_tp_group if self.require_attn_tp_allgather else None
),
)
return (tensor_dict, handles, postprocess)

def _pp_wait_proxy_tensors(
self: "Scheduler",
recv_state: Optional[Tuple[Dict, List, List]],
) -> Optional["PPProxyTensors"]:
"""Wait for async proxy tensor recv to complete."""
if recv_state is None:
return None

tensor_dict, handles, postprocess = recv_state

for h in handles:
h.wait()
for fn in postprocess:
fn()

return PPProxyTensors(tensor_dict)

def _pp_pd_get_bootstrapped_ids(self: Scheduler):
# communicate pre-consensus bootstrapp reqs
if self.pp_group.is_first_rank:
Expand Down
Loading