diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 2309e8a83c3c..f7d4092d85dd 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -122,13 +122,23 @@ def __init__( @abstractmethod def init( + self, + prefill_dp_rank: int, + ): + """ + Resolve bootstrap metadata and mark the receiver ready for transfer metadata. + """ + ... + + @abstractmethod + def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): """ - Set req's index metadata locally or notify the prefill server about the kv indices, aux index, and state_indices. + Notify the prefill server about the kv indices, aux index, and state_indices. """ ... diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 399acde06d3b..072bd14e4ddc 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -489,20 +489,31 @@ def __init__( mgr: CommonKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr + self.conclude_state: Optional[KVPoll] = None + self.bootstrap_infos = None + self.prefill_info = None + self.prefill_dp_rank = None + self.target_tp_rank = None + self.target_tp_ranks = None + self.target_cp_ranks = None + self.target_pp_ranks = None + self.required_dst_info_num = None + self.required_prefill_response_num = None + self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + def init(self, prefill_dp_rank: int): if self.bootstrap_addr not in self.kv_mgr.prefill_info_table: self.kv_mgr.record_failure( self.bootstrap_room, f"Prefill server with bootstrap_addr: {self.bootstrap_addr} is healthy before, but now it is down. Request (bootstrap_room: {self.bootstrap_room}) has been marked as failed.", ) + self.conclude_state = KVPoll.Failed self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) - self.bootstrap_infos = None return # Read pre-computed rank mapping from prefill_info (computed in try_ensure_parallel_info) @@ -520,11 +531,9 @@ def __init__( self.required_prefill_response_num ) - assert ( - prefill_dp_rank is not None - ), "prefill_dp_rank must be resolved before creating receiver" self.prefill_dp_rank = prefill_dp_rank self._setup_bootstrap_infos() + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) def _setup_bootstrap_infos(self): all_bootstrap_infos = [] @@ -562,6 +571,7 @@ def _setup_bootstrap_infos(self): self.bootstrap_room, f"Could not fetch bootstrap info for: prefill_dp_rank: {self.prefill_dp_rank} prefill_cp_rank: {target_cp_rank} target_tp_rank: {target_tp_rank} and target_pp_rank {target_pp_rank}", ) + self.conclude_state = KVPoll.Failed self.kv_mgr.update_status( self.bootstrap_room, KVPoll.Failed ) @@ -645,6 +655,14 @@ def _connect_to_bootstrap_server(cls, bootstrap_info: dict): def _register_kv_args(self): pass + def send_metadata( + self, + kv_indices: npt.NDArray[np.int32], + aux_index: Optional[int] = None, + state_indices: Optional[List[int]] = None, + ): + raise NotImplementedError + def failure_exception(self): raise Exception("Fake KVReceiver Exception") diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 68e112deafe3..a572ab6d231b 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -276,7 +276,7 @@ def __init__( # Queue for requests pending pre-allocation self.queue: List[DecodeRequest] = [] self.retracted_queue: List[Req] = [] - self.pending_reqs: List[Req] = [] + self.pending_reqs: List[DecodeRequest] = [] self._ensure_retry_count: Dict[str, int] = {} self._max_ensure_retries: int = 20 # scheduling cycles self._ensure_last_attempt_time: Dict[str, float] = {} @@ -368,17 +368,20 @@ def add(self, req: Req, is_retracted: bool = False) -> None: req.retraction_mb_id = None self.retracted_queue.append(req) else: + decode_req = self._create_receiver_and_enqueue(req) + # NOTE: fake transfer does not need to resolve prefill dp rank in the pending queue if _is_fake_transfer(req, self.scheduler.server_args): - self._create_receiver_and_enqueue(req, 0) + decode_req.kv_receiver.init(0) return # Fast path: cache-only lookup, no network calls prefill_dp_rank = self._resolve_prefill_dp_rank(req) if prefill_dp_rank is not None: - self._create_receiver_and_enqueue(req, prefill_dp_rank) - else: - self.pending_reqs.append(req) + decode_req.kv_receiver.init(prefill_dp_rank) + return + + self.pending_reqs.append(decode_req) def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]: if req.disagg_prefill_dp_rank is not None: @@ -396,7 +399,7 @@ def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]: return None - def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None: + def _create_receiver_and_enqueue(self, req: Req) -> DecodeRequest: backend = ( TransferBackend.FAKE if _is_fake_transfer(req, self.scheduler.server_args) @@ -408,12 +411,11 @@ def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None: mgr=self.kv_manager, bootstrap_addr=_bootstrap_addr(req), bootstrap_room=req.bootstrap_room, - prefill_dp_rank=prefill_dp_rank, ) - self.queue.append( - DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) - ) + decode_req = DecodeRequest(req=req, kv_receiver=kv_receiver) + self.queue.append(decode_req) + return decode_req def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool: if len(req.origin_input_ids) > self.max_total_num_tokens: @@ -511,12 +513,12 @@ def _update_handshake_waiters( raise ValueError(f"Unexpected poll case: {poll}") def _ensure_prefill_info( - self, addr_to_reqs: Dict[str, List[Req]] - ) -> Tuple[Dict[str, List[Req]], List[Req]]: + self, addr_to_reqs: Dict[str, List[DecodeRequest]] + ) -> Tuple[Dict[str, List[DecodeRequest]], List[DecodeRequest]]: """Non-blocking ensure parallel info for each addr. Returns (ready_addrs, remaining_reqs).""" - ready: Dict[str, List[Req]] = {} - remaining: List[Req] = [] + ready: Dict[str, List[DecodeRequest]] = {} + remaining: List[DecodeRequest] = [] now = time.monotonic() for bootstrap_addr, reqs in addr_to_reqs.items(): @@ -543,13 +545,17 @@ def _ensure_prefill_info( if count >= self._max_ensure_retries: error_msg = f"Could not fetch prefill parallel info from {bootstrap_addr} after {count} attempts" logger.error(error_msg) - for req in reqs: + for decode_req in reqs: prepare_abort( - req, error_msg, status_code=HTTPStatus.INTERNAL_SERVER_ERROR + decode_req.req, + error_msg, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) if self.scheduler.enable_metrics: self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() - self.scheduler.stream_output([req], req.return_logprob) + self.scheduler.stream_output( + [decode_req.req], decode_req.req.return_logprob + ) del self._ensure_retry_count[bootstrap_addr] del self._ensure_last_attempt_time[bootstrap_addr] else: @@ -558,46 +564,48 @@ def _ensure_prefill_info( return ready, remaining def _resolve_pending_reqs(self) -> None: - """Batch-resolve prefill_dp_ranks for pending requests and create receivers.""" + """Batch-resolve prefill_dp_ranks for pending requests and initialize receivers.""" if not self.pending_reqs: return # Group pending requests by bootstrap_addr - addr_to_reqs: Dict[str, List[Req]] = {} - for req in self.pending_reqs: - addr = _bootstrap_addr(req) - addr_to_reqs.setdefault(addr, []).append(req) + addr_to_reqs: Dict[str, List[DecodeRequest]] = {} + for decode_req in self.pending_reqs: + addr = _bootstrap_addr(decode_req.req) + addr_to_reqs.setdefault(addr, []).append(decode_req) # Pass 1: ensure parallel info for each addr ready_addrs, remaining = self._ensure_prefill_info(addr_to_reqs) - # Pass 2: resolve dp rank for addrs whose info is available - resolved = [] - for bootstrap_addr, reqs in ready_addrs.items(): - need_query: List[Req] = [] - for req in reqs: - prefill_dp_rank = self._resolve_prefill_dp_rank(req) + resolved: List[Tuple[DecodeRequest, int]] = [] + for bootstrap_addr, decode_reqs in ready_addrs.items(): + need_query: List[DecodeRequest] = [] + for decode_req in decode_reqs: + prefill_dp_rank = self._resolve_prefill_dp_rank(decode_req.req) if prefill_dp_rank is not None: - resolved.append((req, prefill_dp_rank)) + resolved.append((decode_req, prefill_dp_rank)) else: - need_query.append(req) + need_query.append(decode_req) + # Pass 2: resolve dp rank for addrs whose info is available if need_query: - rooms = [req.bootstrap_room for req in need_query] + rooms = [decode_req.req.bootstrap_room for decode_req in need_query] room_to_rank = CommonKVReceiver.query_prefill_dp_ranks( bootstrap_addr, rooms ) - for req in need_query: - prefill_dp_rank = room_to_rank.get(str(req.bootstrap_room)) + for decode_req in need_query: + prefill_dp_rank = room_to_rank.get( + str(decode_req.req.bootstrap_room) + ) if prefill_dp_rank is not None: - resolved.append((req, int(prefill_dp_rank))) + resolved.append((decode_req, int(prefill_dp_rank))) else: - remaining.append(req) + remaining.append(decode_req) self.pending_reqs = remaining - for req, prefill_dp_rank in resolved: - self._create_receiver_and_enqueue(req, prefill_dp_rank) + for decode_req, prefill_dp_rank in resolved: + decode_req.kv_receiver.init(prefill_dp_rank) def pop_preallocated( self, rids_to_check: Optional[List[str]] = None @@ -726,7 +734,7 @@ def pop_preallocated( ) assert decode_req.metadata_buffer_index is not None page_indices = kv_to_page_indices(kv_indices, page_size) - decode_req.kv_receiver.init( + decode_req.kv_receiver.send_metadata( page_indices, decode_req.metadata_buffer_index, state_indices ) preallocated_reqs.append(decode_req) diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 4a3841e68208..03b79af189f7 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -82,28 +82,33 @@ def __init__( mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): - self.has_init = False + self.bootstrap_done = False + self.has_sent_metadata = False def poll(self) -> KVPoll: - if self.has_init is False: - # Assume handshake completed instantly + if not self.bootstrap_done: + return KVPoll.Bootstrapping + if not self.has_sent_metadata: return KVPoll.WaitingForInput - else: - # Assume transfer completed instantly - logger.debug("FakeKVReceiver poll success") - return KVPoll.Success + logger.debug("FakeKVReceiver poll success") + return KVPoll.Success def init( + self, + prefill_dp_rank: int, + ): + self.bootstrap_done = True + + def send_metadata( self, kv_indices: list[int], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): - self.has_init = True + self.has_sent_metadata = True logger.debug( - f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}" + f"FakeKVReceiver send_metadata with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}" ) def failure_exception(self): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index fd7090b2ecb8..e4cdcca2b31c 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -1238,15 +1238,10 @@ def __init__( mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): self.session_id = mgr.get_session_id() - self.conclude_state = None self.init_time = None - super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) - - self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) + super().__init__(mgr, bootstrap_addr, bootstrap_room) def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: @@ -1297,6 +1292,12 @@ def _register_kv_args(self): ) def init( + self, + prefill_dp_rank: int, + ): + super().init(prefill_dp_rank) + + def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, diff --git a/python/sglang/srt/disaggregation/mori/conn.py b/python/sglang/srt/disaggregation/mori/conn.py index 89b11f03e94e..a244fa3ad1ef 100644 --- a/python/sglang/srt/disaggregation/mori/conn.py +++ b/python/sglang/srt/disaggregation/mori/conn.py @@ -985,17 +985,18 @@ def __init__( mgr: MoriKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): - super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) - self.conclude_state: Optional[KVPoll] = None + super().__init__(mgr, bootstrap_addr, bootstrap_room) self.init_time: Optional[float] = None - if self.bootstrap_room is None or self.bootstrap_infos is None: + + def init( + self, + prefill_dp_rank: int, + ): + super().init(prefill_dp_rank) + if self.bootstrap_room is None: return - self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) self.kv_mgr.room_to_bootstrap_addr[self.bootstrap_room] = self.bootstrap_addr - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) - self._register_kv_args() def _register_kv_args(self): if self.bootstrap_infos is None: @@ -1029,7 +1030,7 @@ def _register_kv_args(self): ] ) - def init( + def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 764fd9e42689..38a4d15cf048 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -957,20 +957,18 @@ def __init__( mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): self.started_transfer = False - self.conclude_state = None - super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) - - # Track this room with its bootstrap address for heartbeat monitoring - if hasattr(self.kv_mgr, "addr_to_rooms_tracker"): - self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add( - self.bootstrap_room - ) + super().__init__(mgr, bootstrap_addr, bootstrap_room) self.init_time = None def init( + self, + prefill_dp_rank: int, + ): + super().init(prefill_dp_rank) + + def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, @@ -1026,7 +1024,7 @@ def poll(self) -> KVPoll: self.conclude_state = status return status if not self.started_transfer: - return KVPoll.WaitingForInput # type: ignore + return status now = time.time() elapsed = now - self.init_time diff --git a/test/registered/distributed/test_disaggregation_dp_attention.py b/test/registered/distributed/test_disaggregation_dp_attention.py index b6d52fee61da..c851650c4bd8 100644 --- a/test/registered/distributed/test_disaggregation_dp_attention.py +++ b/test/registered/distributed/test_disaggregation_dp_attention.py @@ -110,7 +110,6 @@ def test_gsm8k(self): class TestDisaggregationDPAttentionRoundRobin(TestDisaggregationDPAttention): LOAD_BALANCE_METHOD = "round_robin" - # TODO: add test for other load balance methods # TODO: add a balancedness metric def test_bench_serving(self): @@ -130,6 +129,48 @@ def test_bench_serving(self): self.assertEqual(result["completed"], 1000) +class TestDisaggregationDPAttentionTotalRequests(TestDisaggregationDPAttention): + LOAD_BALANCE_METHOD = "total_requests" + test_gsm8k = unittest.skip( + "Covered by base class; this class targets total_requests path." + )(TestDisaggregationDPAttention.test_gsm8k) + + def test_bench_serving(self): + args = get_benchmark_args( + base_url=f"http://{self.base_host}:{self.lb_port}", + dataset_name="random", + tokenizer=self.model, + num_prompts=256, + random_input_len=2048, + random_output_len=512, + request_rate=float("inf"), + max_concurrency=128, + ) + result = run_benchmark(args) + self.assertEqual(result["completed"], 256) + + +class TestDisaggregationDPAttentionTotalTokens(TestDisaggregationDPAttention): + LOAD_BALANCE_METHOD = "total_tokens" + test_gsm8k = unittest.skip( + "Covered by base class; this class targets total_tokens path." + )(TestDisaggregationDPAttention.test_gsm8k) + + def test_bench_serving(self): + args = get_benchmark_args( + base_url=f"http://{self.base_host}:{self.lb_port}", + dataset_name="random", + tokenizer=self.model, + num_prompts=256, + random_input_len=2048, + random_output_len=512, + request_rate=float("inf"), + max_concurrency=128, + ) + result = run_benchmark(args) + self.assertEqual(result["completed"], 256) + + @unittest.skip( "Skip this test until new testing logic in mini-lb has been updated in docker image." )