Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

logger = logging.getLogger(__name__)


def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
Expand All @@ -43,6 +44,7 @@ def find_available_ports(base_port: int, count: int) -> List[int]:

return available_ports


def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
Expand Down Expand Up @@ -265,7 +267,9 @@ def transfer_thread():
)
if ret != 0:
self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
continue

if kv_chunk.is_last:
Expand All @@ -279,7 +283,9 @@ def transfer_thread():
self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed
)
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
self.transfer_infos.pop(req.room)

except queue.Empty:
Expand Down Expand Up @@ -443,13 +449,14 @@ def _get_prefill_info_from_bootstrap(self, tp_rank: int):
prefill_info = response.json()
return prefill_info
else:
logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}")
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None


@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
Expand All @@ -466,17 +473,25 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
)
if prefill_info is None:
logger.error(
logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}")
logger.error(
f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
)
else:
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = (
prefill_info
)
else:
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]

if prefill_info:
self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
self.prefill_server_url = (
f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
)

logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}")
logger.info(
f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
self.handshake_prefill_server(kv_indices, aux_index)

def handshake_prefill_server(
Expand Down Expand Up @@ -598,8 +613,13 @@ async def _handle_kv_route_put(self, request: web.Request):
# Add lock to make sure thread-safe
if role == "Prefill":
async with self.lock:
self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port}
logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}")
self.prefill_port_table[tp_rank] = {
"serve_ip": serve_ip,
"serve_port": serve_port,
}
logger.info(
f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}"
)

return web.Response(text="OK", status=200)

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/entrypoints/verl_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def generate(
rank=self._tp_rank,
dist_group=self._device_mesh_cpu.get_group(),
src=self._device_mesh_cpu.mesh[0].item(),
force_cpu_device=False,
)

return output
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,9 +846,12 @@ def broadcast_pyobj(
rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0,
force_cpu_device: bool = True,
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
)

if rank == 0:
if len(data) == 0:
Expand Down
Loading