diff --git a/docs/advanced_features/pd_disaggregation.md b/docs/advanced_features/pd_disaggregation.md index 908129efdee5..6399800d70da 100644 --- a/docs/advanced_features/pd_disaggregation.md +++ b/docs/advanced_features/pd_disaggregation.md @@ -140,6 +140,7 @@ export MC_FORCE_MNNVL=True | **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions | | **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` | | **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `300` | +| **`SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL`** | Interval (seconds) between cleanups of bootstrap entries | `120` | If a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600` (10 minutes) to relax the timeout condition. Please be aware that this setting will cause prefill instances to take a longer time to clean up the affected memory resources when a running decode node loses connection. diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 67fe82ad67f1..938d4c4f91ab 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -4,6 +4,7 @@ import logging import socket import threading +import time from functools import cache from typing import Dict, List, Optional, Tuple, Union @@ -23,6 +24,7 @@ ) from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.distributed import get_pp_group +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import ( get_attention_dp_rank, get_attention_dp_size, @@ -52,6 +54,7 @@ def __init__( self.kv_args = args self.is_mla_backend = is_mla_backend self.disaggregation_mode = disaggregation_mode + self.server_args = server_args # for p/d multi node infer self.bootstrap_host = server_args.host self.bootstrap_port = server_args.disaggregation_bootstrap_port @@ -129,6 +132,7 @@ def _register_to_bootstrap(self): "rank_ip": self.local_ip, "rank_port": self.rank_port, "page_size": self.kv_args.page_size, + "bootstrap_room": -1, } try: @@ -144,6 +148,33 @@ def _register_to_bootstrap(self): f"Prefill instance failed to register to bootstrap server: {e}" ) + def _register_prefill_dp_rank(self, bootstrap_room: int, bootstrap_addr: str): + """Register request's prefill dp rank to bootstrap server via HTTP POST.""" + + url = f"http://{bootstrap_addr}/route" + payload = { + "role": "Prefill", + "attn_dp_rank": self.attn_dp_rank, + "rank_ip": self.local_ip, + "rank_port": self.rank_port, + "bootstrap_room": bootstrap_room, + } + + try: + response = requests.put(url, json=payload, timeout=5) + if response.status_code == 200: + logger.debug( + "Prefill dp rank successfully registered to bootstrap server." + ) + else: + logger.error( + f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}" + ) + except Exception as e: + logger.error( + f"Prefill instance failed to register to bootstrap server: {e}" + ) + @cache def _connect(self, endpoint: str, is_ipv6: bool = False): socket = zmq.Context().socket(zmq.PUSH) @@ -215,6 +246,12 @@ def __init__( # inner state self.curr_idx = 0 self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + # TODO: do not register prefill dp rank if prefill server's load balance method is 'follow_bootstrap_room' + # If prefill server's load balance method is not 'follow_bootstrap_room', register prefill dp rank. + if self.kv_mgr.server_args.dp_size > 1: + self.kv_mgr._register_prefill_dp_rank( + self.bootstrap_room, self.bootstrap_server_url + ) def init(self, num_kv_indices: int, aux_index: Optional[int] = None): self.num_kv_indices = num_kv_indices @@ -367,14 +404,6 @@ def __init__( self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size ) * (self.prefill_pp_size // self.kv_mgr.pp_size) - if prefill_dp_rank is not None: - logger.debug(f"Targeting DP rank: {prefill_dp_rank}") - self.prefill_dp_rank = prefill_dp_rank - else: - self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size - - self.target_dp_group = self.prefill_dp_rank - # Decode pp size should be equal to prefill pp size or 1 assert ( self.kv_mgr.pp_size == self.prefill_pp_size or self.kv_mgr.pp_size == 1 @@ -389,9 +418,24 @@ def __init__( self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( self.required_prefill_response_num ) - # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank + + # If router has not assigned prefill_dp_rank to each request, we will try to fetch it from bootstrap server. + if prefill_dp_rank is not None: + logger.debug(f"Targeting DP rank: {prefill_dp_rank}") + self.prefill_dp_rank = prefill_dp_rank + elif self.prefill_dp_size == 1: + self.prefill_dp_rank = 0 + self.should_notify_dp_rank = False + else: + self.should_notify_dp_rank = True + + if not self.should_notify_dp_rank: + self._setup_bootstrap_infos() + + def _setup_bootstrap_infos(self): + # NOTE: key distinguished by bootstrap_addr, prefill_dp_rank, and target_tp_rank bootstrap_key = ( - f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}" + f"{self.bootstrap_addr}_{self.prefill_dp_rank}_{self.target_tp_rank}" ) if bootstrap_key not in self.kv_mgr.connection_pool: @@ -400,7 +444,7 @@ def __init__( # Enable higher PP ranks to be bootstrapped earlier to make PP PD requests bootstrap more robust for target_pp_rank in reversed(self.target_pp_ranks): bootstrap_info = self._get_bootstrap_info_from_server( - target_tp_rank, self.target_dp_group, target_pp_rank + target_tp_rank, self.prefill_dp_rank, target_pp_rank ) if bootstrap_info is not None: if self.kv_mgr.is_mla_backend: @@ -413,13 +457,13 @@ def __init__( # For non-MLA: all target_tp_ranks are selected real ranks bootstrap_info["is_dummy"] = False logger.debug( - f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}" + f"Fetched bootstrap info: {bootstrap_info} for DP {self.prefill_dp_rank} TP {target_tp_rank} PP {target_pp_rank}" ) bootstrap_infos.append(bootstrap_info) else: self.kv_mgr.record_failure( self.bootstrap_room, - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}", + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and prefill_dp_rank: {self.prefill_dp_rank} and target_pp_rank {target_pp_rank}", ) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) return @@ -439,7 +483,7 @@ def _get_bootstrap_info_from_server( ): """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}" + url = f"http://{self.bootstrap_addr}/route?bootstrap_room={-1}&engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}" response = requests.get(url, timeout=5) if response.status_code == 200: bootstrap_info = response.json() @@ -458,7 +502,7 @@ def _get_prefill_parallel_info_from_server( ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]: """Fetch the prefill parallel info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}" + url = f"http://{self.bootstrap_addr}/route?bootstrap_room={-1}&engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}" response = requests.get(url) if response.status_code == 200: prefill_parallel_info = response.json() @@ -477,6 +521,39 @@ def _get_prefill_parallel_info_from_server( logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") return None, None, None, None + def _get_prefill_dp_rank_from_server(self, bootstrap_rooms): + """Fetch requests' prefill dp rank info from the bootstrap server. + + Args: + bootstrap_rooms: List of bootstrap_room IDs to query + + Returns: + Dictionary mapping bootstrap_room (as string) to dp rank info + """ + try: + if not isinstance(bootstrap_rooms, (list, tuple)): + raise TypeError( + f"bootstrap_rooms must be a list or tuple, got {type(bootstrap_rooms).__name__}" + ) + + if not bootstrap_rooms: + return {} + + rooms_str = ",".join(str(room) for room in bootstrap_rooms) + url = f"http://{self.bootstrap_addr}/route?bootstrap_room={rooms_str}&engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}" + response = requests.get(url, timeout=5) + if response.status_code == 200: + bootstrap_table = response.json() + return bootstrap_table + else: + logger.error( + f"Failed to get prefill server info: {response.status_code}, {response.text}" + ) + return None + except Exception as e: + logger.error(f"Error fetching prefill info from bootstrap: {e}") + return None + @classmethod def _connect(cls, endpoint: str, is_ipv6: bool = False): with cls._global_lock: @@ -507,7 +584,7 @@ def failure_exception(self): class CommonKVBootstrapServer(BaseKVBootstrapServer): - def __init__(self, host: str, port: int): + def __init__(self, host: str, port: int, dp_size: int): self.host = host self.port = port self.app = web.Application() @@ -521,6 +598,11 @@ def __init__(self, host: str, port: int): self.prefill_port_table: Dict[ int, Dict[int, Dict[int, Dict[str, Union[str, int]]]] ] = {} + self.prefill_dp_rank_table: Dict[int, Dict[str, Union[str, int, float]]] = {} + self.should_notify_dp_rank = dp_size > 1 + self.entry_cleanup_interval = ( + envs.SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL.get() + ) # Start bootstrap server self.thread = threading.Thread(target=self._run_server, daemon=True) @@ -550,16 +632,30 @@ async def _handle_route(self, request: web.Request): async def _handle_route_put(self, request: web.Request): data = await request.json() role = data["role"] + rank_ip = data["rank_ip"] + rank_port = int(data["rank_port"]) + bootstrap_room = int(data["bootstrap_room"]) + attn_dp_rank = data["attn_dp_rank"] + + # Use real bootstrap_room id to put prefill dp rank + if bootstrap_room != -1: + async with self.lock: + self.prefill_dp_rank_table[bootstrap_room] = { + "dp_rank": attn_dp_rank, + "rank_port": rank_port, + "rank_ip": rank_ip, + "timestamp": time.time(), + } + logger.debug(f"register bootstrap_room table {bootstrap_room=}") + return web.Response(text="OK", status=200) + attn_tp_size = data["attn_tp_size"] attn_tp_rank = data["attn_tp_rank"] attn_dp_size = data["attn_dp_size"] - attn_dp_rank = data["attn_dp_rank"] pp_size = data["pp_size"] pp_rank = data["pp_rank"] system_dp_size = data["system_dp_size"] system_dp_rank = data["system_dp_rank"] - rank_ip = data["rank_ip"] - rank_port = int(data["rank_port"]) page_size = int(data["page_size"]) if self.attn_tp_size is None: @@ -598,6 +694,38 @@ async def _handle_route_put(self, request: web.Request): return web.Response(text="OK", status=200) async def _handle_route_get(self, request: web.Request): + bootstrap_room = request.query.get("bootstrap_room") + # Batch query for prefill dp rank: bootstrap_room is comma-separated list like "1,2,3" + # Skip if bootstrap_room is -1 (special value for non-dp-rank queries) + if bootstrap_room and bootstrap_room.strip() and bootstrap_room.strip() != "-1": + # Parse comma-separated bootstrap_room IDs + room_ids = [] + for rid in bootstrap_room.split(","): + rid = rid.strip() + if rid: + try: + room_id = int(rid) + if room_id != -1: + room_ids.append(room_id) + except ValueError: + logger.warning(f"Invalid bootstrap_room value: {rid}") + + if room_ids: + # Return only the requested bootstrap_rooms and delete them after query + result = {} + async with self.lock: + for room_id in room_ids: + if room_id in self.prefill_dp_rank_table: + result[str(room_id)] = self.prefill_dp_rank_table[room_id] + # Delete the entry immediately after query + del self.prefill_dp_rank_table[room_id] + logger.debug( + f"Deleted bootstrap_room {room_id} from prefill_dp_rank_table after query" + ) + return web.json_response(result, status=200) + else: + return web.json_response({}, status=200) + engine_rank = request.query.get("engine_rank") target_dp_group = request.query.get("target_dp_group") target_pp_rank = request.query.get("target_pp_rank") @@ -629,12 +757,44 @@ async def _handle_route_get(self, request: web.Request): else: return web.Response(text="Bootstrap info not Found", status=404) + async def _cleanup_expired_entries(self): + """Remove entries older than cleanup interval from prefill_dp_rank_table.""" + while True: + await asyncio.sleep(self.entry_cleanup_interval) + current_time = time.time() + async with self.lock: + expired_keys = [ + key + for key, value in self.prefill_dp_rank_table.items() + if current_time - value["timestamp"] > self.entry_cleanup_interval + ] + if expired_keys: + start_time = time.time() + async with self.lock: + for key in expired_keys: + if key in self.prefill_dp_rank_table: + del self.prefill_dp_rank_table[key] + consumed_time = time.time() - start_time + # Check if cleanup took longer than 1ms + if consumed_time > 0.001: + logger.warning( + f"Cleaned up {len(expired_keys)} expired entries from prefill_dp_rank_table. It takes too long {consumed_time=} secs" + ) + else: + logger.debug( + f"Cleaned up {len(expired_keys)} expired entries from prefill_dp_rank_table" + ) + def _run_server(self): try: # Event Loop self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) + # Schedule the cleanup task for prefill_dp_rank_table + if self.should_notify_dp_rank: + self._loop.create_task(self._cleanup_expired_entries()) + access_log = None if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG: access_log = self.app.logger diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 4738b032f8dd..df733f12be33 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -453,6 +453,8 @@ def pop_preallocated( failed_reqs = [] preallocated_reqs = [] indices_to_remove = set() + bootstrap_table = None + data_parallel_rank = None # We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request # Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted. @@ -475,7 +477,51 @@ def pop_preallocated( indices_to_remove.add(i) # Then, preallocate the remaining requests if possible + # Batch fetch all bootstrap_rooms at once to reduce network overhead + if bootstrap_table is None and any( + decode_req.kv_receiver.should_notify_dp_rank for decode_req in self.queue + ): + bootstrap_rooms_to_fetch = [] + for decode_req in self.queue: + if ( + decode_req.kv_receiver.should_notify_dp_rank + and hasattr( + decode_req.kv_receiver, "_get_prefill_dp_rank_from_server" + ) + and decode_req.req.bootstrap_host != FAKE_BOOTSTRAP_HOST + ): + bootstrap_rooms_to_fetch.append(decode_req.req.bootstrap_room) + + if bootstrap_rooms_to_fetch: + # Use the first kv_receiver to batch fetch all bootstrap_rooms + for decode_req in self.queue: + if hasattr( + decode_req.kv_receiver, "_get_prefill_dp_rank_from_server" + ): + bootstrap_table = ( + decode_req.kv_receiver._get_prefill_dp_rank_from_server( + bootstrap_rooms_to_fetch + ) + ) + break + for i, decode_req in enumerate(self.queue): + if decode_req.kv_receiver.should_notify_dp_rank: + # Do not check warmup requests with FAKE_BOOTSTRAP_HOST + if decode_req.req.bootstrap_host != FAKE_BOOTSTRAP_HOST: + if ( + bootstrap_table + and str(decode_req.req.bootstrap_room) in bootstrap_table + ): + data_parallel_rank = bootstrap_table[ + str(decode_req.req.bootstrap_room) + ]["dp_rank"] + else: + logger.debug( + f"bootstrap info for {decode_req.req.bootstrap_room} {decode_req.req.bootstrap_host} not found" + ) + continue + if rids_to_check is not None and decode_req.req.rid not in rids_to_check: continue @@ -571,7 +617,10 @@ def pop_preallocated( assert decode_req.metadata_buffer_index is not None page_indices = kv_to_page_indices(kv_indices, page_size) decode_req.kv_receiver.init( - page_indices, decode_req.metadata_buffer_index, state_indices + page_indices, + decode_req.metadata_buffer_index, + state_indices, + data_parallel_rank, ) preallocated_reqs.append(decode_req) indices_to_remove.add(i) diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index e759465e49e4..4391f5d20a86 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -68,6 +68,7 @@ def __init__( prefill_dp_rank: Optional[int] = None, ): self.has_init = False + self.should_notify_dp_rank = False def poll(self) -> KVPoll: if self.has_init is False: @@ -83,6 +84,7 @@ def init( kv_indices: list[int], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + prefill_dp_rank: Optional[int] = None, ): self.has_init = True logger.debug( diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index a0c80e0d1514..5c22c52d25bc 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -1233,7 +1233,13 @@ def init( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + prefill_dp_rank: Optional[int] = None, ): + # Dp rank for prefill server is synchronized now. + if self.should_notify_dp_rank: + self.prefill_dp_rank = prefill_dp_rank + self._setup_bootstrap_infos() + if self.bootstrap_infos is None: self.kv_mgr.record_failure( self.bootstrap_room, diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index d3e390c92f9b..4d536ac06d07 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -777,7 +777,13 @@ def init( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + prefill_dp_rank: Optional[int] = None, ): + # Dp rank for prefill server is synchronized now. + if self.should_notify_dp_rank: + self.prefill_dp_rank = prefill_dp_rank + self._setup_bootstrap_infos() + if self.bootstrap_infos is None: logger.error( f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 5adec318cc9b..98717ad28bc2 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -246,6 +246,7 @@ class Envs: SGLANG_PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK = EnvFloat(None) SGLANG_DATA_PARALLEL_BUDGET_INTERVAL = EnvInt(1) SGLANG_QUEUED_TIMEOUT_MS = EnvInt(-1) + SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL = EnvInt(120) # Test: pd-disaggregation SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake") diff --git a/python/sglang/srt/managers/disagg_service.py b/python/sglang/srt/managers/disagg_service.py index 7dd0216d0dd1..ca3b3c814613 100644 --- a/python/sglang/srt/managers/disagg_service.py +++ b/python/sglang/srt/managers/disagg_service.py @@ -28,6 +28,7 @@ def start_disagg_service( bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class( host=server_args.host, port=server_args.disaggregation_bootstrap_port, + dp_size=server_args.dp_size, ) is_create_store = ( server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4816d04de54b..5bea023373b8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -787,18 +787,6 @@ def _handle_load_balance_method(self): ) return - # Backward compat: in PD prefill, legacy "round_robin" means `bootstrap_room` routing. - if ( - self.disaggregation_mode == "prefill" - and self.load_balance_method == "round_robin" - ): - logger.warning( - "In PD-disaggregation prefill mode, the 'round_robin' load balancing method " - "means `bootstrap_room` routing (use 'follow_bootstrap_room' instead). " - "Falling back to 'follow_bootstrap_room' for backward compatibility." - ) - self.load_balance_method = "follow_bootstrap_room" - def _handle_deprecated_args(self): # Handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"}