diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 28c057abfd5c..286132d4f0f2 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import dataclasses import logging import socket import threading @@ -43,6 +44,22 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class PrefillServerInfo: + attn_tp_size: int + dp_size: int + pp_size: int + page_size: Optional[int] + follow_bootstrap_room: bool + + def __post_init__(self): + self.attn_tp_size = int(self.attn_tp_size) + self.dp_size = int(self.dp_size) + self.pp_size = int(self.pp_size) + self.page_size = int(self.page_size) if self.page_size is not None else None + self.follow_bootstrap_room = bool(self.follow_bootstrap_room) + + class CommonKVManager(BaseKVManager): def __init__( self, @@ -92,11 +109,7 @@ def __init__( self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_lock = threading.Lock() self.required_prefill_response_num_table: Dict[int, int] = {} - self.prefill_attn_tp_size_table: Dict[str, int] = {} - self.prefill_dp_size_table: Dict[str, int] = {} - self.prefill_pp_size_table: Dict[str, int] = {} - self.prefill_page_size_table: Dict[str, Optional[int]] = {} - self.follow_bootstrap_room_table: Dict[str, bool] = {} + self.prefill_info_table: Dict[str, PrefillServerInfo] = {} else: raise ValueError( f"Unsupported DisaggregationMode: {self.disaggregation_mode}" @@ -106,30 +119,43 @@ def ensure_parallel_info(self, bootstrap_addr: str) -> bool: """Fetch and cache prefill parallel info if not yet available. Returns True if info is available (cached or freshly fetched). """ - if bootstrap_addr in self.prefill_dp_size_table: + if bootstrap_addr in self.prefill_info_table: return True - info = CommonKVReceiver._fetch_prefill_parallel_info(bootstrap_addr) + info = self._fetch_prefill_server_info(bootstrap_addr) if info is None: return False - tp_size, dp_size, pp_size, page_size, follow_bootstrap_room = info - if page_size is not None and page_size != self.kv_args.page_size: + if info.page_size is not None and info.page_size != self.kv_args.page_size: raise RuntimeError( - f"Page size mismatch: prefill server has page_size={page_size}, " + f"Page size mismatch: prefill server has page_size={info.page_size}, " f"but decode server has page_size={self.kv_args.page_size}. " f"Both servers must use the same --page-size value." ) - self.prefill_attn_tp_size_table[bootstrap_addr] = tp_size - self.prefill_dp_size_table[bootstrap_addr] = dp_size - self.prefill_pp_size_table[bootstrap_addr] = pp_size - self.prefill_page_size_table[bootstrap_addr] = page_size - self.follow_bootstrap_room_table[bootstrap_addr] = follow_bootstrap_room - logger.debug( - f"Prefill parallel info for [{bootstrap_addr}]: DP={dp_size} TP={tp_size} PP={pp_size} page_size={page_size} follow_bootstrap_room={follow_bootstrap_room}" - ) + self.prefill_info_table[bootstrap_addr] = info + logger.debug(f"Prefill parallel info for [{bootstrap_addr}]: {info}") return True + @staticmethod + def _fetch_prefill_server_info( + bootstrap_addr: str, + ) -> Optional[PrefillServerInfo]: + """Fetch the prefill server info from the bootstrap server.""" + try: + url = f"http://{bootstrap_addr}/route?engine_rank={-1}&prefill_dp_rank={-1}&target_pp_rank={-1}" + response = requests.get(url, timeout=5) + if response.status_code == 200: + data = response.json() + return PrefillServerInfo(**data) + else: + logger.error( + f"Failed to get prefill server info: {response.status_code}, {response.text}" + ) + return None + except Exception as e: + logger.error(f"Error fetching prefill server info from bootstrap: {e}") + return None + def register_to_bootstrap(self): """Register KVSender to bootstrap server via HTTP POST.""" if self.dist_init_addr: @@ -315,38 +341,31 @@ def __init__( self.bootstrap_infos = None return - self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[ - self.bootstrap_addr - ] - self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] - self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] - self.prefill_page_size = self.kv_mgr.prefill_page_size_table.get( - self.bootstrap_addr - ) + self.prefill_info = self.kv_mgr.prefill_info_table[self.bootstrap_addr] # Handling for PD with different TP sizes per DP rank - if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size: + if self.kv_mgr.attn_tp_size == self.prefill_info.attn_tp_size: self.target_tp_rank = ( self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size ) self.required_dst_info_num = 1 self.required_prefill_response_num = 1 * ( - self.prefill_pp_size // self.kv_mgr.pp_size + self.prefill_info.pp_size // self.kv_mgr.pp_size ) self.target_tp_ranks = [self.target_tp_rank] - elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size: + elif self.kv_mgr.attn_tp_size > self.prefill_info.attn_tp_size: if not self.kv_mgr.is_mla_backend: logger.warning_once( "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " ) self.target_tp_rank = ( self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size - ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size) + ) // (self.kv_mgr.attn_tp_size // self.prefill_info.attn_tp_size) self.required_dst_info_num = ( - self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size + self.kv_mgr.attn_tp_size // self.prefill_info.attn_tp_size ) self.required_prefill_response_num = 1 * ( - self.prefill_pp_size // self.kv_mgr.pp_size + self.prefill_info.pp_size // self.kv_mgr.pp_size ) self.target_tp_ranks = [self.target_tp_rank] else: @@ -359,9 +378,9 @@ def __init__( rank for rank in range( (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size) - * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size), + * (self.prefill_info.attn_tp_size // self.kv_mgr.attn_tp_size), (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1) - * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size), + * (self.prefill_info.attn_tp_size // self.kv_mgr.attn_tp_size), ) ] @@ -372,23 +391,23 @@ def __init__( self.required_dst_info_num = 1 if self.kv_mgr.is_mla_backend: self.required_prefill_response_num = ( - self.prefill_pp_size // self.kv_mgr.pp_size + self.prefill_info.pp_size // self.kv_mgr.pp_size ) else: self.required_prefill_response_num = ( - self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size - ) * (self.prefill_pp_size // self.kv_mgr.pp_size) + self.prefill_info.attn_tp_size // self.kv_mgr.attn_tp_size + ) * (self.prefill_info.pp_size // self.kv_mgr.pp_size) # Decode pp size should be equal to prefill pp size or 1 assert ( - self.kv_mgr.pp_size == self.prefill_pp_size or self.kv_mgr.pp_size == 1 + self.kv_mgr.pp_size == self.prefill_info.pp_size or self.kv_mgr.pp_size == 1 ), ( - f"Decode pp size ({self.kv_mgr.pp_size}) should be equal to prefill pp size ({self.prefill_pp_size}) or 1", + f"Decode pp size ({self.kv_mgr.pp_size}) should be equal to prefill pp size ({self.prefill_info.pp_size}) or 1", ) - if self.prefill_pp_size == self.kv_mgr.pp_size: + if self.prefill_info.pp_size == self.kv_mgr.pp_size: self.target_pp_ranks = [self.kv_mgr.pp_rank] else: - self.target_pp_ranks = [rank for rank in range(self.prefill_pp_size)] + self.target_pp_ranks = [rank for rank in range(self.prefill_info.pp_size)] self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( self.required_prefill_response_num @@ -465,36 +484,6 @@ def _get_bootstrap_info_from_server( logger.error(f"Error fetching prefill info from bootstrap: {e}") return None - @staticmethod - def _fetch_prefill_parallel_info( - bootstrap_addr: str, - ) -> Optional[Tuple[int, int, int, int, bool]]: - """Fetch the prefill parallel info from the bootstrap server. - - Returns (attn_tp_size, dp_size, pp_size, page_size, follow_bootstrap_room) - or None on failure. - """ - try: - url = f"http://{bootstrap_addr}/route?engine_rank={-1}&prefill_dp_rank={-1}&target_pp_rank={-1}" - response = requests.get(url, timeout=5) - if response.status_code == 200: - info = response.json() - return ( - int(info["prefill_attn_tp_size"]), - int(info["prefill_dp_size"]), - int(info["prefill_pp_size"]), - int(info["prefill_page_size"]), - bool(info.get("follow_bootstrap_room", True)), - ) - else: - logger.error( - f"Failed to get prefill parallel info: {response.status_code}, {response.text}" - ) - return None - except Exception as e: - logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") - return None - @staticmethod def query_prefill_dp_ranks( bootstrap_addr: str, bootstrap_rooms: List[int] @@ -663,18 +652,18 @@ async def _handle_route_get(self, request: web.Request): and int(prefill_dp_rank) == -1 and int(target_pp_rank) == -1 ): - prefill_parallel_info = { - "prefill_attn_tp_size": self.attn_tp_size, - "prefill_dp_size": self.dp_size, - "prefill_pp_size": self.pp_size, - "prefill_page_size": self.page_size, - "follow_bootstrap_room": ( + info = PrefillServerInfo( + attn_tp_size=self.attn_tp_size, + dp_size=self.dp_size, + pp_size=self.pp_size, + page_size=self.page_size, + follow_bootstrap_room=( self.follow_bootstrap_room if self.follow_bootstrap_room is not None else True ), - } - return web.json_response(prefill_parallel_info, status=200) + ) + return web.json_response(dataclasses.asdict(info), status=200) # Find corresponding prefill info async with self.lock: diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 1b2d4741cc98..8d437988973b 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -364,14 +364,12 @@ def _resolve_dp_rank(self, req: Req) -> Optional[int]: bootstrap_addr = f"{req.bootstrap_host}:{req.bootstrap_port}" - if bootstrap_addr not in self.kv_manager.prefill_dp_size_table: + prefill_info = self.kv_manager.prefill_info_table.get(bootstrap_addr) + if prefill_info is None: return None - if self.kv_manager.follow_bootstrap_room_table[bootstrap_addr]: - return ( - req.bootstrap_room - % self.kv_manager.prefill_dp_size_table[bootstrap_addr] - ) + if prefill_info.follow_bootstrap_room: + return req.bootstrap_room % prefill_info.dp_size return None diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 4a80321cea42..09b265490983 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -997,7 +997,7 @@ def heartbeat_checker(): while True: time.sleep(self.heartbeat_interval) with self.connection_lock: - addresses = list(self.prefill_dp_size_table.keys()) + addresses = list(self.prefill_info_table.keys()) for bootstrap_addr in addresses: session = None @@ -1128,16 +1128,8 @@ def _handle_node_failure(self, failed_bootstrap_addr): possible_affected_rooms = self.addr_to_rooms_tracker.get( failed_bootstrap_addr, [] ) - keys_to_remove = [ - self.prefill_attn_tp_size_table, - self.prefill_dp_size_table, - self.prefill_pp_size_table, - self.follow_bootstrap_room_table, - self.addr_to_rooms_tracker, - ] - for k in keys_to_remove: - if failed_bootstrap_addr in k: - del k[failed_bootstrap_addr] + self.prefill_info_table.pop(failed_bootstrap_addr, None) + self.addr_to_rooms_tracker.pop(failed_bootstrap_addr, None) # Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed affected_rooms = [] diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 25cd5cc7f183..457022a704c7 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -224,7 +224,7 @@ def heartbeat_checker(): while True: time.sleep(self.heartbeat_interval) with self.connection_lock: - addresses = list(self.prefill_dp_size_table.keys()) + addresses = list(self.prefill_info_table.keys()) for bootstrap_addr in addresses: session = None @@ -274,18 +274,12 @@ def _handle_node_failure(self, failed_bootstrap_addr): ] for k in keys_to_remove: del self.connection_pool[k] - if failed_bootstrap_addr in self.prefill_attn_tp_size_table: - del self.prefill_attn_tp_size_table[failed_bootstrap_addr] - if failed_bootstrap_addr in self.prefill_dp_size_table: - del self.prefill_dp_size_table[failed_bootstrap_addr] - if failed_bootstrap_addr in self.prefill_pp_size_table: - del self.prefill_pp_size_table[failed_bootstrap_addr] + self.prefill_info_table.pop(failed_bootstrap_addr, None) possible_affected_rooms = self.addr_to_rooms_tracker.get( failed_bootstrap_addr, [] ) - if failed_bootstrap_addr in self.addr_to_rooms_tracker: - del self.addr_to_rooms_tracker[failed_bootstrap_addr] + self.addr_to_rooms_tracker.pop(failed_bootstrap_addr, None) # Mark all pending transfers associated with the failed node as failed affected_rooms = []