Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 160 additions & 86 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
CommonKVReceiver,
CommonKVSender,
)
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.common.utils import (
FastQueue,
group_concurrent_contiguous,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
filter_kv_indices_for_cp_rank,
Expand Down Expand Up @@ -68,6 +71,17 @@ def from_zmq(cls, msg: List[bytes]):
)


@dataclasses.dataclass
class TransferKVChunk:
room: int
prefill_kv_indices: npt.NDArray[np.int32]
index_slice: slice
is_last: bool
chunk_id: int
prefill_aux_index: Optional[int]
state_indices: Optional[List[int]]


@dataclasses.dataclass
class KVArgsRegisterInfo:
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
Expand Down Expand Up @@ -203,6 +217,14 @@ def __init__(
self.register_buffer_to_engine()

if self.disaggregation_mode == DisaggregationMode.PREFILL:
transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get()
self.transfer_queues: List[FastQueue] = [
FastQueue() for _ in range(transfer_queue_size)
]
for queue in self.transfer_queues:
threading.Thread(
target=self.transfer_worker, args=(queue,), daemon=True
).start()
self._start_bootstrap_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
Expand Down Expand Up @@ -300,6 +322,118 @@ def _handle_node_failure(self, failed_bootstrap_addr):
logger.error(f"Let room {room} be failed due to prefill down")
self.update_status(room, KVPoll.Failed)

def check_status(self, bootstrap_room: int):
return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping)

def transfer_worker(self, queue: FastQueue):
while True:
kv_chunk: TransferKVChunk = queue.get()
room = kv_chunk.room
try:
if self.check_status(room) == KVPoll.Failed:
continue

assert room in self.transfer_infos

self.update_status(room, KVPoll.Transferring)

reqs_to_be_processed = list(self.transfer_infos[room].values())
handles: List = []

for req in reqs_to_be_processed:
assert room == req.room
if req.is_dummy():
continue

chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
if len(chunked_dst_kv_indice) < len(kv_chunk.prefill_kv_indices):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code has:

assert len(chunked_dst_kv_indice) == len(kv_indices)

Why was this changed?

kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
: len(chunked_dst_kv_indice)
]
assert req.agent_name in self.decode_kv_args_table

notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.engine_rank}"
decode_tp_size = self.decode_kv_args_table[
req.agent_name
].decode_tp_size

if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
)
else:
kv_xfer_handle = self.send_kvcache_slice(
req.agent_name,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
].decode_tp_rank,
dst_kv_item_len=self.decode_kv_args_table[
req.agent_name
].dst_kv_item_len,
)
handles.append(kv_xfer_handle)

if kv_chunk.is_last:
if kv_chunk.state_indices is not None:
dst_info = self.decode_kv_args_table[req.agent_name]
state_xfer_handle = self.maybe_send_extra(
req.agent_name,
kv_chunk.state_indices,
dst_info.dst_state_data_ptrs,
req.dst_state_indices,
dst_info.gpu_id,
f"{req.room}_state_{self.kv_args.engine_rank}",
decode_tp_size,
decode_tp_rank=dst_info.decode_tp_rank,
dst_state_item_lens=dst_info.dst_state_item_lens,
dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor,
)
if state_xfer_handle is not None:
handles.append(state_xfer_handle)

if kv_chunk.prefill_aux_index is None:
raise RuntimeError("Missing aux index for last chunk")
aux_xfer_handle = self.send_aux(
req.agent_name,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
req.dst_aux_index,
f"{req.room}_aux",
)
handles.append(aux_xfer_handle)

while handles:
states = [self.agent.check_xfer_state(h) for h in handles]
if any(s == "ERR" for s in states):
raise RuntimeError(f"NIXL transfer encountered ERR room={room}")
if all(s == "DONE" for s in states):
break
time.sleep(0)

if kv_chunk.is_last:
if room in self.transfer_infos:
del self.transfer_infos[room]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see multiple threads modifying and reading self.transfer_infos. Do we need locking? Or pass the data to the workers in a safer way?

self.update_status(room, KVPoll.Success)
else:
self.update_status(room, KVPoll.Transferring)
except Exception as e:
reason = f"Prefill transfer worker error room={room}: {e}"
logger.exception(reason)
self.record_failure(room, reason)
self.update_status(room, KVPoll.Failed)

