diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 005d5b05c286..43c4b9bd3e72 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -19,7 +19,10 @@ CommonKVReceiver, CommonKVSender, ) -from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous +from sglang.srt.disaggregation.common.utils import ( + FastQueue, + group_concurrent_contiguous, +) from sglang.srt.disaggregation.utils import ( DisaggregationMode, filter_kv_indices_for_cp_rank, @@ -68,6 +71,17 @@ def from_zmq(cls, msg: List[bytes]): ) +@dataclasses.dataclass +class TransferKVChunk: + room: int + prefill_kv_indices: npt.NDArray[np.int32] + index_slice: slice + is_last: bool + chunk_id: int + prefill_aux_index: Optional[int] + state_indices: Optional[List[int]] + + @dataclasses.dataclass class KVArgsRegisterInfo: """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread.""" @@ -203,6 +217,14 @@ def __init__( self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: + transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get() + self.transfer_queues: List[FastQueue] = [ + FastQueue() for _ in range(transfer_queue_size) + ] + for queue in self.transfer_queues: + threading.Thread( + target=self.transfer_worker, args=(queue,), daemon=True + ).start() self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( @@ -300,6 +322,118 @@ def _handle_node_failure(self, failed_bootstrap_addr): logger.error(f"Let room {room} be failed due to prefill down") self.update_status(room, KVPoll.Failed) + def check_status(self, bootstrap_room: int): + return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping) + + def transfer_worker(self, queue: FastQueue): + while True: + kv_chunk: TransferKVChunk = queue.get() + room = kv_chunk.room + try: + if self.check_status(room) == KVPoll.Failed: + continue + + assert room in self.transfer_infos + + self.update_status(room, KVPoll.Transferring) + + reqs_to_be_processed = list(self.transfer_infos[room].values()) + handles: List = [] + + for req in reqs_to_be_processed: + assert room == req.room + if req.is_dummy(): + continue + + chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice] + if len(chunked_dst_kv_indice) < len(kv_chunk.prefill_kv_indices): + kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[ + : len(chunked_dst_kv_indice) + ] + assert req.agent_name in self.decode_kv_args_table + + notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.engine_rank}" + decode_tp_size = self.decode_kv_args_table[ + req.agent_name + ].decode_tp_size + + if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): + kv_xfer_handle = self.send_kvcache( + req.agent_name, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + ) + else: + kv_xfer_handle = self.send_kvcache_slice( + req.agent_name, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + prefill_tp_size=self.attn_tp_size, + decode_tp_size=decode_tp_size, + decode_tp_rank=self.decode_kv_args_table[ + req.agent_name + ].decode_tp_rank, + dst_kv_item_len=self.decode_kv_args_table[ + req.agent_name + ].dst_kv_item_len, + ) + handles.append(kv_xfer_handle) + + if kv_chunk.is_last: + if kv_chunk.state_indices is not None: + dst_info = self.decode_kv_args_table[req.agent_name] + state_xfer_handle = self.maybe_send_extra( + req.agent_name, + kv_chunk.state_indices, + dst_info.dst_state_data_ptrs, + req.dst_state_indices, + dst_info.gpu_id, + f"{req.room}_state_{self.kv_args.engine_rank}", + decode_tp_size, + decode_tp_rank=dst_info.decode_tp_rank, + dst_state_item_lens=dst_info.dst_state_item_lens, + dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, + ) + if state_xfer_handle is not None: + handles.append(state_xfer_handle) + + if kv_chunk.prefill_aux_index is None: + raise RuntimeError("Missing aux index for last chunk") + aux_xfer_handle = self.send_aux( + req.agent_name, + kv_chunk.prefill_aux_index, + self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, + req.dst_aux_index, + f"{req.room}_aux", + ) + handles.append(aux_xfer_handle) + + while handles: + states = [self.agent.check_xfer_state(h) for h in handles] + if any(s == "ERR" for s in states): + raise RuntimeError(f"NIXL transfer encountered ERR room={room}") + if all(s == "DONE" for s in states): + break + time.sleep(0) + + if kv_chunk.is_last: + if room in self.transfer_infos: + del self.transfer_infos[room] + self.update_status(room, KVPoll.Success) + else: + self.update_status(room, KVPoll.Transferring) + except Exception as e: + reason = f"Prefill transfer worker error room={room}: {e}" + logger.exception(reason) + self.record_failure(room, reason) + self.update_status(room, KVPoll.Failed) + def register_buffer_to_engine(self): kv_addrs = [] for kv_data_ptr, kv_data_len in zip( @@ -872,81 +1006,22 @@ def add_transfer_request( assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) - reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() - handles = [] - for req in reqs_to_be_processed: - assert bootstrap_room == req.room - if req.is_dummy(): - continue - - chunked_dst_kv_indice = req.dst_kv_indices[index_slice] - assert len(chunked_dst_kv_indice) == len(kv_indices) - assert req.agent_name in self.decode_kv_args_table - - notif = ( - f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.engine_rank}" + if bootstrap_room not in self.request_status: + self.update_status(bootstrap_room, KVPoll.Bootstrapping) + + shard_idx = bootstrap_room % len(self.transfer_queues) + self.transfer_queues[shard_idx].put( + TransferKVChunk( + room=bootstrap_room, + prefill_kv_indices=kv_indices, + index_slice=index_slice, + is_last=is_last, + chunk_id=chunk_id, + prefill_aux_index=aux_index, + state_indices=state_indices, ) - decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size - - if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): - kv_xfer_handle = self.send_kvcache( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - ) - else: - kv_xfer_handle = self.send_kvcache_slice( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - prefill_tp_size=self.attn_tp_size, - decode_tp_size=decode_tp_size, - decode_tp_rank=self.decode_kv_args_table[ - req.agent_name - ].decode_tp_rank, - dst_kv_item_len=self.decode_kv_args_table[ - req.agent_name - ].dst_kv_item_len, - ) - - handles.append(kv_xfer_handle) - # Only the last chunk we need to send the aux data. - if is_last: - if state_indices is not None: - dst_info = self.decode_kv_args_table[req.agent_name] - state_xfer_handle = self.maybe_send_extra( - req.agent_name, - state_indices, - dst_info.dst_state_data_ptrs, - req.dst_state_indices, - dst_info.gpu_id, - f"{req.room}_state_{self.kv_args.engine_rank}", - decode_tp_size, - decode_tp_rank=dst_info.decode_tp_rank, - dst_state_item_lens=dst_info.dst_state_item_lens, - dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, - ) - if state_xfer_handle is not None: - handles.append(state_xfer_handle) - - assert aux_index is not None - aux_xfer_handle = self.send_aux( - req.agent_name, - aux_index, - self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, - req.dst_aux_index, - f"{req.room}_aux", - ) - handles.append(aux_xfer_handle) - if is_last: - del self.transfer_infos[bootstrap_room] - return handles + ) + return None def update_transfer_status(self): # Process notifications from received transfers. @@ -1035,7 +1110,6 @@ def __init__( pp_rank: int, ): super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank) - self.xfer_handles = [] self.has_sent = False self.chunk_id = 0 @@ -1062,7 +1136,7 @@ def send( self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success) return - new_xfer_handles = self.kv_mgr.add_transfer_request( + self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, index_slice, @@ -1071,21 +1145,21 @@ def send( self.aux_index, state_indices, ) - self.xfer_handles.extend(new_xfer_handles) self.chunk_id += 1 if is_last: self.has_sent = True - del self.kv_mgr.request_status[self.bootstrap_room] def poll(self) -> KVPoll: + status = self.kv_mgr.check_status(self.bootstrap_room) if not self.has_sent: - return self.kv_mgr.check_status(self.bootstrap_room) - states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] - if all([x == "DONE" for x in states]): - return KVPoll.Success # type: ignore - if any([x == "ERR" for x in states]): - raise Exception("KVSender transfer encountered an error.") - return KVPoll.WaitingForInput # type: ignore + return status + if status in (KVPoll.Success, KVPoll.Failed): + return status + return status + + def clear(self): + if self.bootstrap_room in self.kv_mgr.request_status: + self.kv_mgr.request_status.pop(self.bootstrap_room) def failure_exception(self): raise RuntimeError("NIXL KVSender Exception") diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index d9016d0f94d3..a3ba821ff43f 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -628,7 +628,11 @@ def process_disagg_prefill_inflight_queue( undone_reqs.append(req) continue - if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: + if poll in [ + KVPoll.Bootstrapping, + KVPoll.WaitingForInput, + KVPoll.Transferring, + ]: undone_reqs.append(req) elif poll == KVPoll.Success: # transfer done release_kv_cache(req, self.tree_cache) # unlock the tree