Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,23 @@ def __init__(

@abstractmethod
def init(
self,
prefill_dp_rank: int,
):
"""
Resolve bootstrap metadata and mark the receiver ready for transfer metadata.
"""
...

@abstractmethod
def send_metadata(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
"""
Set req's index metadata locally or notify the prefill server about the kv indices, aux index, and state_indices.
Notify the prefill server about the kv indices, aux index, and state_indices.
"""
...

Expand Down
28 changes: 23 additions & 5 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,20 +489,31 @@ def __init__(
mgr: CommonKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.conclude_state: Optional[KVPoll] = None
self.bootstrap_infos = None
self.prefill_info = None
self.prefill_dp_rank = None
self.target_tp_rank = None
self.target_tp_ranks = None
self.target_cp_ranks = None
self.target_pp_ranks = None
self.required_dst_info_num = None
self.required_prefill_response_num = None
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)

def init(self, prefill_dp_rank: int):
if self.bootstrap_addr not in self.kv_mgr.prefill_info_table:
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Prefill server with bootstrap_addr: {self.bootstrap_addr} is healthy before, but now it is down. Request (bootstrap_room: {self.bootstrap_room}) has been marked as failed.",
)
self.conclude_state = KVPoll.Failed
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
self.bootstrap_infos = None
return

# Read pre-computed rank mapping from prefill_info (computed in try_ensure_parallel_info)
Expand All @@ -520,11 +531,9 @@ def __init__(
self.required_prefill_response_num
)

assert (
prefill_dp_rank is not None
), "prefill_dp_rank must be resolved before creating receiver"
self.prefill_dp_rank = prefill_dp_rank
self._setup_bootstrap_infos()
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)

def _setup_bootstrap_infos(self):
all_bootstrap_infos = []
Expand Down Expand Up @@ -562,6 +571,7 @@ def _setup_bootstrap_infos(self):
self.bootstrap_room,
f"Could not fetch bootstrap info for: prefill_dp_rank: {self.prefill_dp_rank} prefill_cp_rank: {target_cp_rank} target_tp_rank: {target_tp_rank} and target_pp_rank {target_pp_rank}",
)
self.conclude_state = KVPoll.Failed
self.kv_mgr.update_status(
self.bootstrap_room, KVPoll.Failed
)
Expand Down Expand Up @@ -645,6 +655,14 @@ def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
def _register_kv_args(self):
pass

def send_metadata(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
raise NotImplementedError

def failure_exception(self):
raise Exception("Fake KVReceiver Exception")

Expand Down
84 changes: 46 additions & 38 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.retracted_queue: List[Req] = []
self.pending_reqs: List[Req] = []
self.pending_reqs: List[DecodeRequest] = []
self._ensure_retry_count: Dict[str, int] = {}
self._max_ensure_retries: int = 20 # scheduling cycles
self._ensure_last_attempt_time: Dict[str, float] = {}
Expand Down Expand Up @@ -368,17 +368,20 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
req.retraction_mb_id = None
self.retracted_queue.append(req)
else:
decode_req = self._create_receiver_and_enqueue(req)

# NOTE: fake transfer does not need to resolve prefill dp rank in the pending queue
if _is_fake_transfer(req, self.scheduler.server_args):
self._create_receiver_and_enqueue(req, 0)
decode_req.kv_receiver.init(0)
return

# Fast path: cache-only lookup, no network calls
prefill_dp_rank = self._resolve_prefill_dp_rank(req)
if prefill_dp_rank is not None:
self._create_receiver_and_enqueue(req, prefill_dp_rank)
else:
self.pending_reqs.append(req)
decode_req.kv_receiver.init(prefill_dp_rank)
return

self.pending_reqs.append(decode_req)

def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]:
if req.disagg_prefill_dp_rank is not None:
Expand All @@ -396,7 +399,7 @@ def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]:

return None

def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None:
def _create_receiver_and_enqueue(self, req: Req) -> DecodeRequest:
backend = (
TransferBackend.FAKE
if _is_fake_transfer(req, self.scheduler.server_args)
Expand All @@ -408,12 +411,11 @@ def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None:
mgr=self.kv_manager,
bootstrap_addr=_bootstrap_addr(req),
bootstrap_room=req.bootstrap_room,
prefill_dp_rank=prefill_dp_rank,
)

self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
decode_req = DecodeRequest(req=req, kv_receiver=kv_receiver)
self.queue.append(decode_req)
return decode_req

def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
Expand Down Expand Up @@ -511,12 +513,12 @@ def _update_handshake_waiters(
raise ValueError(f"Unexpected poll case: {poll}")

def _ensure_prefill_info(
self, addr_to_reqs: Dict[str, List[Req]]
) -> Tuple[Dict[str, List[Req]], List[Req]]:
self, addr_to_reqs: Dict[str, List[DecodeRequest]]
) -> Tuple[Dict[str, List[DecodeRequest]], List[DecodeRequest]]:
"""Non-blocking ensure parallel info for each addr.
Returns (ready_addrs, remaining_reqs)."""
ready: Dict[str, List[Req]] = {}
remaining: List[Req] = []
ready: Dict[str, List[DecodeRequest]] = {}
remaining: List[DecodeRequest] = []