def register_buffer_to_engine(self):
kv_addrs = []
for kv_data_ptr, kv_data_len in zip(
Expand Down Expand Up @@ -872,81 +1006,22 @@ def add_transfer_request(
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)

reqs_to_be_processed = self.transfer_infos[bootstrap_room].values()
handles = []
for req in reqs_to_be_processed:
assert bootstrap_room == req.room
if req.is_dummy():
continue

chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
assert len(chunked_dst_kv_indice) == len(kv_indices)
assert req.agent_name in self.decode_kv_args_table

notif = (
f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.engine_rank}"
if bootstrap_room not in self.request_status:
self.update_status(bootstrap_room, KVPoll.Bootstrapping)

shard_idx = bootstrap_room % len(self.transfer_queues)
self.transfer_queues[shard_idx].put(
TransferKVChunk(
room=bootstrap_room,
prefill_kv_indices=kv_indices,
index_slice=index_slice,
is_last=is_last,
chunk_id=chunk_id,
prefill_aux_index=aux_index,
state_indices=state_indices,
)
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size

if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
)
else:
kv_xfer_handle = self.send_kvcache_slice(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
].decode_tp_rank,
dst_kv_item_len=self.decode_kv_args_table[
req.agent_name
].dst_kv_item_len,
)

handles.append(kv_xfer_handle)
# Only the last chunk we need to send the aux data.
if is_last:
if state_indices is not None:
dst_info = self.decode_kv_args_table[req.agent_name]
state_xfer_handle = self.maybe_send_extra(
req.agent_name,
state_indices,
dst_info.dst_state_data_ptrs,
req.dst_state_indices,
dst_info.gpu_id,
f"{req.room}_state_{self.kv_args.engine_rank}",
decode_tp_size,
decode_tp_rank=dst_info.decode_tp_rank,
dst_state_item_lens=dst_info.dst_state_item_lens,
dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor,
)
if state_xfer_handle is not None:
handles.append(state_xfer_handle)

assert aux_index is not None
aux_xfer_handle = self.send_aux(
req.agent_name,
aux_index,
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
req.dst_aux_index,
f"{req.room}_aux",
)
handles.append(aux_xfer_handle)
if is_last:
del self.transfer_infos[bootstrap_room]
return handles
)
return None

def update_transfer_status(self):
# Process notifications from received transfers.
Expand Down Expand Up @@ -1035,7 +1110,6 @@ def __init__(
pp_rank: int,
):
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
self.xfer_handles = []
self.has_sent = False
self.chunk_id = 0

Expand All @@ -1062,7 +1136,7 @@ def send(
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success)
return

new_xfer_handles = self.kv_mgr.add_transfer_request(
self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
Expand All @@ -1071,21 +1145,21 @@ def send(
self.aux_index,
state_indices,
)
self.xfer_handles.extend(new_xfer_handles)
self.chunk_id += 1
if is_last:
self.has_sent = True
del self.kv_mgr.request_status[self.bootstrap_room]

def poll(self) -> KVPoll:
status = self.kv_mgr.check_status(self.bootstrap_room)
if not self.has_sent:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we return status here in all cases, so we can just return directly and remove the ifs

return self.kv_mgr.check_status(self.bootstrap_room)
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
if all([x == "DONE" for x in states]):
return KVPoll.Success # type: ignore
if any([x == "ERR" for x in states]):
raise Exception("KVSender transfer encountered an error.")
return KVPoll.WaitingForInput # type: ignore
return status
if status in (KVPoll.Success, KVPoll.Failed):
return status
return status

def clear(self):
if self.bootstrap_room in self.kv_mgr.request_status:
self.kv_mgr.request_status.pop(self.bootstrap_room)

def failure_exception(self):
raise RuntimeError("NIXL KVSender Exception")
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,11 @@ def process_disagg_prefill_inflight_queue(
undone_reqs.append(req)
continue

if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
if poll in [
KVPoll.Bootstrapping,
KVPoll.WaitingForInput,
KVPoll.Transferring,
]:
undone_reqs.append(req)
elif poll == KVPoll.Success: # transfer done
release_kv_cache(req, self.tree_cache) # unlock the tree
Expand Down
Loading