Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 201 additions & 132 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
CommonKVReceiver,
CommonKVSender,
)
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.common.utils import (
FastQueue,
group_concurrent_contiguous,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
filter_kv_indices_for_cp_rank,
Expand Down Expand Up @@ -94,6 +97,17 @@ def from_zmq(cls, msg: List[bytes]):
)


@dataclasses.dataclass
class TransferKVChunk:
room: int
prefill_kv_indices: npt.NDArray[np.int32]
index_slice: slice
is_last: bool
chunk_id: int
prefill_aux_index: Optional[int]
state_indices: Optional[List[int]]


@dataclasses.dataclass
class KVArgsRegisterInfo:
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
Expand Down Expand Up @@ -248,6 +262,15 @@ def __init__(
self.register_buffer_to_engine()

if self.disaggregation_mode == DisaggregationMode.PREFILL:
transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get()
self.transfer_queues: List[FastQueue] = [
FastQueue() for _ in range(transfer_queue_size)
]
self.exceptions: Dict[int, Exception] = {}
for queue in self.transfer_queues:
threading.Thread(
target=self.transfer_worker, args=(queue,), daemon=True
).start()
self._start_bootstrap_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
Expand Down Expand Up @@ -345,6 +368,146 @@ def _handle_node_failure(self, failed_bootstrap_addr):
logger.error(f"Let room {room} be failed due to prefill down")
self.update_status(room, KVPoll.Failed)

def check_status(self, bootstrap_room: int):
return self.request_status.get(bootstrap_room, KVPoll.WaitingForInput)

def transfer_worker(self, queue: FastQueue):
while True:
kv_chunk: TransferKVChunk = queue.get()
room = kv_chunk.room
try:
if self.check_status(room) == KVPoll.Failed:
continue

assert room in self.transfer_infos

self.update_status(room, KVPoll.Transferring)

reqs_to_be_processed = list(self.transfer_infos[room].values())
handles: List = []

for req in reqs_to_be_processed:
assert room == req.room
if req.is_dummy():
continue

assert req.agent_name in self.decode_kv_args_table
decode_tp_size = self.decode_kv_args_table[
req.agent_name
].decode_tp_size

# Skip KV RDMA transfer when there are no pages to send
# (e.g., decode-side radix cache matched the entire prefix).
# Aux data is still sent below when is_last=True.
if len(kv_chunk.prefill_kv_indices) > 0:
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]

# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
: len(chunked_dst_kv_indice)
]

notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.engine_rank}"

if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
)
else:
kv_xfer_handle = self.send_kvcache_slice(
req.agent_name,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
].decode_tp_rank,
dst_kv_item_len=self.decode_kv_args_table[
req.agent_name
].dst_kv_item_len,
)

handles.append(kv_xfer_handle)

if kv_chunk.is_last:
if kv_chunk.state_indices is not None:
dst_info = self.decode_kv_args_table[req.agent_name]
state_xfer_handle = self.maybe_send_extra(
req.agent_name,
kv_chunk.state_indices,
dst_info.dst_state_data_ptrs,
req.dst_state_indices,
dst_info.gpu_id,
f"{req.room}_state_{self.kv_args.engine_rank}",
decode_tp_size,
decode_tp_rank=dst_info.decode_tp_rank,
dst_state_item_lens=dst_info.dst_state_item_lens,
dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor,
)
if state_xfer_handle is not None:
handles.append(state_xfer_handle)

if kv_chunk.prefill_aux_index is None:
raise RuntimeError("Missing aux index for last chunk")
# When no KV pages were sent (decode-side cache hit),
# encode pp_rank in aux notif so receiver can mark
# expected_kvs_per_pp[pp_rank] = 0.
if len(kv_chunk.prefill_kv_indices) == 0:
aux_notif = (
f"{req.room}_aux_nokv_{self.kv_args.engine_rank}"
)
else:
aux_notif = f"{req.room}_aux"
aux_xfer_handle = self.send_aux(
req.agent_name,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
req.dst_aux_index,
aux_notif,
)
handles.append(aux_xfer_handle)

while handles:
states = [self.agent.check_xfer_state(h) for h in handles]
if any(s == "ERR" for s in states):
raise RuntimeError(f"NIXL transfer encountered ERR room={room}")
if all(s == "DONE" for s in states):
break
time.sleep(0)

if kv_chunk.is_last:
self.update_status(room, KVPoll.Success)
else:
self.update_status(room, KVPoll.Transferring)
except Exception as e:
# Catch all exceptions to prevent silently killing this
# worker thread, but still propagate via failure_exception().
if isinstance(e, _NIXL_TRANSPORT_ERRORS):
logger.warning(f"NIXL transport error for room {room}: {e}")
else:
logger.exception(
f"Unexpected transfer worker error for room {room}"
)
self.exceptions[room] = e
self.record_failure(room, str(e))
self.update_status(room, KVPoll.Failed)

