Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 100 additions & 22 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import dataclasses
import logging
import os
import queue
import socket
import struct
Expand Down Expand Up @@ -73,9 +75,7 @@ class TransferInfo:
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int]
dst_aux_index: int

@classmethod
Expand All @@ -85,10 +85,29 @@ def from_zmq(cls, msg: List[bytes]):
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
dst_aux_index=int(msg[5].decode("ascii")),
)


@dataclasses.dataclass
class KVArgsRegisterInfo:
room: str
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]

@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=str(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
)


Expand Down Expand Up @@ -123,8 +142,15 @@ def __init__(
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, TransferInfo] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread()
self._register_to_bootstrap()

# Determine the number of threads to use for kv sender
cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=cpu_count if cpu_count is not None else 64
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
Expand Down Expand Up @@ -158,28 +184,53 @@ def send_kvcache(
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
):
# group by indices
# Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)

num_layers = len(self.kv_args.kv_data_ptrs)
for layer_id in range(num_layers):
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
dst_ptr = dst_kv_ptrs[layer_id]
item_len = self.kv_args.kv_item_lens[layer_id]
layers_params = [
(
self.kv_args.kv_data_ptrs[layer_id],
dst_kv_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
)
for layer_id in range(num_layers)
]

# Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index)

# TODO: make async later
status = self.engine.transfer_sync(
mooncake_session_id, src_addr, dst_addr, length
)
if status != 0:
return status
return 0

futures = [
self.executor.submit(
process_layer,
src_ptr,
dst_ptr,
item_len,
)
for (src_ptr, dst_ptr, item_len) in layers_params
]

for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
# Immediate shutdown on first error (existing tasks will finish)
executor.shutdown(wait=False)
for f in futures:
f.cancel()
return status

return 0

Expand Down Expand Up @@ -223,6 +274,13 @@ def bootstrap_thread():
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[0].decode("ascii")
if room == "None":
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
logger.debug(
f"Register KVArgs from {mooncake_session_id} successfully"
)
continue
room = int(room)
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
Expand All @@ -244,7 +302,7 @@ def transfer_thread():
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
req.dst_kv_ptrs,
self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
chunked_dst_kv_indice,
)
if ret != 0:
Expand All @@ -259,7 +317,9 @@ def transfer_thread():
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
req.dst_aux_ptrs,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
)
self.request_status[req.room] = (
Expand Down Expand Up @@ -460,6 +520,8 @@ def __init__(
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]

Expand Down Expand Up @@ -502,6 +564,30 @@ def _get_prefill_dp_size_from_server(self) -> int:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None

def _register_kv_args(self):
self.prefill_server_url = (
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
)

packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
"None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
]
)

@classmethod
def _connect(cls, endpoint: str):
with cls._global_lock:
Expand All @@ -520,12 +606,6 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)

packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
Expand All @@ -534,9 +614,7 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
]
)
Expand Down Expand Up @@ -610,7 +688,7 @@ async def _handle_route_put(self, request: web.Request):
"rank_port": rank_port,
}
logger.debug(
f"Registered Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)

return web.Response(text="OK", status=200)
Expand Down
Loading