now = time.monotonic()
for bootstrap_addr, reqs in addr_to_reqs.items():
Expand All @@ -543,13 +545,17 @@ def _ensure_prefill_info(
if count >= self._max_ensure_retries:
error_msg = f"Could not fetch prefill parallel info from {bootstrap_addr} after {count} attempts"
logger.error(error_msg)
for req in reqs:
for decode_req in reqs:
prepare_abort(
req, error_msg, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
decode_req.req,
error_msg,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
if self.scheduler.enable_metrics:
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
self.scheduler.stream_output([req], req.return_logprob)
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
del self._ensure_retry_count[bootstrap_addr]
del self._ensure_last_attempt_time[bootstrap_addr]
else:
Expand All @@ -558,46 +564,48 @@ def _ensure_prefill_info(
return ready, remaining

def _resolve_pending_reqs(self) -> None:
"""Batch-resolve prefill_dp_ranks for pending requests and create receivers."""
"""Batch-resolve prefill_dp_ranks for pending requests and initialize receivers."""
if not self.pending_reqs:
return

# Group pending requests by bootstrap_addr
addr_to_reqs: Dict[str, List[Req]] = {}
for req in self.pending_reqs:
addr = _bootstrap_addr(req)
addr_to_reqs.setdefault(addr, []).append(req)
addr_to_reqs: Dict[str, List[DecodeRequest]] = {}
for decode_req in self.pending_reqs:
addr = _bootstrap_addr(decode_req.req)
addr_to_reqs.setdefault(addr, []).append(decode_req)

# Pass 1: ensure parallel info for each addr
ready_addrs, remaining = self._ensure_prefill_info(addr_to_reqs)

# Pass 2: resolve dp rank for addrs whose info is available
resolved = []
for bootstrap_addr, reqs in ready_addrs.items():
need_query: List[Req] = []
for req in reqs:
prefill_dp_rank = self._resolve_prefill_dp_rank(req)
resolved: List[Tuple[DecodeRequest, int]] = []
for bootstrap_addr, decode_reqs in ready_addrs.items():
need_query: List[DecodeRequest] = []
for decode_req in decode_reqs:
prefill_dp_rank = self._resolve_prefill_dp_rank(decode_req.req)
if prefill_dp_rank is not None:
resolved.append((req, prefill_dp_rank))
resolved.append((decode_req, prefill_dp_rank))
else:
need_query.append(req)
need_query.append(decode_req)

# Pass 2: resolve dp rank for addrs whose info is available
if need_query:
rooms = [req.bootstrap_room for req in need_query]
rooms = [decode_req.req.bootstrap_room for decode_req in need_query]
room_to_rank = CommonKVReceiver.query_prefill_dp_ranks(
bootstrap_addr, rooms
)
for req in need_query:
prefill_dp_rank = room_to_rank.get(str(req.bootstrap_room))
for decode_req in need_query:
prefill_dp_rank = room_to_rank.get(
str(decode_req.req.bootstrap_room)
)
if prefill_dp_rank is not None:
resolved.append((req, int(prefill_dp_rank)))
resolved.append((decode_req, int(prefill_dp_rank)))
else:
remaining.append(req)
remaining.append(decode_req)

self.pending_reqs = remaining

for req, prefill_dp_rank in resolved:
self._create_receiver_and_enqueue(req, prefill_dp_rank)
for decode_req, prefill_dp_rank in resolved:
decode_req.kv_receiver.init(prefill_dp_rank)

def pop_preallocated(
self, rids_to_check: Optional[List[str]] = None
Expand Down Expand Up @@ -726,7 +734,7 @@ def pop_preallocated(
)
assert decode_req.metadata_buffer_index is not None
page_indices = kv_to_page_indices(kv_indices, page_size)
decode_req.kv_receiver.init(
decode_req.kv_receiver.send_metadata(
page_indices, decode_req.metadata_buffer_index, state_indices
)
preallocated_reqs.append(decode_req)
Expand Down
25 changes: 15 additions & 10 deletions python/sglang/srt/disaggregation/fake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,33 @@ def __init__(
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.has_init = False
self.bootstrap_done = False
self.has_sent_metadata = False

def poll(self) -> KVPoll:
if self.has_init is False:
# Assume handshake completed instantly
if not self.bootstrap_done:
return KVPoll.Bootstrapping
if not self.has_sent_metadata:
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.debug("FakeKVReceiver poll success")
return KVPoll.Success
logger.debug("FakeKVReceiver poll success")
return KVPoll.Success

def init(
self,
prefill_dp_rank: int,
):
self.bootstrap_done = True

def send_metadata(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
self.has_init = True
self.has_sent_metadata = True
logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
f"FakeKVReceiver send_metadata with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
)

def failure_exception(self):
Expand Down
13 changes: 7 additions & 6 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,15 +1238,10 @@ def __init__(
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.session_id = mgr.get_session_id()
self.conclude_state = None
self.init_time = None
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)

self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
super().__init__(mgr, bootstrap_addr, bootstrap_room)

def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
Expand Down Expand Up @@ -1297,6 +1292,12 @@ def _register_kv_args(self):
)

def init(
self,
prefill_dp_rank: int,
):
super().init(prefill_dp_rank)

def send_metadata(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
Expand Down
17 changes: 9 additions & 8 deletions python/sglang/srt/disaggregation/mori/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,17 +985,18 @@ def __init__(
mgr: MoriKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
self.conclude_state: Optional[KVPoll] = None
super().__init__(mgr, bootstrap_addr, bootstrap_room)
self.init_time: Optional[float] = None
if self.bootstrap_room is None or self.bootstrap_infos is None:

def init(
self,
prefill_dp_rank: int,
):
super().init(prefill_dp_rank)
if self.bootstrap_room is None:
return
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
self.kv_mgr.room_to_bootstrap_addr[self.bootstrap_room] = self.bootstrap_addr
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
self._register_kv_args()

def _register_kv_args(self):
if self.bootstrap_infos is None:
Expand Down Expand Up @@ -1029,7 +1030,7 @@ def _register_kv_args(self):
]
)

def init(
def send_metadata(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
Expand Down
Loading
Loading