Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 94 additions & 46 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,19 @@

logger = logging.getLogger(__name__)

NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]

GUARD = "NixlMsgGuard".encode("ascii")


@dataclasses.dataclass
class TransferInfo:
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""

room: int
endpoint: str
dst_port: int
agent_metadata: bytes
agent_name: str
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_ptrs: list[int]
dst_aux_index: int
dst_gpu_id: int
required_dst_info_num: int

def is_dummy(self):
Expand All @@ -59,14 +55,37 @@ def from_zmq(cls, msg: List[bytes]):
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"),
agent_name=msg[3].decode("ascii"),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
dst_aux_index=int(msg[5].decode("ascii")),
required_dst_info_num=int(msg[6].decode("ascii")),
)


@dataclasses.dataclass
class KVArgsRegisterInfo:
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""

room: str
endpoint: str
dst_port: int
agent_name: str
agent_metadata: bytes
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]
gpu_id: int

@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=str(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_name=msg[3].decode("ascii"),
agent_metadata=msg[4],
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
required_dst_info_num=int(msg[10].decode("ascii")),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
gpu_id=int(msg[7].decode("ascii")),
)


Expand Down Expand Up @@ -109,9 +128,9 @@ def __init__(
self.register_buffer_to_engine()

if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.request_status = {}
self.transfer_infos: Dict[int, TransferInfo] = {}
self.peer_names: Dict[str, str] = {}
self.request_status: Dict[int, KVPoll] = {}
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self._start_bootstrap_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
Expand Down Expand Up @@ -154,10 +173,13 @@ def register_buffer_to_engine(self):
if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors")

def _add_remote(self, agent_name: str, agent_metadata: bytes):
if agent_name not in self.peer_names:
self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata)
return self.peer_names[agent_name]
def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
agent_name = decode_kv_args.agent_name
if agent_name in self.decode_kv_args_table:
logger.info(f"Peer {agent_name} was already registered, ignoring.")
return
self.decode_kv_args_table[agent_name] = decode_kv_args
self.agent.add_remote_agent(decode_kv_args.agent_metadata)

def send_kvcache(
self,
Expand Down Expand Up @@ -262,31 +284,33 @@ def add_transfer_request(
if req.is_dummy():
continue

peer_name = self._add_remote(req.agent_name, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
assert len(chunked_dst_kv_indice) == len(kv_indices)
assert req.agent_name in self.decode_kv_args_table

notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
kv_xfer_handle = self.send_kvcache(
peer_name,
req.agent_name,
kv_indices,
req.dst_kv_ptrs,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
req.dst_gpu_id,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
)
handles.append(kv_xfer_handle)
# Only the last chunk we need to send the aux data.
if is_last:
assert aux_index is not None
aux_xfer_handle = self.send_aux(
peer_name,
req.agent_name,
aux_index,
req.dst_aux_ptrs,
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
req.dst_aux_index,
str(req.room) + "_aux",
)
handles.append(aux_xfer_handle)
if is_last:
del self.transfer_infos[bootstrap_room]
return handles

def update_transfer_status(self):
Expand Down Expand Up @@ -328,16 +352,23 @@ def bootstrap_thread():
), f"First message should be {GUARD}. Foreign traffic?"
waiting_req_bytes = waiting_req_bytes[1:]
room = waiting_req_bytes[0].decode("ascii")

required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
agent_name = waiting_req_bytes[3].decode("ascii")
if room == "None":
# Register new peer and save KV base pointers.
self._add_remote_peer(
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
logger.debug(f"Register KVArgs from {agent_name} successfully")
continue
room = int(room)
agent_name = waiting_req_bytes[4].decode("ascii")
if room not in self.transfer_infos:
self.transfer_infos[room] = {}
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
waiting_req_bytes
)

required_dst_info_num = self.transfer_infos[room][
agent_name
].required_dst_info_num
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
if len(self.transfer_infos[room]) == required_dst_info_num:
logger.debug(f"{room=} is bootstrapped")
Expand Down Expand Up @@ -391,6 +422,7 @@ def send(
self.chunk_id += 1
if is_last:
self.has_sent = True
del self.kv_mgr.request_status[self.bootstrap_room]

def poll(self) -> KVPoll:
if not self.has_sent:
Expand All @@ -415,6 +447,7 @@ def __init__(
data_parallel_rank: Optional[int] = None,
):
self.started_transfer = False
self.conclude_state = None
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)

def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
Expand All @@ -426,17 +459,8 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
is_dummy = bootstrap_info["is_dummy"]

# TODO: send_kv_args earlier
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)

logger.debug(
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
)
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
Expand All @@ -446,31 +470,55 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
self.kv_mgr.agent.name.encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes() if not is_dummy else b"",
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
str(self.required_dst_info_num).encode("ascii"),
]
)

self.started_transfer = True

def poll(self) -> KVPoll:
if self.conclude_state is not None:
return self.conclude_state
if not self.started_transfer:
return KVPoll.WaitingForInput # type: ignore

self.kv_mgr.update_transfer_status()

if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
self.conclude_state = KVPoll.Success
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
return KVPoll.Success # type: ignore
return KVPoll.WaitingForInput # type: ignore

def _register_kv_args(self):
pass
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)

sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
"None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
]
)

def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
Expand Down
Loading