def register_buffer_to_engine(self):
kv_addrs = []
for kv_data_ptr, kv_data_len in zip(
Expand Down Expand Up @@ -925,91 +1088,19 @@ def add_transfer_request(
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)

reqs_to_be_processed = self.transfer_infos[bootstrap_room].values()
handles = []
for req in reqs_to_be_processed:
assert bootstrap_room == req.room
if req.is_dummy():
continue

chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
assert len(chunked_dst_kv_indice) == len(kv_indices)
assert req.agent_name in self.decode_kv_args_table

decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size

# Skip KV RDMA transfer when there are no pages to send
# (e.g., decode-side radix cache matched the entire prefix).
# Aux data is still sent below when is_last=True.
if len(kv_indices) > 0:
notif = (
f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}"
)

if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
)
else:
kv_xfer_handle = self.send_kvcache_slice(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
].decode_tp_rank,
dst_kv_item_len=self.decode_kv_args_table[
req.agent_name
].dst_kv_item_len,
)

handles.append(kv_xfer_handle)
# Only the last chunk we need to send the aux data.
if is_last:
if state_indices is not None:
dst_info = self.decode_kv_args_table[req.agent_name]
state_xfer_handle = self.maybe_send_extra(
req.agent_name,
state_indices,
dst_info.dst_state_data_ptrs,
req.dst_state_indices,
dst_info.gpu_id,
f"{req.room}_state_{self.kv_args.engine_rank}",
decode_tp_size,
decode_tp_rank=dst_info.decode_tp_rank,
dst_state_item_lens=dst_info.dst_state_item_lens,
dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor,
)
if state_xfer_handle is not None:
handles.append(state_xfer_handle)

assert aux_index is not None
# When no KV pages were sent (decode-side cache hit),
# encode pp_rank in aux notif so receiver can mark
# expected_kvs_per_pp[pp_rank] = 0.
if len(kv_indices) == 0:
aux_notif = f"{req.room}_aux_nokv_{self.kv_args.pp_rank}"
else:
aux_notif = f"{req.room}_aux"
aux_xfer_handle = self.send_aux(
req.agent_name,
aux_index,
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
req.dst_aux_index,
aux_notif,
)
handles.append(aux_xfer_handle)
return handles
shard_idx = bootstrap_room % len(self.transfer_queues)
self.transfer_queues[shard_idx].put(
TransferKVChunk(
room=bootstrap_room,
prefill_kv_indices=kv_indices,
index_slice=index_slice,
is_last=is_last,
chunk_id=chunk_id,
prefill_aux_index=aux_index,
state_indices=state_indices,
)
)
return None

def update_transfer_status(self):
# Process notifications from received transfers.
Expand Down Expand Up @@ -1115,7 +1206,6 @@ def __init__(
pp_rank: int,
):
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
self.xfer_handles = []
self.has_sent = False
self.chunk_id = 0
self._send_failed = False
Expand Down Expand Up @@ -1159,64 +1249,43 @@ def send(
):
self._transfer_start_time = time.perf_counter()

try:
new_xfer_handles = self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
is_last,
self.chunk_id,
self.aux_index,
state_indices,
)
except _NIXL_TRANSPORT_ERRORS as e:
logger.warning(
f"KVSender transfer request failed for room {self.bootstrap_room}: {e}"
)
self._send_failed = True
self._send_error = e
return

self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
is_last,
self.chunk_id,
self.aux_index,
state_indices,
)
self._record_transfer_indices(kv_indices, state_indices)
self.xfer_handles.extend(new_xfer_handles)
self.chunk_id += 1
if is_last:
self.has_sent = True

def poll(self) -> KVPoll:
if self._send_failed:
return KVPoll.Failed # type: ignore
if not self.has_sent:
return self.kv_mgr.check_status(self.bootstrap_room)
try:
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
except _NIXL_TRANSPORT_ERRORS as e:
logger.warning(
f"KVSender check_xfer_state failed for room {self.bootstrap_room}: {e}"
)
self._send_failed = True
self._send_error = e
return KVPoll.Failed # type: ignore
if all(x == "DONE" for x in states):
if (
self._transfer_start_time is not None
and self._transfer_metric.transfer_latency_s is None
):
self._transfer_metric.transfer_latency_s = (
time.perf_counter() - self._transfer_start_time
)
return KVPoll.Success # type: ignore
if any(x == "ERR" for x in states):
self._send_failed = True
self._send_error = RuntimeError(
f"NIXL transfer error for room {self.bootstrap_room}"
Comment on lines -1193 to -1212
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC: @cctry

Copy link
Copy Markdown
Contributor Author

@ovidiusm ovidiusm May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good point. I have now changed the code to catch exceptions in the worker thread, pass them to the main thread and raise from there, so that we can detect _NIXL_TRANSPORT_ERRORS as before. The worker thread still has to catch all exceptions otherwise it may die in case of other errors, which may cause hangs

status = self.kv_mgr.check_status(self.bootstrap_room)
if (
status == KVPoll.Success
and self._transfer_start_time is not None
and self._transfer_metric.transfer_latency_s is None
):
self._transfer_metric.transfer_latency_s = (
time.perf_counter() - self._transfer_start_time
)
return KVPoll.Failed # type: ignore
return KVPoll.WaitingForInput # type: ignore
return status

def clear(self):
super().clear()

def failure_exception(self):
if self._send_error is not None:
raise self._send_error
exc = self.kv_mgr.exceptions.pop(self.bootstrap_room, None)
if exc is not None:
raise exc
raise RuntimeError("NIXL KVSender Exception")


Expand Down
Loading