diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index e414116fc399..cb21fd5c945b 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -20,7 +20,10 @@ CommonKVReceiver, CommonKVSender, ) -from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous +from sglang.srt.disaggregation.common.utils import ( + FastQueue, + group_concurrent_contiguous, +) from sglang.srt.disaggregation.utils import ( DisaggregationMode, filter_kv_indices_for_cp_rank, @@ -94,6 +97,17 @@ def from_zmq(cls, msg: List[bytes]): ) +@dataclasses.dataclass +class TransferKVChunk: + room: int + prefill_kv_indices: npt.NDArray[np.int32] + index_slice: slice + is_last: bool + chunk_id: int + prefill_aux_index: Optional[int] + state_indices: Optional[List[int]] + + @dataclasses.dataclass class KVArgsRegisterInfo: """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread.""" @@ -248,6 +262,15 @@ def __init__( self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: + transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get() + self.transfer_queues: List[FastQueue] = [ + FastQueue() for _ in range(transfer_queue_size) + ] + self.exceptions: Dict[int, Exception] = {} + for queue in self.transfer_queues: + threading.Thread( + target=self.transfer_worker, args=(queue,), daemon=True + ).start() self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( @@ -345,6 +368,146 @@ def _handle_node_failure(self, failed_bootstrap_addr): logger.error(f"Let room {room} be failed due to prefill down") self.update_status(room, KVPoll.Failed) + def check_status(self, bootstrap_room: int): + return self.request_status.get(bootstrap_room, KVPoll.WaitingForInput) + + def transfer_worker(self, queue: FastQueue): + while True: + kv_chunk: TransferKVChunk = queue.get() + room = kv_chunk.room + try: + if self.check_status(room) == KVPoll.Failed: + continue + + assert room in self.transfer_infos + + self.update_status(room, KVPoll.Transferring) + + reqs_to_be_processed = list(self.transfer_infos[room].values()) + handles: List = [] + + for req in reqs_to_be_processed: + assert room == req.room + if req.is_dummy(): + continue + + assert req.agent_name in self.decode_kv_args_table + decode_tp_size = self.decode_kv_args_table[ + req.agent_name + ].decode_tp_size + + # Skip KV RDMA transfer when there are no pages to send + # (e.g., decode-side radix cache matched the entire prefix). + # Aux data is still sent below when is_last=True. + if len(kv_chunk.prefill_kv_indices) > 0: + chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice] + + # NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices + # is mismatched with the dst_kv_indices when page size > 1, this should never happen. + if len(chunked_dst_kv_indice) < len( + kv_chunk.prefill_kv_indices + ): + logger.warning( + f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" + ) + kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[ + : len(chunked_dst_kv_indice) + ] + + notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.engine_rank}" + + if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): + kv_xfer_handle = self.send_kvcache( + req.agent_name, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + ) + else: + kv_xfer_handle = self.send_kvcache_slice( + req.agent_name, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + prefill_tp_size=self.attn_tp_size, + decode_tp_size=decode_tp_size, + decode_tp_rank=self.decode_kv_args_table[ + req.agent_name + ].decode_tp_rank, + dst_kv_item_len=self.decode_kv_args_table[ + req.agent_name + ].dst_kv_item_len, + ) + + handles.append(kv_xfer_handle) + + if kv_chunk.is_last: + if kv_chunk.state_indices is not None: + dst_info = self.decode_kv_args_table[req.agent_name] + state_xfer_handle = self.maybe_send_extra( + req.agent_name, + kv_chunk.state_indices, + dst_info.dst_state_data_ptrs, + req.dst_state_indices, + dst_info.gpu_id, + f"{req.room}_state_{self.kv_args.engine_rank}", + decode_tp_size, + decode_tp_rank=dst_info.decode_tp_rank, + dst_state_item_lens=dst_info.dst_state_item_lens, + dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, + ) + if state_xfer_handle is not None: + handles.append(state_xfer_handle) + + if kv_chunk.prefill_aux_index is None: + raise RuntimeError("Missing aux index for last chunk") + # When no KV pages were sent (decode-side cache hit), + # encode pp_rank in aux notif so receiver can mark + # expected_kvs_per_pp[pp_rank] = 0. + if len(kv_chunk.prefill_kv_indices) == 0: + aux_notif = ( + f"{req.room}_aux_nokv_{self.kv_args.engine_rank}" + ) + else: + aux_notif = f"{req.room}_aux" + aux_xfer_handle = self.send_aux( + req.agent_name, + kv_chunk.prefill_aux_index, + self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, + req.dst_aux_index, + aux_notif, + ) + handles.append(aux_xfer_handle) + + while handles: + states = [self.agent.check_xfer_state(h) for h in handles] + if any(s == "ERR" for s in states): + raise RuntimeError(f"NIXL transfer encountered ERR room={room}") + if all(s == "DONE" for s in states): + break + time.sleep(0) + + if kv_chunk.is_last: + self.update_status(room, KVPoll.Success) + else: + self.update_status(room, KVPoll.Transferring) + except Exception as e: + # Catch all exceptions to prevent silently killing this + # worker thread, but still propagate via failure_exception(). + if isinstance(e, _NIXL_TRANSPORT_ERRORS): + logger.warning(f"NIXL transport error for room {room}: {e}") + else: + logger.exception( + f"Unexpected transfer worker error for room {room}" + ) + self.exceptions[room] = e + self.record_failure(room, str(e)) + self.update_status(room, KVPoll.Failed) + def register_buffer_to_engine(self): kv_addrs = [] for kv_data_ptr, kv_data_len in zip( @@ -925,91 +1088,19 @@ def add_transfer_request( assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) - reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() - handles = [] - for req in reqs_to_be_processed: - assert bootstrap_room == req.room - if req.is_dummy(): - continue - - chunked_dst_kv_indice = req.dst_kv_indices[index_slice] - assert len(chunked_dst_kv_indice) == len(kv_indices) - assert req.agent_name in self.decode_kv_args_table - - decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size - - # Skip KV RDMA transfer when there are no pages to send - # (e.g., decode-side radix cache matched the entire prefix). - # Aux data is still sent below when is_last=True. - if len(kv_indices) > 0: - notif = ( - f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}" - ) - - if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): - kv_xfer_handle = self.send_kvcache( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - ) - else: - kv_xfer_handle = self.send_kvcache_slice( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - prefill_tp_size=self.attn_tp_size, - decode_tp_size=decode_tp_size, - decode_tp_rank=self.decode_kv_args_table[ - req.agent_name - ].decode_tp_rank, - dst_kv_item_len=self.decode_kv_args_table[ - req.agent_name - ].dst_kv_item_len, - ) - - handles.append(kv_xfer_handle) - # Only the last chunk we need to send the aux data. - if is_last: - if state_indices is not None: - dst_info = self.decode_kv_args_table[req.agent_name] - state_xfer_handle = self.maybe_send_extra( - req.agent_name, - state_indices, - dst_info.dst_state_data_ptrs, - req.dst_state_indices, - dst_info.gpu_id, - f"{req.room}_state_{self.kv_args.engine_rank}", - decode_tp_size, - decode_tp_rank=dst_info.decode_tp_rank, - dst_state_item_lens=dst_info.dst_state_item_lens, - dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, - ) - if state_xfer_handle is not None: - handles.append(state_xfer_handle) - - assert aux_index is not None - # When no KV pages were sent (decode-side cache hit), - # encode pp_rank in aux notif so receiver can mark - # expected_kvs_per_pp[pp_rank] = 0. - if len(kv_indices) == 0: - aux_notif = f"{req.room}_aux_nokv_{self.kv_args.pp_rank}" - else: - aux_notif = f"{req.room}_aux" - aux_xfer_handle = self.send_aux( - req.agent_name, - aux_index, - self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, - req.dst_aux_index, - aux_notif, - ) - handles.append(aux_xfer_handle) - return handles + shard_idx = bootstrap_room % len(self.transfer_queues) + self.transfer_queues[shard_idx].put( + TransferKVChunk( + room=bootstrap_room, + prefill_kv_indices=kv_indices, + index_slice=index_slice, + is_last=is_last, + chunk_id=chunk_id, + prefill_aux_index=aux_index, + state_indices=state_indices, + ) + ) + return None def update_transfer_status(self): # Process notifications from received transfers. @@ -1115,7 +1206,6 @@ def __init__( pp_rank: int, ): super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank) - self.xfer_handles = [] self.has_sent = False self.chunk_id = 0 self._send_failed = False @@ -1159,26 +1249,16 @@ def send( ): self._transfer_start_time = time.perf_counter() - try: - new_xfer_handles = self.kv_mgr.add_transfer_request( - self.bootstrap_room, - kv_indices, - index_slice, - is_last, - self.chunk_id, - self.aux_index, - state_indices, - ) - except _NIXL_TRANSPORT_ERRORS as e: - logger.warning( - f"KVSender transfer request failed for room {self.bootstrap_room}: {e}" - ) - self._send_failed = True - self._send_error = e - return - + self.kv_mgr.add_transfer_request( + self.bootstrap_room, + kv_indices, + index_slice, + is_last, + self.chunk_id, + self.aux_index, + state_indices, + ) self._record_transfer_indices(kv_indices, state_indices) - self.xfer_handles.extend(new_xfer_handles) self.chunk_id += 1 if is_last: self.has_sent = True @@ -1186,37 +1266,26 @@ def send( def poll(self) -> KVPoll: if self._send_failed: return KVPoll.Failed # type: ignore - if not self.has_sent: - return self.kv_mgr.check_status(self.bootstrap_room) - try: - states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] - except _NIXL_TRANSPORT_ERRORS as e: - logger.warning( - f"KVSender check_xfer_state failed for room {self.bootstrap_room}: {e}" - ) - self._send_failed = True - self._send_error = e - return KVPoll.Failed # type: ignore - if all(x == "DONE" for x in states): - if ( - self._transfer_start_time is not None - and self._transfer_metric.transfer_latency_s is None - ): - self._transfer_metric.transfer_latency_s = ( - time.perf_counter() - self._transfer_start_time - ) - return KVPoll.Success # type: ignore - if any(x == "ERR" for x in states): - self._send_failed = True - self._send_error = RuntimeError( - f"NIXL transfer error for room {self.bootstrap_room}" + status = self.kv_mgr.check_status(self.bootstrap_room) + if ( + status == KVPoll.Success + and self._transfer_start_time is not None + and self._transfer_metric.transfer_latency_s is None + ): + self._transfer_metric.transfer_latency_s = ( + time.perf_counter() - self._transfer_start_time ) - return KVPoll.Failed # type: ignore - return KVPoll.WaitingForInput # type: ignore + return status + + def clear(self): + super().clear() def failure_exception(self): if self._send_error is not None: raise self._send_error + exc = self.kv_mgr.exceptions.pop(self.bootstrap_room, None) + if exc is not None: + raise exc raise RuntimeError("NIXL KVSender Exception")