Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from sglang.srt.distributed import get_pp_group
from sglang.srt.environ import envs
from sglang.srt.layers.dp_attention import (
get_attention_cp_rank,
get_attention_cp_size,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_rank,
Expand Down Expand Up @@ -83,8 +85,18 @@ def __init__(
self.dist_init_addr = server_args.dist_init_addr
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.attn_cp_size = get_attention_cp_size()
self.attn_cp_rank = get_attention_cp_rank()
self.attn_dp_size = get_attention_dp_size()
self.attn_dp_rank = get_attention_dp_rank()
# For MLA + CP prefill, a single authoritative CP rank per attn-TP rank
# is enough because KV is already rebuilt to full sequence on each CP rank.
self.skip_register_prefill = (
disaggregation_mode == DisaggregationMode.PREFILL
and is_mla_backend
and self.attn_cp_size > 1
and self.attn_cp_rank != 0
)
self.system_dp_size = (
1 if server_args.enable_dp_attention else server_args.dp_size
)
Expand All @@ -108,7 +120,8 @@ def __init__(
self.failure_lock = threading.Lock()

if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.register_to_bootstrap()
if not self.skip_register_prefill:
self.register_to_bootstrap()
self.transfer_infos = {}
self.decode_kv_args_table = {}
self.pp_group = get_pp_group()
Expand Down Expand Up @@ -258,7 +271,6 @@ def register_to_bootstrap(self):
"kv_cache_dtype": self.server_args.kv_cache_dtype,
"load_balance_method": self.server_args.load_balance_method,
}

try:
response = requests.put(url, json=payload, timeout=5)
if response.status_code == 200:
Expand Down Expand Up @@ -342,9 +354,14 @@ def __init__(
self.bootstrap_server_url = bootstrap_addr
# inner state
self.curr_idx = 0
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
initial_status = KVPoll.Bootstrapping
if self.kv_mgr.skip_register_prefill:
# Non-authoritative CP ranks are dummy participants.
initial_status = KVPoll.WaitingForInput
self.kv_mgr.update_status(self.bootstrap_room, initial_status)
if (
self.kv_mgr.server_args.dp_size > 1
not self.kv_mgr.skip_register_prefill
and self.kv_mgr.server_args.dp_size > 1
and self.kv_mgr.server_args.load_balance_method != "follow_bootstrap_room"
):
self._register_prefill_dp_rank()
Expand Down
27 changes: 17 additions & 10 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,11 @@ def transfer_worker(
endpoint, dst_port, room, status, local_rank
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
# Dummy request means the decode instance is not used,
# so its status can be marked as success directly.
# Keep the membership guard to avoid resurrecting a
# request that has already been cleared concurrently.
# Dummy request does not need to sync status to decode endpoint.
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)

Expand Down Expand Up @@ -1038,12 +1041,12 @@ def add_transfer_request(
bootstrap_room: int,
kv_indices: npt.NDArray[np.int32],
index_slice: slice,
is_last: bool,
is_last_chunk: bool,
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
assert not is_last_chunk or (is_last_chunk and aux_index is not None)

if (
bootstrap_room not in self.request_status
Expand All @@ -1057,7 +1060,11 @@ def add_transfer_request(
if bootstrap_room not in self.transfer_infos:
# This means that the current rank is a dummy rank for this request,
# and it has already been marked as success, so there is no need to
# add further chunks into the transfer queue.
# add further chunks into the transfer queue. We still guard on
# request_status membership to avoid recreating terminal status for
# requests that have already been cleared on this rank.
if is_last_chunk and bootstrap_room in self.request_status:
self.update_status(bootstrap_room, KVPoll.Success)
return

# NOTE(shangming): sharding according to the dst_infos to make sure
Expand All @@ -1072,7 +1079,7 @@ def add_transfer_request(
room=bootstrap_room,
prefill_kv_indices=kv_indices,
index_slice=index_slice,
is_last=is_last,
is_last=is_last_chunk,
prefill_aux_index=aux_index,
state_indices=state_indices,
)
Expand Down Expand Up @@ -1134,21 +1141,21 @@ def send(
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
is_last = self.curr_idx == self.num_kv_indices
is_last_chunk = self.curr_idx == self.num_kv_indices

if not is_last:
if not is_last_chunk:
self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
False,
is_last_chunk=False,
)
else:
self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
True,
is_last_chunk=is_last_chunk,
aux_index=self.aux_index,
state_indices=state_indices,
)
Expand Down
13 changes: 10 additions & 3 deletions python/sglang/srt/disaggregation/mori/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,9 +798,16 @@ def add_transfer_request(
assert self.disaggregation_mode == DisaggregationMode.PREFILL
transfer_infos = self.transfer_infos.get(bootstrap_room)
if not transfer_infos:
raise RuntimeError(
f"No transfer info found for bootstrap_room={bootstrap_room}"
)
if not self.skip_register_prefill:
raise RuntimeError(
f"No transfer info found for bootstrap_room={bootstrap_room}"
)
# Non-authoritative CP ranks can have no transfer_infos for this room.
# If this is the last chunk, finish the local sender as a no-op.
if is_last and bootstrap_room in self.request_status:
self.update_status(bootstrap_room, KVPoll.Success)
return [], []
return [], None
result_statuses = []
target_infos_snapshot: Optional[List[TransferInfo]] = None
with self.transfer_lock:
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,17 @@ def add_transfer_request(
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)

if bootstrap_room not in self.transfer_infos:
if not self.skip_register_prefill:
raise RuntimeError(
f"No transfer info found for bootstrap_room={bootstrap_room}"
)
# Non-authoritative CP ranks can have no transfer_infos for this room.
# If this is the last chunk, finish the local sender as a no-op.
if is_last and bootstrap_room in self.request_status:
self.update_status(bootstrap_room, KVPoll.Success)
return []

reqs_to_be_processed = self.transfer_infos[bootstrap_room].values()
handles = []
for req in reqs_to_be_processed:
Expand Down
42 changes: 34 additions & 8 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import TYPE_CHECKING, List, Optional

import torch
import torch.distributed as dist

from sglang.srt.disaggregation.base import KVPoll
from sglang.srt.disaggregation.common.conn import CommonKVManager
Expand All @@ -49,14 +50,35 @@
from sglang.srt.observability.req_time_stats import set_schedule_time_batch

if TYPE_CHECKING:
from torch.distributed import ProcessGroup

from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
from sglang.srt.mem_cache.memory_pool import KVCache

logger = logging.getLogger(__name__)


def poll_and_all_reduce_attn_groups(
pollers,
attn_cp_cpu_group: dist.ProcessGroup,
attn_tp_cpu_group: Optional[dist.ProcessGroup] = None,
):
# First sync across CP ranks (same attn_tp shard).
polls = poll_and_all_reduce(pollers, attn_cp_cpu_group)
# Then sync across attn-TP ranks so all TPxCP participants in one DP shard
# observe the same status transitions.
if (
attn_tp_cpu_group is not None
and dist.get_world_size(group=attn_tp_cpu_group) > 1
):
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
dist.all_reduce(
tensor_to_reduce,
op=dist.ReduceOp.MIN,
group=attn_tp_cpu_group,
)
polls = tensor_to_reduce.tolist()
return polls


def release_req_to_metadata_buffer(
req: Req, allocator: ReqToMetadataIdxAllocator
) -> None:
Expand Down Expand Up @@ -93,7 +115,7 @@ def __init__(
tp_size: int,
gpu_id: int,
bootstrap_port: int,
gloo_group: ProcessGroup,
gloo_group: dist.ProcessGroup,
max_total_num_tokens: int,
decode_tp_size: int,
decode_dp_size: int,
Expand Down Expand Up @@ -269,8 +291,10 @@ def pop_bootstrapped(
else:
return [], []

polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.queue], self.gloo_group
polls = poll_and_all_reduce_attn_groups(
[req.disagg_kv_sender for req in self.queue],
self.scheduler.attn_cp_cpu_group,
self.scheduler.attn_tp_cpu_group,
)

for i, (req, poll) in enumerate(zip(self.queue, polls)):
Expand Down Expand Up @@ -553,13 +577,13 @@ def process_disagg_prefill_inflight_queue(

done_reqs = []

polls = poll_and_all_reduce(
polls = poll_and_all_reduce_attn_groups(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.attn_cp_cpu_group,
self.attn_tp_cpu_group,
)

undone_reqs: List[Req] = []
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):

if rids_to_check is not None:
Expand Down Expand Up @@ -621,8 +645,9 @@ def get_transferred_rids(self: Scheduler) -> List[str]:
"""
Used by PP, get the transferred rids but **do not pop**
"""
polls = poll_and_all_reduce(
polls = poll_and_all_reduce_attn_groups(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.attn_cp_cpu_group,
self.attn_tp_cpu_group,
)

Expand Down Expand Up @@ -740,4 +765,5 @@ def send_kv_chunk(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return

req.disagg_kv_sender.send(page_indices, state_indices)
Loading