Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 17 additions & 47 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,39 +53,23 @@ class TransferInfo:
required_dst_info_num: int

def is_dummy(self):
return self.endpoint == ""
return self.dst_kv_indices.size == 0

@classmethod
def from_zmq(cls, msg: List[bytes]):
if len(msg) == 1:
# dummy msg
return cls(
room=int(msg[0].decode("ascii")),
endpoint="",
dst_port=0,
agent_metadata=b"",
agent_name="",
dst_kv_ptrs=[],
dst_kv_indices=np.array([], dtype=np.int64),
dst_aux_ptrs=[],
dst_aux_index=0,
dst_gpu_id=0,
required_dst_info_num=0,
)
else:
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
required_dst_info_num=int(msg[10].decode("ascii")),
)
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
required_dst_info_num=int(msg[10].decode("ascii")),
)


@dataclasses.dataclass
Expand Down Expand Up @@ -278,7 +262,7 @@ def add_transfer_request(
for req in reqs_to_be_processed:
assert bootstrap_room == req.room
if req.is_dummy():
return []
continue

peer_name = self._add_remote(req.agent_name, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
Expand Down Expand Up @@ -346,8 +330,7 @@ def bootstrap_thread():
), f"First message should be {GUARD}. Foreign traffic?"
waiting_req_bytes = waiting_req_bytes[1:]
room = waiting_req_bytes[0].decode("ascii")
if room == "None":
continue

required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
room = int(room)
agent_name = waiting_req_bytes[4].decode("ascii")
Expand Down Expand Up @@ -438,19 +421,6 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
)
is_dummy = bootstrap_info["is_dummy"]

# TODO: just send "" for indices for dummy
if is_dummy:
# TODO: need to set success??
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
]
)
continue

# TODO: send_kv_args earlier
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
Expand All @@ -473,7 +443,7 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
self.kv_mgr.agent.get_agent_metadata(),
self.kv_mgr.agent.name.encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(),
kv_indices.tobytes() if not is_dummy else b"",
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
Expand Down
Loading