Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 99 additions & 10 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ def __init__(
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.enable_dp_attention = server_args.enable_dp_attention
if not server_args.enable_dp_attention and server_args.dp_size != 1:
raise ValueError(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self.request_status: Dict[int, KVPoll] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
Expand All @@ -121,6 +128,7 @@ def __init__(
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
Expand Down Expand Up @@ -331,6 +339,8 @@ def _register_to_bootstrap(self):
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"tp_size": self.tp_size,
"dp_size": self.dp_size,
"rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
Expand Down Expand Up @@ -408,12 +418,41 @@ def __init__(
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)

if not self.kv_mgr.enable_dp_attention:
# We assume dp_attention should be activated simultaneously for
# both prefill role and decode role. If the decode instance does
# not enable dp_attention, then dp_attention is not enabled on the
# prefill instance as well. Therefore, we should skip questioning
# the prefill dp size to reduce bootstrap overhead.
self.prefill_dp_size = 1
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_dp_size, tp_size_per_dp_rank = (
self._get_prefill_dp_size_from_server()
)
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank.
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
if self.prefill_dp_size is None:
logger.error(
f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
)
else:
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
else:
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]

# NOTE: key distinguished by bootstrap_addr and engine_rank
self.target_dp_group = bootstrap_room % self.prefill_dp_size
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"

if bootstrap_key not in self.kv_mgr.connection_pool:
self.bootstrap_info = self._get_bootstrap_info_from_server(
self.kv_mgr.kv_args.engine_rank
self.kv_mgr.kv_args.engine_rank,
self.target_dp_group,
)
if self.bootstrap_info is None:
logger.error(
Expand All @@ -427,10 +466,10 @@ def __init__(
assert self.bootstrap_info is not None
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)

def _get_bootstrap_info_from_server(self, engine_rank):
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
response = requests.get(url)
if response.status_code == 200:
bootstrap_info = response.json()
Expand All @@ -444,6 +483,25 @@ def _get_bootstrap_info_from_server(self, engine_rank):
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None

def _get_prefill_dp_size_from_server(self) -> int:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return int(prefill_parallel_info["prefill_dp_size"]), int(
prefill_parallel_info["tp_size_per_dp_rank"]
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None

@classmethod
def _connect(cls, endpoint: str):
with cls._global_lock:
Expand Down Expand Up @@ -497,7 +555,9 @@ def __init__(self, port: int):
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
self.dp_size = None
self.tp_size_per_dp_rank = None
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}

# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
Expand All @@ -523,35 +583,64 @@ async def _handle_route(self, request: web.Request):
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
tp_size = data["tp_size"]
dp_size = data["dp_size"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])

if self.dp_size is None:
self.dp_size = dp_size

tp_size_per_dp_rank = tp_size // dp_size
if self.tp_size_per_dp_rank == None:
self.tp_size_per_dp_rank = tp_size_per_dp_rank

# Add lock to make sure thread-safe
if role == "Prefill":
self.prefill_port_table[engine_rank] = {
dp_group = engine_rank // tp_size_per_dp_rank
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank

async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}

self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
f"Registered Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)

return web.Response(text="OK", status=200)

async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
if not engine_rank:
return web.Response(text="Missing rank", status=400)
target_dp_group = request.query.get("target_dp_group")
if not engine_rank or not target_dp_group:
return web.Response(text="Missing inputs for bootstrap server.", status=400)

# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if int(engine_rank) == -1 and int(target_dp_group) == -1:
prefill_parallel_info = {
"prefill_dp_size": self.dp_size,
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
}
return web.json_response(prefill_parallel_info, status=200)

# Find corresponding prefill info
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank

async with self.lock:
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
tp_rank_in_dp_group
]

if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Not Found", status=404)
return web.Response(text="Bootstrap info not Found", status=404)

def _run_server(self):
try:
Expand Down
Loading