Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 68 additions & 79 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import dataclasses
import logging
import socket
import threading
Expand Down Expand Up @@ -43,6 +44,22 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class PrefillServerInfo:
attn_tp_size: int
dp_size: int
pp_size: int
page_size: Optional[int]
follow_bootstrap_room: bool

def __post_init__(self):
self.attn_tp_size = int(self.attn_tp_size)
self.dp_size = int(self.dp_size)
self.pp_size = int(self.pp_size)
self.page_size = int(self.page_size) if self.page_size is not None else None
self.follow_bootstrap_room = bool(self.follow_bootstrap_room)


class CommonKVManager(BaseKVManager):
def __init__(
self,
Expand Down Expand Up @@ -92,11 +109,7 @@ def __init__(
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.connection_lock = threading.Lock()
self.required_prefill_response_num_table: Dict[int, int] = {}
self.prefill_attn_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
self.prefill_pp_size_table: Dict[str, int] = {}
self.prefill_page_size_table: Dict[str, Optional[int]] = {}
self.follow_bootstrap_room_table: Dict[str, bool] = {}
self.prefill_info_table: Dict[str, PrefillServerInfo] = {}
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
Expand All @@ -106,30 +119,43 @@ def ensure_parallel_info(self, bootstrap_addr: str) -> bool:
"""Fetch and cache prefill parallel info if not yet available.
Returns True if info is available (cached or freshly fetched).
"""
if bootstrap_addr in self.prefill_dp_size_table:
if bootstrap_addr in self.prefill_info_table:
return True
info = CommonKVReceiver._fetch_prefill_parallel_info(bootstrap_addr)
info = self._fetch_prefill_server_info(bootstrap_addr)
if info is None:
return False
tp_size, dp_size, pp_size, page_size, follow_bootstrap_room = info

if page_size is not None and page_size != self.kv_args.page_size:
if info.page_size is not None and info.page_size != self.kv_args.page_size:
raise RuntimeError(
f"Page size mismatch: prefill server has page_size={page_size}, "
f"Page size mismatch: prefill server has page_size={info.page_size}, "
f"but decode server has page_size={self.kv_args.page_size}. "
f"Both servers must use the same --page-size value."
)

self.prefill_attn_tp_size_table[bootstrap_addr] = tp_size
self.prefill_dp_size_table[bootstrap_addr] = dp_size
self.prefill_pp_size_table[bootstrap_addr] = pp_size
self.prefill_page_size_table[bootstrap_addr] = page_size
self.follow_bootstrap_room_table[bootstrap_addr] = follow_bootstrap_room
logger.debug(
f"Prefill parallel info for [{bootstrap_addr}]: DP={dp_size} TP={tp_size} PP={pp_size} page_size={page_size} follow_bootstrap_room={follow_bootstrap_room}"
)
self.prefill_info_table[bootstrap_addr] = info
logger.debug(f"Prefill parallel info for [{bootstrap_addr}]: {info}")
return True

@staticmethod
def _fetch_prefill_server_info(
bootstrap_addr: str,
) -> Optional[PrefillServerInfo]:
"""Fetch the prefill server info from the bootstrap server."""
try:
url = f"http://{bootstrap_addr}/route?engine_rank={-1}&prefill_dp_rank={-1}&target_pp_rank={-1}"
response = requests.get(url, timeout=5)
if response.status_code == 200:
data = response.json()
return PrefillServerInfo(**data)
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill server info from bootstrap: {e}")
return None

def register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
Expand Down Expand Up @@ -315,38 +341,31 @@ def __init__(
self.bootstrap_infos = None
return

self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr]
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr]
self.prefill_page_size = self.kv_mgr.prefill_page_size_table.get(
self.bootstrap_addr
)
self.prefill_info = self.kv_mgr.prefill_info_table[self.bootstrap_addr]

# Handling for PD with different TP sizes per DP rank
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
if self.kv_mgr.attn_tp_size == self.prefill_info.attn_tp_size:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
)
self.required_dst_info_num = 1
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
self.prefill_info.pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
elif self.kv_mgr.attn_tp_size > self.prefill_info.attn_tp_size:
if not self.kv_mgr.is_mla_backend:
logger.warning_once(
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
) // (self.kv_mgr.attn_tp_size // self.prefill_info.attn_tp_size)
self.required_dst_info_num = (
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
self.kv_mgr.attn_tp_size // self.prefill_info.attn_tp_size
)
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
self.prefill_info.pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
else:
Expand All @@ -359,9 +378,9 @@ def __init__(
rank
for rank in range(
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
* (self.prefill_info.attn_tp_size // self.kv_mgr.attn_tp_size),
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
* (self.prefill_info.attn_tp_size // self.kv_mgr.attn_tp_size),
)
]

Expand All @@ -372,23 +391,23 @@ def __init__(
self.required_dst_info_num = 1
if self.kv_mgr.is_mla_backend:
self.required_prefill_response_num = (
self.prefill_pp_size // self.kv_mgr.pp_size
self.prefill_info.pp_size // self.kv_mgr.pp_size
)
else:
self.required_prefill_response_num = (
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
self.prefill_info.attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_info.pp_size // self.kv_mgr.pp_size)

# Decode pp size should be equal to prefill pp size or 1
assert (
self.kv_mgr.pp_size == self.prefill_pp_size or self.kv_mgr.pp_size == 1
self.kv_mgr.pp_size == self.prefill_info.pp_size or self.kv_mgr.pp_size == 1
), (
f"Decode pp size ({self.kv_mgr.pp_size}) should be equal to prefill pp size ({self.prefill_pp_size}) or 1",
f"Decode pp size ({self.kv_mgr.pp_size}) should be equal to prefill pp size ({self.prefill_info.pp_size}) or 1",
)
if self.prefill_pp_size == self.kv_mgr.pp_size:
if self.prefill_info.pp_size == self.kv_mgr.pp_size:
self.target_pp_ranks = [self.kv_mgr.pp_rank]
else:
self.target_pp_ranks = [rank for rank in range(self.prefill_pp_size)]
self.target_pp_ranks = [rank for rank in range(self.prefill_info.pp_size)]

self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_response_num
Expand Down Expand Up @@ -465,36 +484,6 @@ def _get_bootstrap_info_from_server(
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None

@staticmethod
def _fetch_prefill_parallel_info(
bootstrap_addr: str,
) -> Optional[Tuple[int, int, int, int, bool]]:
"""Fetch the prefill parallel info from the bootstrap server.

Returns (attn_tp_size, dp_size, pp_size, page_size, follow_bootstrap_room)
or None on failure.
"""
try:
url = f"http://{bootstrap_addr}/route?engine_rank={-1}&prefill_dp_rank={-1}&target_pp_rank={-1}"
response = requests.get(url, timeout=5)
if response.status_code == 200:
info = response.json()
return (
int(info["prefill_attn_tp_size"]),
int(info["prefill_dp_size"]),
int(info["prefill_pp_size"]),
int(info["prefill_page_size"]),
bool(info.get("follow_bootstrap_room", True)),
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None

@staticmethod
def query_prefill_dp_ranks(
bootstrap_addr: str, bootstrap_rooms: List[int]
Expand Down Expand Up @@ -663,18 +652,18 @@ async def _handle_route_get(self, request: web.Request):
and int(prefill_dp_rank) == -1
and int(target_pp_rank) == -1
):
prefill_parallel_info = {
"prefill_attn_tp_size": self.attn_tp_size,
"prefill_dp_size": self.dp_size,
"prefill_pp_size": self.pp_size,
"prefill_page_size": self.page_size,
"follow_bootstrap_room": (
info = PrefillServerInfo(
attn_tp_size=self.attn_tp_size,
dp_size=self.dp_size,
pp_size=self.pp_size,
page_size=self.page_size,
follow_bootstrap_room=(
self.follow_bootstrap_room
if self.follow_bootstrap_room is not None
else True
),
}
return web.json_response(prefill_parallel_info, status=200)
)
return web.json_response(dataclasses.asdict(info), status=200)

# Find corresponding prefill info
async with self.lock:
Expand Down
10 changes: 4 additions & 6 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,12 @@ def _resolve_dp_rank(self, req: Req) -> Optional[int]:

bootstrap_addr = f"{req.bootstrap_host}:{req.bootstrap_port}"

if bootstrap_addr not in self.kv_manager.prefill_dp_size_table:
prefill_info = self.kv_manager.prefill_info_table.get(bootstrap_addr)
if prefill_info is None:
return None

if self.kv_manager.follow_bootstrap_room_table[bootstrap_addr]:
return (
req.bootstrap_room
% self.kv_manager.prefill_dp_size_table[bootstrap_addr]
)
if prefill_info.follow_bootstrap_room:
return req.bootstrap_room % prefill_info.dp_size

return None

Expand Down
14 changes: 3 additions & 11 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def heartbeat_checker():
while True:
time.sleep(self.heartbeat_interval)
with self.connection_lock:
addresses = list(self.prefill_dp_size_table.keys())
addresses = list(self.prefill_info_table.keys())

for bootstrap_addr in addresses:
session = None
Expand Down Expand Up @@ -1128,16 +1128,8 @@ def _handle_node_failure(self, failed_bootstrap_addr):
possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, []
)
keys_to_remove = [
self.prefill_attn_tp_size_table,
self.prefill_dp_size_table,
self.prefill_pp_size_table,
self.follow_bootstrap_room_table,
self.addr_to_rooms_tracker,
]
for k in keys_to_remove:
if failed_bootstrap_addr in k:
del k[failed_bootstrap_addr]
self.prefill_info_table.pop(failed_bootstrap_addr, None)
self.addr_to_rooms_tracker.pop(failed_bootstrap_addr, None)

# Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed
affected_rooms = []
Expand Down
12 changes: 3 additions & 9 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def heartbeat_checker():
while True:
time.sleep(self.heartbeat_interval)
with self.connection_lock:
addresses = list(self.prefill_dp_size_table.keys())
addresses = list(self.prefill_info_table.keys())

for bootstrap_addr in addresses:
session = None
Expand Down Expand Up @@ -274,18 +274,12 @@ def _handle_node_failure(self, failed_bootstrap_addr):
]
for k in keys_to_remove:
del self.connection_pool[k]
if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_dp_size_table:
del self.prefill_dp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_pp_size_table:
del self.prefill_pp_size_table[failed_bootstrap_addr]
self.prefill_info_table.pop(failed_bootstrap_addr, None)

possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, []
)
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
self.addr_to_rooms_tracker.pop(failed_bootstrap_addr, None)

# Mark all pending transfers associated with the failed node as failed
affected_rooms = []
Expand Down
Loading