From 28b6504c4e5e9e8406e9115d86e3045f522fe0ab Mon Sep 17 00:00:00 2001 From: Ovidiu Mara Date: Mon, 4 May 2026 16:42:17 +0200 Subject: [PATCH 1/5] Nixl async transfer -- rebased onto latest main Signed-off-by: Ovidiu Mara --- python/sglang/srt/disaggregation/nixl/conn.py | 279 +++++++++++------- python/sglang/srt/disaggregation/prefill.py | 6 +- 2 files changed, 183 insertions(+), 102 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 471ab416fda2..36508b4d0f16 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -20,7 +20,10 @@ CommonKVReceiver, CommonKVSender, ) -from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous +from sglang.srt.disaggregation.common.utils import ( + FastQueue, + group_concurrent_contiguous, +) from sglang.srt.disaggregation.utils import ( DisaggregationMode, filter_kv_indices_for_cp_rank, @@ -79,6 +82,17 @@ def from_zmq(cls, msg: List[bytes]): ) +@dataclasses.dataclass +class TransferKVChunk: + room: int + prefill_kv_indices: npt.NDArray[np.int32] + index_slice: slice + is_last: bool + chunk_id: int + prefill_aux_index: Optional[int] + state_indices: Optional[List[int]] + + @dataclasses.dataclass class KVArgsRegisterInfo: """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread.""" @@ -233,6 +247,14 @@ def __init__( self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: + transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get() + self.transfer_queues: List[FastQueue] = [ + FastQueue() for _ in range(transfer_queue_size) + ] + for queue in self.transfer_queues: + threading.Thread( + target=self.transfer_worker, args=(queue,), daemon=True + ).start() self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( @@ -330,6 +352,140 @@ def _handle_node_failure(self, failed_bootstrap_addr): logger.error(f"Let room {room} be failed due to prefill down") self.update_status(room, KVPoll.Failed) + def check_status(self, bootstrap_room: int): + return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping) + + def transfer_worker(self, queue: FastQueue): + while True: + kv_chunk: TransferKVChunk = queue.get() + room = kv_chunk.room + try: + if self.check_status(room) == KVPoll.Failed: + continue + + assert room in self.transfer_infos + + self.update_status(room, KVPoll.Transferring) + + reqs_to_be_processed = list(self.transfer_infos[room].values()) + handles: List = [] + + for req in reqs_to_be_processed: + assert room == req.room + if req.is_dummy(): + continue + + assert req.agent_name in self.decode_kv_args_table + decode_tp_size = self.decode_kv_args_table[ + req.agent_name + ].decode_tp_size + + # Skip KV RDMA transfer when there are no pages to send + # (e.g., decode-side radix cache matched the entire prefix). + # Aux data is still sent below when is_last=True. + if len(kv_chunk.prefill_kv_indices) > 0: + chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice] + + # NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices + # is mismatched with the dst_kv_indices when page size > 1, this should never happen. + if len(chunked_dst_kv_indice) < len( + kv_chunk.prefill_kv_indices + ): + logger.warning( + f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" + ) + kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[ + : len(chunked_dst_kv_indice) + ] + + notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.pp_rank}" + + if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): + kv_xfer_handle = self.send_kvcache( + req.agent_name, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + ) + else: + kv_xfer_handle = self.send_kvcache_slice( + req.agent_name, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + prefill_tp_size=self.attn_tp_size, + decode_tp_size=decode_tp_size, + decode_tp_rank=self.decode_kv_args_table[ + req.agent_name + ].decode_tp_rank, + dst_kv_item_len=self.decode_kv_args_table[ + req.agent_name + ].dst_kv_item_len, + ) + + handles.append(kv_xfer_handle) + + if kv_chunk.is_last: + if kv_chunk.state_indices is not None: + dst_info = self.decode_kv_args_table[req.agent_name] + state_xfer_handle = self.maybe_send_extra( + req.agent_name, + kv_chunk.state_indices, + dst_info.dst_state_data_ptrs, + req.dst_state_indices, + dst_info.gpu_id, + f"{req.room}_state_{self.kv_args.engine_rank}", + decode_tp_size, + decode_tp_rank=dst_info.decode_tp_rank, + dst_state_item_lens=dst_info.dst_state_item_lens, + dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, + ) + if state_xfer_handle is not None: + handles.append(state_xfer_handle) + + if kv_chunk.prefill_aux_index is None: + raise RuntimeError("Missing aux index for last chunk") + # When no KV pages were sent (decode-side cache hit), + # encode pp_rank in aux notif so receiver can mark + # expected_kvs_per_pp[pp_rank] = 0. + if len(kv_chunk.prefill_kv_indices) == 0: + aux_notif = f"{req.room}_aux_nokv_{self.kv_args.pp_rank}" + else: + aux_notif = f"{req.room}_aux" + aux_xfer_handle = self.send_aux( + req.agent_name, + kv_chunk.prefill_aux_index, + self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, + req.dst_aux_index, + aux_notif, + ) + handles.append(aux_xfer_handle) + + while handles: + states = [self.agent.check_xfer_state(h) for h in handles] + if any(s == "ERR" for s in states): + raise RuntimeError(f"NIXL transfer encountered ERR room={room}") + if all(s == "DONE" for s in states): + break + time.sleep(0) + + if kv_chunk.is_last: + if room in self.transfer_infos: + del self.transfer_infos[room] + self.req_to_decode_prefix_len.pop(room, None) + self.update_status(room, KVPoll.Success) + else: + self.update_status(room, KVPoll.Transferring) + except Exception as e: + reason = f"Prefill transfer worker error room={room}: {e}" + logger.exception(reason) + self.record_failure(room, reason) + self.update_status(room, KVPoll.Failed) + def register_buffer_to_engine(self): kv_addrs = [] for kv_data_ptr, kv_data_len in zip( @@ -902,94 +1058,22 @@ def add_transfer_request( assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) - reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() - handles = [] - for req in reqs_to_be_processed: - assert bootstrap_room == req.room - if req.is_dummy(): - continue - - chunked_dst_kv_indice = req.dst_kv_indices[index_slice] - assert len(chunked_dst_kv_indice) == len(kv_indices) - assert req.agent_name in self.decode_kv_args_table - - decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size - - # Skip KV RDMA transfer when there are no pages to send - # (e.g., decode-side radix cache matched the entire prefix). - # Aux data is still sent below when is_last=True. - if len(kv_indices) > 0: - notif = ( - f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}" - ) - - if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): - kv_xfer_handle = self.send_kvcache( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - ) - else: - kv_xfer_handle = self.send_kvcache_slice( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - prefill_tp_size=self.attn_tp_size, - decode_tp_size=decode_tp_size, - decode_tp_rank=self.decode_kv_args_table[ - req.agent_name - ].decode_tp_rank, - dst_kv_item_len=self.decode_kv_args_table[ - req.agent_name - ].dst_kv_item_len, - ) - - handles.append(kv_xfer_handle) - # Only the last chunk we need to send the aux data. - if is_last: - if state_indices is not None: - dst_info = self.decode_kv_args_table[req.agent_name] - state_xfer_handle = self.maybe_send_extra( - req.agent_name, - state_indices, - dst_info.dst_state_data_ptrs, - req.dst_state_indices, - dst_info.gpu_id, - f"{req.room}_state_{self.kv_args.engine_rank}", - decode_tp_size, - decode_tp_rank=dst_info.decode_tp_rank, - dst_state_item_lens=dst_info.dst_state_item_lens, - dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, - ) - if state_xfer_handle is not None: - handles.append(state_xfer_handle) - - assert aux_index is not None - # When no KV pages were sent (decode-side cache hit), - # encode pp_rank in aux notif so receiver can mark - # expected_kvs_per_pp[pp_rank] = 0. - if len(kv_indices) == 0: - aux_notif = f"{req.room}_aux_nokv_{self.kv_args.pp_rank}" - else: - aux_notif = f"{req.room}_aux" - aux_xfer_handle = self.send_aux( - req.agent_name, - aux_index, - self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, - req.dst_aux_index, - aux_notif, - ) - handles.append(aux_xfer_handle) - if is_last: - del self.transfer_infos[bootstrap_room] - self.req_to_decode_prefix_len.pop(bootstrap_room, None) - return handles + if bootstrap_room not in self.request_status: + self.update_status(bootstrap_room, KVPoll.Bootstrapping) + + shard_idx = bootstrap_room % len(self.transfer_queues) + self.transfer_queues[shard_idx].put( + TransferKVChunk( + room=bootstrap_room, + prefill_kv_indices=kv_indices, + index_slice=index_slice, + is_last=is_last, + chunk_id=chunk_id, + prefill_aux_index=aux_index, + state_indices=state_indices, + ) + ) + return None def update_transfer_status(self): # Process notifications from received transfers. @@ -1095,7 +1179,6 @@ def __init__( pp_rank: int, ): super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank) - self.xfer_handles = [] self.has_sent = False self.chunk_id = 0 @@ -1128,7 +1211,7 @@ def send( self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success) return - new_xfer_handles = self.kv_mgr.add_transfer_request( + self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, index_slice, @@ -1137,21 +1220,15 @@ def send( self.aux_index, state_indices, ) - self.xfer_handles.extend(new_xfer_handles) self.chunk_id += 1 if is_last: self.has_sent = True - del self.kv_mgr.request_status[self.bootstrap_room] def poll(self) -> KVPoll: - if not self.has_sent: - return self.kv_mgr.check_status(self.bootstrap_room) - states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] - if all([x == "DONE" for x in states]): - return KVPoll.Success # type: ignore - if any([x == "ERR" for x in states]): - raise Exception("KVSender transfer encountered an error.") - return KVPoll.WaitingForInput # type: ignore + return self.kv_mgr.check_status(self.bootstrap_room) + + def clear(self): + self.kv_mgr.request_status.pop(self.bootstrap_room, None) def failure_exception(self): raise RuntimeError("NIXL KVSender Exception") diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 533265ae72a4..4dea84deba29 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -616,7 +616,11 @@ def process_disagg_prefill_inflight_queue( undone_reqs.append(req) continue - if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: + if poll in [ + KVPoll.Bootstrapping, + KVPoll.WaitingForInput, + KVPoll.Transferring, + ]: undone_reqs.append(req) elif poll == KVPoll.Success: # transfer done release_kv_cache(req, self.tree_cache) # unlock the tree From 752f02b18dfdf3be11c0a1c11edc5a225a406898 Mon Sep 17 00:00:00 2001 From: Ovidiu Mara Date: Wed, 6 May 2026 01:30:51 +0200 Subject: [PATCH 2/5] Fix P>D notifications Signed-off-by: Ovidiu Mara --- python/sglang/srt/disaggregation/nixl/conn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 36508b4d0f16..212969f5ad2b 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -398,7 +398,7 @@ def transfer_worker(self, queue: FastQueue): : len(chunked_dst_kv_indice) ] - notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.pp_rank}" + notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.engine_rank}" if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): kv_xfer_handle = self.send_kvcache( @@ -453,7 +453,9 @@ def transfer_worker(self, queue: FastQueue): # encode pp_rank in aux notif so receiver can mark # expected_kvs_per_pp[pp_rank] = 0. if len(kv_chunk.prefill_kv_indices) == 0: - aux_notif = f"{req.room}_aux_nokv_{self.kv_args.pp_rank}" + aux_notif = ( + f"{req.room}_aux_nokv_{self.kv_args.engine_rank}" + ) else: aux_notif = f"{req.room}_aux" aux_xfer_handle = self.send_aux( From b5d4d178a6b7c53ec4c6eceb307b6e449705bf06 Mon Sep 17 00:00:00 2001 From: Ovidiu Mara Date: Thu, 7 May 2026 14:51:34 +0200 Subject: [PATCH 3/5] Change check_status to default to WaitingForInput (the initial state after bootstrap) Signed-off-by: Ovidiu Mara --- python/sglang/srt/disaggregation/nixl/conn.py | 5 +---- python/sglang/srt/disaggregation/prefill.py | 6 +----- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index ef5fcd8bcc79..32e1d6ac8de7 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -368,7 +368,7 @@ def _handle_node_failure(self, failed_bootstrap_addr): self.update_status(room, KVPoll.Failed) def check_status(self, bootstrap_room: int): - return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping) + return self.request_status.get(bootstrap_room, KVPoll.WaitingForInput) def transfer_worker(self, queue: FastQueue): while True: @@ -1083,9 +1083,6 @@ def add_transfer_request( assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) - if bootstrap_room not in self.request_status: - self.update_status(bootstrap_room, KVPoll.Bootstrapping) - shard_idx = bootstrap_room % len(self.transfer_queues) self.transfer_queues[shard_idx].put( TransferKVChunk( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index f518125e2aef..1a089e8ffccb 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -619,11 +619,7 @@ def process_disagg_prefill_inflight_queue( undone_reqs.append(req) continue - if poll in [ - KVPoll.Bootstrapping, - KVPoll.WaitingForInput, - KVPoll.Transferring, - ]: + if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: undone_reqs.append(req) elif poll == KVPoll.Success: # transfer done release_kv_cache(req, self.tree_cache) # unlock the tree From 227fa3bb18c2a1e5cefa95a3a477a6e597c766d0 Mon Sep 17 00:00:00 2001 From: Ovidiu Mara Date: Thu, 7 May 2026 15:05:23 +0200 Subject: [PATCH 4/5] Handle separately _NIXL_TRANSPORT_ERRORS and generic exceptions Signed-off-by: Ovidiu Mara --- python/sglang/srt/disaggregation/nixl/conn.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 80758043d281..95bb6e40da6b 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -494,10 +494,15 @@ def transfer_worker(self, queue: FastQueue): self.update_status(room, KVPoll.Success) else: self.update_status(room, KVPoll.Transferring) + except _NIXL_TRANSPORT_ERRORS as e: + logger.warning(f"NIXL transport error for room {room}: {e}") + self.record_failure(room, str(e)) + self.update_status(room, KVPoll.Failed) except Exception as e: - reason = f"Prefill transfer worker error room={room}: {e}" - logger.exception(reason) - self.record_failure(room, reason) + logger.exception( + f"Unexpected transfer worker error for room {room}: {e}" + ) + self.record_failure(room, str(e)) self.update_status(room, KVPoll.Failed) def register_buffer_to_engine(self): From ba5ad5255424b32de1ee4ca79869d4397ed51551 Mon Sep 17 00:00:00 2001 From: Ovidiu Mara Date: Thu, 7 May 2026 15:39:01 +0200 Subject: [PATCH 5/5] Propagate exceptions for logging and telemetry Signed-off-by: Ovidiu Mara --- python/sglang/srt/disaggregation/nixl/conn.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 95bb6e40da6b..cb21fd5c945b 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -266,6 +266,7 @@ def __init__( self.transfer_queues: List[FastQueue] = [ FastQueue() for _ in range(transfer_queue_size) ] + self.exceptions: Dict[int, Exception] = {} for queue in self.transfer_queues: threading.Thread( target=self.transfer_worker, args=(queue,), daemon=True @@ -494,14 +495,16 @@ def transfer_worker(self, queue: FastQueue): self.update_status(room, KVPoll.Success) else: self.update_status(room, KVPoll.Transferring) - except _NIXL_TRANSPORT_ERRORS as e: - logger.warning(f"NIXL transport error for room {room}: {e}") - self.record_failure(room, str(e)) - self.update_status(room, KVPoll.Failed) except Exception as e: - logger.exception( - f"Unexpected transfer worker error for room {room}: {e}" - ) + # Catch all exceptions to prevent silently killing this + # worker thread, but still propagate via failure_exception(). + if isinstance(e, _NIXL_TRANSPORT_ERRORS): + logger.warning(f"NIXL transport error for room {room}: {e}") + else: + logger.exception( + f"Unexpected transfer worker error for room {room}" + ) + self.exceptions[room] = e self.record_failure(room, str(e)) self.update_status(room, KVPoll.Failed) @@ -1280,6 +1283,9 @@ def clear(self): def failure_exception(self): if self._send_error is not None: raise self._send_error + exc = self.kv_mgr.exceptions.pop(self.bootstrap_room, None) + if exc is not None: + raise exc raise RuntimeError("NIXL KVSender Exception")