diff --git a/docs/advanced_features/pd_disaggregation.md b/docs/advanced_features/pd_disaggregation.md index 64cb41a9fe5f..17b81b86368e 100644 --- a/docs/advanced_features/pd_disaggregation.md +++ b/docs/advanced_features/pd_disaggregation.md @@ -142,6 +142,7 @@ The `SGLANG_MOONCAKE_CUSTOM_MEM_POOL` environment variable enables the custom me | **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions | | **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` | | **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `300` | +| **`SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL`** | Interval (seconds) between cleanups of bootstrap entries | `120` | If a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600` (10 minutes) to relax the timeout condition. Please be aware that this setting will cause prefill instances to take a longer time to clean up the affected memory resources when a running decode node loses connection. diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index da4629e52527..3233c5e31075 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -61,6 +61,11 @@ def __init__( is_mla_backend: Optional[bool] = False, ): ... + @abstractmethod + def register_to_bootstrap(self): + """Register to the bootstrap server.""" + ... + class BaseKVSender(ABC): @@ -158,4 +163,4 @@ def abort(self): class BaseKVBootstrapServer(ABC): @abstractmethod - def __init__(self, host: str, port: int): ... + def __init__(self, host: str, port: int, dp_size: int = 1): ... diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 67fe82ad67f1..28c057abfd5c 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -4,6 +4,7 @@ import logging import socket import threading +import time from functools import cache from typing import Dict, List, Optional, Tuple, Union @@ -23,6 +24,7 @@ ) from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.distributed import get_pp_group +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import ( get_attention_dp_rank, get_attention_dp_size, @@ -52,6 +54,7 @@ def __init__( self.kv_args = args self.is_mla_backend = is_mla_backend self.disaggregation_mode = disaggregation_mode + self.server_args = server_args # for p/d multi node infer self.bootstrap_host = server_args.host self.bootstrap_port = server_args.disaggregation_bootstrap_port @@ -81,7 +84,7 @@ def __init__( self.request_status: Dict[int, KVPoll] = {} if self.disaggregation_mode == DisaggregationMode.PREFILL: - self._register_to_bootstrap() + self.register_to_bootstrap() self.transfer_infos = {} self.decode_kv_args_table = {} self.pp_group = get_pp_group() @@ -93,12 +96,41 @@ def __init__( self.prefill_dp_size_table: Dict[str, int] = {} self.prefill_pp_size_table: Dict[str, int] = {} self.prefill_page_size_table: Dict[str, Optional[int]] = {} + self.follow_bootstrap_room_table: Dict[str, bool] = {} else: raise ValueError( f"Unsupported DisaggregationMode: {self.disaggregation_mode}" ) - def _register_to_bootstrap(self): + def ensure_parallel_info(self, bootstrap_addr: str) -> bool: + """Fetch and cache prefill parallel info if not yet available. + Returns True if info is available (cached or freshly fetched). + """ + if bootstrap_addr in self.prefill_dp_size_table: + return True + info = CommonKVReceiver._fetch_prefill_parallel_info(bootstrap_addr) + if info is None: + return False + tp_size, dp_size, pp_size, page_size, follow_bootstrap_room = info + + if page_size is not None and page_size != self.kv_args.page_size: + raise RuntimeError( + f"Page size mismatch: prefill server has page_size={page_size}, " + f"but decode server has page_size={self.kv_args.page_size}. " + f"Both servers must use the same --page-size value." + ) + + self.prefill_attn_tp_size_table[bootstrap_addr] = tp_size + self.prefill_dp_size_table[bootstrap_addr] = dp_size + self.prefill_pp_size_table[bootstrap_addr] = pp_size + self.prefill_page_size_table[bootstrap_addr] = page_size + self.follow_bootstrap_room_table[bootstrap_addr] = follow_bootstrap_room + logger.debug( + f"Prefill parallel info for [{bootstrap_addr}]: DP={dp_size} TP={tp_size} PP={pp_size} page_size={page_size} follow_bootstrap_room={follow_bootstrap_room}" + ) + return True + + def register_to_bootstrap(self): """Register KVSender to bootstrap server via HTTP POST.""" if self.dist_init_addr: # Multi-node case: bootstrap server's host is dist_init_addr @@ -129,6 +161,7 @@ def _register_to_bootstrap(self): "rank_ip": self.local_ip, "rank_port": self.rank_port, "page_size": self.kv_args.page_size, + "load_balance_method": self.server_args.load_balance_method, } try: @@ -215,6 +248,27 @@ def __init__( # inner state self.curr_idx = 0 self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + if ( + self.kv_mgr.server_args.dp_size > 1 + and self.kv_mgr.server_args.load_balance_method != "follow_bootstrap_room" + ): + self._register_prefill_dp_rank() + + def _register_prefill_dp_rank(self): + """Register this request's prefill dp_rank to the bootstrap server.""" + url = f"http://{self.bootstrap_server_url}/register_dp_rank" + payload = { + "bootstrap_room": self.bootstrap_room, + "dp_rank": self.kv_mgr.attn_dp_rank, + } + try: + response = requests.post(url, json=payload, timeout=5) + if response.status_code != 200: + logger.error( + f"Failed to register prefill dp_rank: {response.status_code}, {response.text}" + ) + except Exception as e: + logger.error(f"Failed to register prefill dp_rank: {e}") def init(self, num_kv_indices: int, aux_index: Optional[int] = None): self.num_kv_indices = num_kv_indices @@ -252,65 +306,23 @@ def __init__( self.kv_mgr = mgr self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) - if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: - ( - self.prefill_attn_tp_size, - self.prefill_dp_size, - self.prefill_pp_size, - self.prefill_page_size, - ) = self._get_prefill_parallel_info_from_server() - if ( - self.prefill_attn_tp_size is None - or self.prefill_dp_size is None - or self.prefill_pp_size is None - ): - self.kv_mgr.record_failure( - self.bootstrap_room, - f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", - ) - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) - self.bootstrap_infos = None - return - - if self.prefill_page_size is not None: - decode_page_size = self.kv_mgr.kv_args.page_size - if self.prefill_page_size != decode_page_size: - error_msg = ( - f"Page size mismatch: prefill server has page_size={self.prefill_page_size}, " - f"but decode server has page_size={decode_page_size}. " - f"Both servers must use the same --page-size value." - ) - logger.error(error_msg) - raise RuntimeError(error_msg) - - logger.debug( - f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size} Page size:{self.prefill_page_size}" - ) - self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = ( - self.prefill_attn_tp_size - ) - self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = ( - self.prefill_dp_size - ) - self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = ( - self.prefill_pp_size - ) - self.kv_mgr.prefill_page_size_table[self.bootstrap_addr] = ( - self.prefill_page_size - ) - else: - self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[ - self.bootstrap_addr - ] - self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[ - self.bootstrap_addr - ] - self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[ - self.bootstrap_addr - ] - self.prefill_page_size = self.kv_mgr.prefill_page_size_table.get( - self.bootstrap_addr + if not self.kv_mgr.ensure_parallel_info(self.bootstrap_addr): + self.kv_mgr.record_failure( + self.bootstrap_room, + f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", ) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + self.bootstrap_infos = None + return + + self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[ + self.bootstrap_addr + ] + self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] + self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] + self.prefill_page_size = self.kv_mgr.prefill_page_size_table.get( + self.bootstrap_addr + ) # Handling for PD with different TP sizes per DP rank if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size: @@ -367,14 +379,6 @@ def __init__( self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size ) * (self.prefill_pp_size // self.kv_mgr.pp_size) - if prefill_dp_rank is not None: - logger.debug(f"Targeting DP rank: {prefill_dp_rank}") - self.prefill_dp_rank = prefill_dp_rank - else: - self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size - - self.target_dp_group = self.prefill_dp_rank - # Decode pp size should be equal to prefill pp size or 1 assert ( self.kv_mgr.pp_size == self.prefill_pp_size or self.kv_mgr.pp_size == 1 @@ -389,9 +393,17 @@ def __init__( self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( self.required_prefill_response_num ) - # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank + + assert ( + prefill_dp_rank is not None + ), "prefill_dp_rank must be resolved before creating receiver" + self.prefill_dp_rank = prefill_dp_rank + self._setup_bootstrap_infos() + + def _setup_bootstrap_infos(self): + # NOTE: key distinguished by bootstrap_addr, prefill_dp_rank, and target_tp_rank bootstrap_key = ( - f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}" + f"{self.bootstrap_addr}_{self.prefill_dp_rank}_{self.target_tp_rank}" ) if bootstrap_key not in self.kv_mgr.connection_pool: @@ -400,7 +412,7 @@ def __init__( # Enable higher PP ranks to be bootstrapped earlier to make PP PD requests bootstrap more robust for target_pp_rank in reversed(self.target_pp_ranks): bootstrap_info = self._get_bootstrap_info_from_server( - target_tp_rank, self.target_dp_group, target_pp_rank + target_tp_rank, self.prefill_dp_rank, target_pp_rank ) if bootstrap_info is not None: if self.kv_mgr.is_mla_backend: @@ -413,13 +425,13 @@ def __init__( # For non-MLA: all target_tp_ranks are selected real ranks bootstrap_info["is_dummy"] = False logger.debug( - f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}" + f"Fetched bootstrap info: {bootstrap_info} for DP {self.prefill_dp_rank} TP {target_tp_rank} PP {target_pp_rank}" ) bootstrap_infos.append(bootstrap_info) else: self.kv_mgr.record_failure( self.bootstrap_room, - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}", + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and prefill_dp_rank: {self.prefill_dp_rank} and target_pp_rank {target_pp_rank}", ) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) return @@ -435,11 +447,11 @@ def __init__( assert len(self.bootstrap_infos) > 0 def _get_bootstrap_info_from_server( - self, engine_rank, target_dp_group, target_pp_rank + self, engine_rank, prefill_dp_rank, target_pp_rank ): """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}" + url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&prefill_dp_rank={prefill_dp_rank}&target_pp_rank={target_pp_rank}" response = requests.get(url, timeout=5) if response.status_code == 200: bootstrap_info = response.json() @@ -453,29 +465,58 @@ def _get_bootstrap_info_from_server( logger.error(f"Error fetching prefill info from bootstrap: {e}") return None - def _get_prefill_parallel_info_from_server( - self, - ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]: - """Fetch the prefill parallel info from the bootstrap server.""" + @staticmethod + def _fetch_prefill_parallel_info( + bootstrap_addr: str, + ) -> Optional[Tuple[int, int, int, int, bool]]: + """Fetch the prefill parallel info from the bootstrap server. + + Returns (attn_tp_size, dp_size, pp_size, page_size, follow_bootstrap_room) + or None on failure. + """ try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}" - response = requests.get(url) + url = f"http://{bootstrap_addr}/route?engine_rank={-1}&prefill_dp_rank={-1}&target_pp_rank={-1}" + response = requests.get(url, timeout=5) if response.status_code == 200: - prefill_parallel_info = response.json() + info = response.json() return ( - int(prefill_parallel_info["prefill_attn_tp_size"]), - int(prefill_parallel_info["prefill_dp_size"]), - int(prefill_parallel_info["prefill_pp_size"]), - int(prefill_parallel_info["prefill_page_size"]), + int(info["prefill_attn_tp_size"]), + int(info["prefill_dp_size"]), + int(info["prefill_pp_size"]), + int(info["prefill_page_size"]), + bool(info.get("follow_bootstrap_room", True)), ) else: logger.error( f"Failed to get prefill parallel info: {response.status_code}, {response.text}" ) - return None, None, None, None + return None except Exception as e: logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") - return None, None, None, None + return None + + @staticmethod + def query_prefill_dp_ranks( + bootstrap_addr: str, bootstrap_rooms: List[int] + ) -> Dict[str, int]: + """Batch query prefill dp_ranks for given bootstrap_rooms.""" + try: + url = f"http://{bootstrap_addr}/query_dp_ranks" + response = requests.post( + url, + json={"bootstrap_rooms": bootstrap_rooms}, + timeout=5, + ) + if response.status_code == 200: + return response.json() + else: + logger.error( + f"Failed to query dp_ranks: {response.status_code}, {response.text}" + ) + return {} + except Exception as e: + logger.error(f"Error querying dp_ranks from bootstrap: {e}") + return {} @classmethod def _connect(cls, endpoint: str, is_ipv6: bool = False): @@ -507,7 +548,7 @@ def failure_exception(self): class CommonKVBootstrapServer(BaseKVBootstrapServer): - def __init__(self, host: str, port: int): + def __init__(self, host: str, port: int, dp_size: int = 1): self.host = host self.port = port self.app = web.Application() @@ -516,11 +557,16 @@ def __init__(self, host: str, port: int): self._setup_routes() self.pp_size = None self.attn_tp_size = None - self.dp_size = None + self.dp_size = dp_size self.page_size = None + self.follow_bootstrap_room: Optional[bool] = None self.prefill_port_table: Dict[ int, Dict[int, Dict[int, Dict[str, Union[str, int]]]] ] = {} + self.room_to_dp_rank: Dict[int, Dict[str, Union[int, float]]] = {} + self.entry_cleanup_interval = ( + envs.SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL.get() + ) # Start bootstrap server self.thread = threading.Thread(target=self._run_server, daemon=True) @@ -531,6 +577,8 @@ def run(self): def _setup_routes(self): self.app.router.add_route("*", "/route", self._handle_route) + self.app.router.add_post("/register_dp_rank", self._handle_register_dp_rank) + self.app.router.add_post("/query_dp_ranks", self._handle_query_dp_ranks) self.app.router.add_get("/health", self._handle_health_check) async def _handle_health_check(self, request): @@ -574,6 +622,12 @@ async def _handle_route_put(self, request: web.Request): if self.page_size is None and page_size is not None: self.page_size = page_size + if self.follow_bootstrap_room is None: + load_balance_method = data.get( + "load_balance_method", "follow_bootstrap_room" + ) + self.follow_bootstrap_room = load_balance_method == "follow_bootstrap_room" + if role == "Prefill": if system_dp_size == 1: dp_group = attn_dp_rank @@ -599,15 +653,14 @@ async def _handle_route_put(self, request: web.Request): async def _handle_route_get(self, request: web.Request): engine_rank = request.query.get("engine_rank") - target_dp_group = request.query.get("target_dp_group") + prefill_dp_rank = request.query.get("prefill_dp_rank") target_pp_rank = request.query.get("target_pp_rank") - if not engine_rank or not target_dp_group or not target_pp_rank: + if not engine_rank or not prefill_dp_rank or not target_pp_rank: return web.Response(text="Missing inputs for bootstrap server.", status=400) - # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size if ( int(engine_rank) == -1 - and int(target_dp_group) == -1 + and int(prefill_dp_rank) == -1 and int(target_pp_rank) == -1 ): prefill_parallel_info = { @@ -615,12 +668,17 @@ async def _handle_route_get(self, request: web.Request): "prefill_dp_size": self.dp_size, "prefill_pp_size": self.pp_size, "prefill_page_size": self.page_size, + "follow_bootstrap_room": ( + self.follow_bootstrap_room + if self.follow_bootstrap_room is not None + else True + ), } return web.json_response(prefill_parallel_info, status=200) # Find corresponding prefill info async with self.lock: - bootstrap_info = self.prefill_port_table[int(target_dp_group)][ + bootstrap_info = self.prefill_port_table[int(prefill_dp_rank)][ int(engine_rank) ][int(target_pp_rank)] @@ -629,12 +687,55 @@ async def _handle_route_get(self, request: web.Request): else: return web.Response(text="Bootstrap info not Found", status=404) + async def _handle_register_dp_rank(self, request: web.Request): + data = await request.json() + bootstrap_room = int(data["bootstrap_room"]) + dp_rank = int(data["dp_rank"]) + async with self.lock: + self.room_to_dp_rank[bootstrap_room] = { + "dp_rank": dp_rank, + "timestamp": time.time(), + } + logger.debug(f"Registered dp_rank={dp_rank} for {bootstrap_room=}") + return web.Response(text="OK", status=200) + + async def _handle_query_dp_ranks(self, request: web.Request): + data = await request.json() + bootstrap_rooms = data["bootstrap_rooms"] + result = {} + async with self.lock: + for room in bootstrap_rooms: + room_int = int(room) + if room_int in self.room_to_dp_rank: + result[str(room_int)] = self.room_to_dp_rank[room_int]["dp_rank"] + return web.json_response(result, status=200) + + async def _cleanup_expired_entries(self): + """Remove entries older than cleanup interval from room_to_dp_rank.""" + while True: + await asyncio.sleep(self.entry_cleanup_interval) + current_time = time.time() + async with self.lock: + expired_keys = [ + key + for key, value in self.room_to_dp_rank.items() + if current_time - value["timestamp"] > self.entry_cleanup_interval + ] + for key in expired_keys: + del self.room_to_dp_rank[key] + if expired_keys: + logger.debug( + f"Cleaned up {len(expired_keys)} expired entries from room_to_dp_rank" + ) + def _run_server(self): try: # Event Loop self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) + self._loop.create_task(self._cleanup_expired_entries()) + access_log = None if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG: access_log = self.app.logger diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 1d8baf0028e8..1b2d4741cc98 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -253,6 +253,7 @@ def __init__( # Queue for requests pending pre-allocation self.queue: List[DecodeRequest] = [] self.retracted_queue: List[Req] = [] + self.pending_reqs: List[Req] = [] self.prefill_pp_size = prefill_pp_size self.kv_manager = self._init_kv_manager() @@ -345,32 +346,59 @@ def add(self, req: Req, is_retracted: bool = False) -> None: req.retraction_mb_id = None self.retracted_queue.append(req) else: - # Auto enable FAKE mode if configured - if req.bootstrap_host == FAKE_BOOTSTRAP_HOST or ( - req.bootstrap_host is None - and self.scheduler.server_args.disaggregation_transfer_backend == "fake" - ): - kv_receiver_class = get_kv_class( - TransferBackend.FAKE, KVClassType.RECEIVER - ) - else: - kv_receiver_class = get_kv_class( - self.transfer_backend, KVClassType.RECEIVER - ) + dp_rank = self._resolve_dp_rank(req) + if dp_rank is None: + self.pending_reqs.append(req) + return + self._create_receiver_and_enqueue(req, dp_rank) + + def _resolve_dp_rank(self, req: Req) -> Optional[int]: + if req.data_parallel_rank is not None: + return req.data_parallel_rank + + if req.bootstrap_host == FAKE_BOOTSTRAP_HOST or ( + req.bootstrap_host is None + and self.scheduler.server_args.disaggregation_transfer_backend == "fake" + ): + return 0 - kv_receiver = kv_receiver_class( - mgr=self.kv_manager, - bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", - bootstrap_room=req.bootstrap_room, - prefill_dp_rank=req.data_parallel_rank, + bootstrap_addr = f"{req.bootstrap_host}:{req.bootstrap_port}" + + if bootstrap_addr not in self.kv_manager.prefill_dp_size_table: + return None + + if self.kv_manager.follow_bootstrap_room_table[bootstrap_addr]: + return ( + req.bootstrap_room + % self.kv_manager.prefill_dp_size_table[bootstrap_addr] ) - req.add_latency(RequestStage.DECODE_PREPARE) - trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True) - self.queue.append( - DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) + return None + + def _create_receiver_and_enqueue(self, req: Req, dp_rank: int) -> None: + if req.bootstrap_host == FAKE_BOOTSTRAP_HOST or ( + req.bootstrap_host is None + and self.scheduler.server_args.disaggregation_transfer_backend == "fake" + ): + kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER) + else: + kv_receiver_class = get_kv_class( + self.transfer_backend, KVClassType.RECEIVER ) + kv_receiver = kv_receiver_class( + mgr=self.kv_manager, + bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", + bootstrap_room=req.bootstrap_room, + prefill_dp_rank=dp_rank, + ) + + req.add_latency(RequestStage.DECODE_PREPARE) + trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True) + self.queue.append( + DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) + ) + def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool: if len(req.origin_input_ids) > self.max_total_num_tokens: message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}" @@ -465,10 +493,56 @@ def _update_handshake_waiters( else: raise ValueError(f"Unexpected poll case: {poll}") + def _resolve_pending_reqs(self) -> None: + """Batch-resolve dp_ranks for pending requests and create receivers.""" + if not self.pending_reqs: + return + + bootstrap_addr = f"{self.pending_reqs[0].bootstrap_host}:{self.pending_reqs[0].bootstrap_port}" + + # If a request is following the bootstrap room, + # we need get the prefill info before resolving the dp_rank, + # which is a conflict with the lazy resolve logic in CommonKVReceiver, + # so we need to ensure the parallel info before resolving the dp_rank + if not self.kv_manager.ensure_parallel_info(bootstrap_addr): + return + + resolved = [] + need_query = [] + for req in self.pending_reqs: + # NOTE: we need resolve it again because we may ensure the parallel info here + dp_rank = self._resolve_dp_rank(req) + if dp_rank is not None: + resolved.append((req, dp_rank)) + else: + need_query.append(req) + + if need_query: + from sglang.srt.disaggregation.common.conn import CommonKVReceiver + + rooms = [req.bootstrap_room for req in need_query] + room_to_rank = CommonKVReceiver.query_prefill_dp_ranks( + bootstrap_addr, rooms + ) + remaining = [] + for req in need_query: + room_key = str(req.bootstrap_room) + if room_key in room_to_rank: + resolved.append((req, int(room_to_rank[room_key]))) + else: + remaining.append(req) + self.pending_reqs = remaining + else: + self.pending_reqs = [] + + for req, dp_rank in resolved: + self._create_receiver_and_enqueue(req, dp_rank) + def pop_preallocated( self, rids_to_check: Optional[List[str]] = None ) -> Tuple[List[DecodeRequest], List[DecodeRequest]]: """Pop the preallocated requests from the pending queue (FIFO).""" + self._resolve_pending_reqs() self._update_handshake_waiters(rids_to_check) failed_reqs = [] @@ -1086,7 +1160,7 @@ def process_decode_queue(self: Scheduler): if self.polling_count % self.polling_interval == 0: req_conns, _ = self.disagg_decode_prealloc_queue.pop_preallocated() self.disagg_decode_transfer_queue.extend(req_conns) - alloc_reqs = ( + transferred_reqs = ( self.disagg_decode_transfer_queue.pop_transferred() ) # the requests which kv has arrived - self.waiting_queue.extend(alloc_reqs) + self.waiting_queue.extend(transferred_reqs) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index d0d4efd958da..4a80321cea42 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -1124,18 +1124,20 @@ def _handle_node_failure(self, failed_bootstrap_addr): ] for k in keys_to_remove: del self.connection_pool[k] - if failed_bootstrap_addr in self.prefill_attn_tp_size_table: - del self.prefill_attn_tp_size_table[failed_bootstrap_addr] - if failed_bootstrap_addr in self.prefill_dp_size_table: - del self.prefill_dp_size_table[failed_bootstrap_addr] - if failed_bootstrap_addr in self.prefill_pp_size_table: - del self.prefill_pp_size_table[failed_bootstrap_addr] possible_affected_rooms = self.addr_to_rooms_tracker.get( failed_bootstrap_addr, [] ) - if failed_bootstrap_addr in self.addr_to_rooms_tracker: - del self.addr_to_rooms_tracker[failed_bootstrap_addr] + keys_to_remove = [ + self.prefill_attn_tp_size_table, + self.prefill_dp_size_table, + self.prefill_pp_size_table, + self.follow_bootstrap_room_table, + self.addr_to_rooms_tracker, + ] + for k in keys_to_remove: + if failed_bootstrap_addr in k: + del k[failed_bootstrap_addr] # Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed affected_rooms = [] diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 2308b212ccf7..7ef525c2523e 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -259,6 +259,7 @@ class Envs: SGLANG_REQ_WAITING_TIMEOUT = EnvFloat(-1) # in seconds SGLANG_NCCL_ALL_GATHER_IN_OVERLAP_SCHEDULER_SYNC_BATCH = EnvBool(False) SGLANG_REQ_RUNNING_TIMEOUT = EnvFloat(-1) # in seconds + SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL = EnvInt(120) # Test: pd-disaggregation SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake") diff --git a/python/sglang/srt/managers/disagg_service.py b/python/sglang/srt/managers/disagg_service.py index 57dfcd32e279..a1ade7532c65 100644 --- a/python/sglang/srt/managers/disagg_service.py +++ b/python/sglang/srt/managers/disagg_service.py @@ -28,6 +28,7 @@ def start_disagg_service( bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class( host=server_args.host, port=server_args.disaggregation_bootstrap_port, + dp_size=server_args.dp_size, ) is_create_store = ( server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b080aeb1685a..adb40febfeaa 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -814,18 +814,6 @@ def _handle_load_balance_method(self): ) return - # Backward compat: in PD prefill, legacy "round_robin" means `bootstrap_room` routing. - if ( - self.disaggregation_mode == "prefill" - and self.load_balance_method == "round_robin" - ): - logger.warning( - "In PD-disaggregation prefill mode, the 'round_robin' load balancing method " - "means `bootstrap_room` routing (use 'follow_bootstrap_room' instead). " - "Falling back to 'follow_bootstrap_room' for backward compatibility." - ) - self.load_balance_method = "follow_bootstrap_room" - def _handle_deprecated_args(self): # Handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} diff --git a/test/registered/distributed/test_disaggregation_dp_attention.py b/test/registered/distributed/test_disaggregation_dp_attention.py index 3fbb620b6fae..249a11e75a7b 100644 --- a/test/registered/distributed/test_disaggregation_dp_attention.py +++ b/test/registered/distributed/test_disaggregation_dp_attention.py @@ -20,6 +20,7 @@ class TestDisaggregationDPAttention(PDDisaggregationServerBase): PREFILL_DP_SIZE = 4 DECODE_DP_SIZE = 4 + LOAD_BALANCE_METHOD = "auto" @classmethod def setUpClass(cls): @@ -50,6 +51,8 @@ def start_prefill(cls): "--dp", str(cls.PREFILL_DP_SIZE), "--enable-dp-attention", + "--load-balance-method", + cls.LOAD_BALANCE_METHOD, ] prefill_args += cls.transfer_backend + cls.rdma_devices cls.process_prefill = popen_launch_pd_server( @@ -72,6 +75,8 @@ def start_decode(cls): "--enable-dp-attention", "--base-gpu-id", str(cls.PREFILL_DP_SIZE), + "--load-balance-method", + cls.LOAD_BALANCE_METHOD, ] decode_args += cls.transfer_backend + cls.rdma_devices cls.process_decode = popen_launch_pd_server( @@ -97,5 +102,9 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.60) +class TestDisaggregationDPAttentionRoundRobin(TestDisaggregationDPAttention): + LOAD_BALANCE_METHOD = "round_robin" + + if __name__ == "__main__": unittest.main()