diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 844c847afe5f..6030d76989c1 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -135,13 +135,26 @@ def record_failure(self, bootstrap_room: int, failure_reason: str): with self.failure_lock: self.failure_records[bootstrap_room] = failure_reason - def ensure_parallel_info(self, bootstrap_addr: str) -> bool: + def ensure_parallel_info( + self, bootstrap_addr: str, max_retries: int = 20, retry_interval: float = 1.0 + ) -> bool: """Fetch and cache prefill parallel info if not yet available. Returns True if info is available (cached or freshly fetched). + Retries with backoff if the prefill server hasn't registered yet. """ if bootstrap_addr in self.prefill_info_table: return True - info = self._fetch_prefill_server_info(bootstrap_addr) + info = None + for attempt in range(max_retries): + info = self._fetch_prefill_server_info(bootstrap_addr) + if info is not None: + break + if attempt < max_retries - 1: + logger.info( + f"Prefill server info not available from {bootstrap_addr}, " + f"retrying ({attempt + 1}/{max_retries})..." + ) + time.sleep(retry_interval) if info is None: return False @@ -573,6 +586,7 @@ def __init__(self, host: str, port: int, dp_size: int = 1): int, Dict[int, Dict[int, Dict[str, Union[str, int]]]] ] = {} self.room_to_dp_rank: Dict[int, Dict[str, Union[int, float]]] = {} + self._registered_count = 0 self.entry_cleanup_interval = ( envs.SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL.get() ) @@ -584,6 +598,14 @@ def __init__(self, host: str, port: int, dp_size: int = 1): def run(self): self.thread.start() + def _is_ready(self) -> bool: + if self.attn_tp_size is None or self.pp_size is None: + return False + # TODO: verify this expected count is correct for all parallelism + # combinations (CP / DP attention / system DP / TP / PP). + expected = self.dp_size * self.attn_tp_size * self.pp_size + return self._registered_count >= expected + def _setup_routes(self): self.app.router.add_route("*", "/route", self._handle_route) self.app.router.add_post("/register_dp_rank", self._handle_register_dp_rank) @@ -654,8 +676,11 @@ async def _handle_route_put(self, request: web.Request): "rank_ip": rank_ip, "rank_port": rank_port, } + self._registered_count += 1 + expected = self.dp_size * self.attn_tp_size * self.pp_size logger.debug( f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f" ({self._registered_count}/{expected} registered)" ) return web.Response(text="OK", status=200) @@ -672,6 +697,12 @@ async def _handle_route_get(self, request: web.Request): and int(prefill_dp_rank) == -1 and int(target_pp_rank) == -1 ): + if not self._is_ready(): + return web.Response( + text=f"Prefill server not fully registered yet" + f" ({self._registered_count} workers registered).", + status=503, + ) info = PrefillServerInfo( attn_tp_size=self.attn_tp_size, dp_size=self.dp_size, @@ -685,16 +716,27 @@ async def _handle_route_get(self, request: web.Request): ) return web.json_response(dataclasses.asdict(info), status=200) + if not self._is_ready(): + return web.Response( + text=f"Prefill server not fully registered yet" + f" ({self._registered_count} workers registered).", + status=503, + ) + # Find corresponding prefill info - async with self.lock: - bootstrap_info = self.prefill_port_table[int(prefill_dp_rank)][ - int(engine_rank) - ][int(target_pp_rank)] + try: + async with self.lock: + bootstrap_info = self.prefill_port_table[int(prefill_dp_rank)][ + int(engine_rank) + ][int(target_pp_rank)] + except KeyError: + return web.Response( + text=f"Bootstrap info not found for dp_rank={prefill_dp_rank} " + f"engine_rank={engine_rank} pp_rank={target_pp_rank}", + status=404, + ) - if bootstrap_info is not None: - return web.json_response(bootstrap_info, status=200) - else: - return web.Response(text="Bootstrap info not Found", status=404) + return web.json_response(bootstrap_info, status=200) async def _handle_register_dp_rank(self, request: web.Request): data = await request.json()