diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index aef6cbaf931..3dd8975c59a 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -31,23 +31,19 @@ logger = logging.getLogger(__name__) -NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]] - GUARD = "NixlMsgGuard".encode("ascii") @dataclasses.dataclass class TransferInfo: + """Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread.""" + room: int endpoint: str dst_port: int - agent_metadata: bytes agent_name: str - dst_kv_ptrs: list[int] dst_kv_indices: npt.NDArray[np.int32] - dst_aux_ptrs: list[int] dst_aux_index: int - dst_gpu_id: int required_dst_info_num: int def is_dummy(self): @@ -59,14 +55,37 @@ def from_zmq(cls, msg: List[bytes]): room=int(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), dst_port=int(msg[2].decode("ascii")), - agent_metadata=msg[3], - agent_name=msg[4].decode("ascii"), + agent_name=msg[3].decode("ascii"), + dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32), + dst_aux_index=int(msg[5].decode("ascii")), + required_dst_info_num=int(msg[6].decode("ascii")), + ) + + +@dataclasses.dataclass +class KVArgsRegisterInfo: + """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread.""" + + room: str + endpoint: str + dst_port: int + agent_name: str + agent_metadata: bytes + dst_kv_ptrs: list[int] + dst_aux_ptrs: list[int] + gpu_id: int + + @classmethod + def from_zmq(cls, msg: List[bytes]): + return cls( + room=str(msg[0].decode("ascii")), + endpoint=msg[1].decode("ascii"), + dst_port=int(msg[2].decode("ascii")), + agent_name=msg[3].decode("ascii"), + agent_metadata=msg[4], dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), - dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32), - dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])), - dst_aux_index=int(msg[8].decode("ascii")), - dst_gpu_id=int(msg[9].decode("ascii")), - required_dst_info_num=int(msg[10].decode("ascii")), + dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), + gpu_id=int(msg[7].decode("ascii")), ) @@ -109,9 +128,9 @@ def __init__( self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.request_status = {} - self.transfer_infos: Dict[int, TransferInfo] = {} - self.peer_names: Dict[str, str] = {} + self.request_status: Dict[int, KVPoll] = {} + self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} + self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( @@ -154,10 +173,13 @@ def register_buffer_to_engine(self): if not self.aux_descs: raise Exception("NIXL memory registration failed for aux tensors") - def _add_remote(self, agent_name: str, agent_metadata: bytes): - if agent_name not in self.peer_names: - self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata) - return self.peer_names[agent_name] + def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo): + agent_name = decode_kv_args.agent_name + if agent_name in self.decode_kv_args_table: + logger.info(f"Peer {agent_name} was already registered, ignoring.") + return + self.decode_kv_args_table[agent_name] = decode_kv_args + self.agent.add_remote_agent(decode_kv_args.agent_metadata) def send_kvcache( self, @@ -262,17 +284,17 @@ def add_transfer_request( if req.is_dummy(): continue - peer_name = self._add_remote(req.agent_name, req.agent_metadata) chunked_dst_kv_indice = req.dst_kv_indices[index_slice] assert len(chunked_dst_kv_indice) == len(kv_indices) + assert req.agent_name in self.decode_kv_args_table notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) kv_xfer_handle = self.send_kvcache( - peer_name, + req.agent_name, kv_indices, - req.dst_kv_ptrs, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, chunked_dst_kv_indice, - req.dst_gpu_id, + self.decode_kv_args_table[req.agent_name].gpu_id, notif, ) handles.append(kv_xfer_handle) @@ -280,13 +302,15 @@ def add_transfer_request( if is_last: assert aux_index is not None aux_xfer_handle = self.send_aux( - peer_name, + req.agent_name, aux_index, - req.dst_aux_ptrs, + self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, req.dst_aux_index, str(req.room) + "_aux", ) handles.append(aux_xfer_handle) + if is_last: + del self.transfer_infos[bootstrap_room] return handles def update_transfer_status(self): @@ -328,16 +352,23 @@ def bootstrap_thread(): ), f"First message should be {GUARD}. Foreign traffic?" waiting_req_bytes = waiting_req_bytes[1:] room = waiting_req_bytes[0].decode("ascii") - - required_dst_info_num = int(waiting_req_bytes[10].decode("ascii")) + agent_name = waiting_req_bytes[3].decode("ascii") + if room == "None": + # Register new peer and save KV base pointers. + self._add_remote_peer( + KVArgsRegisterInfo.from_zmq(waiting_req_bytes) + ) + logger.debug(f"Register KVArgs from {agent_name} successfully") + continue room = int(room) - agent_name = waiting_req_bytes[4].decode("ascii") if room not in self.transfer_infos: self.transfer_infos[room] = {} self.transfer_infos[room][agent_name] = TransferInfo.from_zmq( waiting_req_bytes ) - + required_dst_info_num = self.transfer_infos[room][ + agent_name + ].required_dst_info_num logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}") if len(self.transfer_infos[room]) == required_dst_info_num: logger.debug(f"{room=} is bootstrapped") @@ -391,6 +422,7 @@ def send( self.chunk_id += 1 if is_last: self.has_sent = True + del self.kv_mgr.request_status[self.bootstrap_room] def poll(self) -> KVPoll: if not self.has_sent: @@ -415,6 +447,7 @@ def __init__( data_parallel_rank: Optional[int] = None, ): self.started_transfer = False + self.conclude_state = None super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank) def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): @@ -426,17 +459,8 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" ) is_dummy = bootstrap_info["is_dummy"] - - # TODO: send_kv_args earlier - packed_kv_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs - ) - packed_aux_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs - ) - logger.debug( - f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}" + f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}" ) sock, lock = self._connect("tcp://" + self.prefill_server_url) with lock: @@ -446,13 +470,9 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non str(self.bootstrap_room).encode("ascii"), get_local_ip_by_remote().encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"), - self.kv_mgr.agent.get_agent_metadata(), self.kv_mgr.agent.name.encode("ascii"), - packed_kv_data_ptrs, kv_indices.tobytes() if not is_dummy else b"", - packed_aux_data_ptrs, str(aux_index).encode("ascii"), - str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), str(self.required_dst_info_num).encode("ascii"), ] ) @@ -460,17 +480,45 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non self.started_transfer = True def poll(self) -> KVPoll: + if self.conclude_state is not None: + return self.conclude_state if not self.started_transfer: return KVPoll.WaitingForInput # type: ignore self.kv_mgr.update_transfer_status() - if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore + self.conclude_state = KVPoll.Success + del self.kv_mgr.transfer_statuses[self.bootstrap_room] return KVPoll.Success # type: ignore return KVPoll.WaitingForInput # type: ignore def _register_kv_args(self): - pass + for bootstrap_info in self.bootstrap_infos: + self.prefill_server_url = ( + f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" + ) + packed_kv_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs + ) + packed_aux_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs + ) + + sock, lock = self._connect("tcp://" + self.prefill_server_url) + with lock: + sock.send_multipart( + [ + GUARD, + "None".encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.kv_mgr.agent.name.encode("ascii"), + self.kv_mgr.agent.get_agent_metadata(), + packed_kv_data_ptrs, + packed_aux_data_ptrs, + str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), + ] + ) def failure_exception(self): raise Exception("Fake KVReceiver Exception")