diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 928dd7530ab..3ed021a6bdc 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -53,39 +53,23 @@ class TransferInfo: required_dst_info_num: int def is_dummy(self): - return self.endpoint == "" + return self.dst_kv_indices.size == 0 @classmethod def from_zmq(cls, msg: List[bytes]): - if len(msg) == 1: - # dummy msg - return cls( - room=int(msg[0].decode("ascii")), - endpoint="", - dst_port=0, - agent_metadata=b"", - agent_name="", - dst_kv_ptrs=[], - dst_kv_indices=np.array([], dtype=np.int64), - dst_aux_ptrs=[], - dst_aux_index=0, - dst_gpu_id=0, - required_dst_info_num=0, - ) - else: - return cls( - room=int(msg[0].decode("ascii")), - endpoint=msg[1].decode("ascii"), - dst_port=int(msg[2].decode("ascii")), - agent_metadata=msg[3], - agent_name=msg[4].decode("ascii"), - dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), - dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64), - dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])), - dst_aux_index=int(msg[8].decode("ascii")), - dst_gpu_id=int(msg[9].decode("ascii")), - required_dst_info_num=int(msg[10].decode("ascii")), - ) + return cls( + room=int(msg[0].decode("ascii")), + endpoint=msg[1].decode("ascii"), + dst_port=int(msg[2].decode("ascii")), + agent_metadata=msg[3], + agent_name=msg[4].decode("ascii"), + dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), + dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64), + dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])), + dst_aux_index=int(msg[8].decode("ascii")), + dst_gpu_id=int(msg[9].decode("ascii")), + required_dst_info_num=int(msg[10].decode("ascii")), + ) @dataclasses.dataclass @@ -278,7 +262,7 @@ def add_transfer_request( for req in reqs_to_be_processed: assert bootstrap_room == req.room if req.is_dummy(): - return [] + continue peer_name = self._add_remote(req.agent_name, req.agent_metadata) chunked_dst_kv_indice = req.dst_kv_indices[index_slice] @@ -346,8 +330,7 @@ def bootstrap_thread(): ), f"First message should be {GUARD}. Foreign traffic?" waiting_req_bytes = waiting_req_bytes[1:] room = waiting_req_bytes[0].decode("ascii") - if room == "None": - continue + required_dst_info_num = int(waiting_req_bytes[10].decode("ascii")) room = int(room) agent_name = waiting_req_bytes[4].decode("ascii") @@ -438,19 +421,6 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non ) is_dummy = bootstrap_info["is_dummy"] - # TODO: just send "" for indices for dummy - if is_dummy: - # TODO: need to set success?? - sock, lock = self._connect("tcp://" + self.prefill_server_url) - with lock: - sock.send_multipart( - [ - GUARD, - str(self.bootstrap_room).encode("ascii"), - ] - ) - continue - # TODO: send_kv_args earlier packed_kv_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs @@ -473,7 +443,7 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non self.kv_mgr.agent.get_agent_metadata(), self.kv_mgr.agent.name.encode("ascii"), packed_kv_data_ptrs, - kv_indices.tobytes(), + kv_indices.tobytes() if not is_dummy else b"", packed_aux_data_ptrs, str(aux_index).encode("ascii"), str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),