diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 699de552416..fb0fef46485 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -28,13 +28,7 @@ import torch from torch.distributed import ProcessGroup -from sglang.srt.disaggregation.base import ( - BaseKVManager, - BaseKVReceiver, - BaseKVSender, - KVArgs, - KVPoll, -) +from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll from sglang.srt.disaggregation.utils import ( DisaggregationMode, KVClassType, diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 062d43a5969..67159784d82 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -329,7 +329,7 @@ def _register_to_bootstrap(self): "role": "Prefill", "rank_ip": get_local_ip_by_remote(), "rank_port": self.rank_port, - "bootstrap_key": f"{bootstrap_server_url}_{self.kv_args.engine_rank}", + "engine_rank": self.kv_args.engine_rank, } try: @@ -400,28 +400,29 @@ def __init__( self.session_id = self.kv_mgr.get_session_id() self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) - self.bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" + # NOTE: key distinguished by bootstrap_addr and engine_rank + bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" - if self.bootstrap_key not in self.kv_mgr.connection_pool: + if bootstrap_key not in self.kv_mgr.connection_pool: self.bootstrap_info = self._get_bootstrap_info_from_server( - self.bootstrap_key + self.kv_mgr.kv_args.engine_rank ) if self.bootstrap_info is None: logger.error( f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" ) else: - self.kv_mgr.connection_pool[self.bootstrap_key] = self.bootstrap_info + self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info else: - self.bootstrap_info = self.kv_mgr.connection_pool[self.bootstrap_key] + self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key] assert self.bootstrap_info is not None self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) - def _get_bootstrap_info_from_server(self, bootstrap_key: str): + def _get_bootstrap_info_from_server(self, engine_rank): """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?bootstrap_key={bootstrap_key}" + url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}" response = requests.get(url) if response.status_code == 200: bootstrap_info = response.json() @@ -556,28 +557,28 @@ async def _handle_route_put(self, request: web.Request): role = data["role"] rank_ip = data["rank_ip"] rank_port = int(data["rank_port"]) - bootstrap_key = data["bootstrap_key"] + engine_rank = int(data["engine_rank"]) # Add lock to make sure thread-safe if role == "Prefill": - self.prefill_port_table[bootstrap_key] = { + self.prefill_port_table[engine_rank] = { "rank_ip": rank_ip, "rank_port": rank_port, } logger.debug( - f"Registered Prefill bootstrap_key: {bootstrap_key} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200) async def _handle_route_get(self, request: web.Request): - bootstrap_key = request.query.get("bootstrap_key") - if not bootstrap_key: - return web.Response(text="Missing bootstrap_key", status=400) + engine_rank = request.query.get("engine_rank") + if not engine_rank: + return web.Response(text="Missing rank", status=400) # Find corresponding prefill info async with self.lock: - bootstrap_info = self.prefill_port_table.get(bootstrap_key) + bootstrap_info = self.prefill_port_table.get(int(engine_rank)) if bootstrap_info is not None: return web.json_response(bootstrap_info, status=200) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 408402c4f83..7ad548ccc3d 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -24,13 +24,7 @@ import torch -from sglang.srt.disaggregation.base import ( - BaseKVManager, - BaseKVReceiver, - BaseKVSender, - KVArgs, - KVPoll, -) +from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll from sglang.srt.disaggregation.utils import ( DisaggregationMode, KVClassType,