From 76640b2a4d9824492249c55f9b97879472d543d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=84=E6=AD=A6?= Date: Mon, 16 Mar 2026 15:16:28 +0800 Subject: [PATCH 1/5] improve nixl performance --- python/sglang/srt/disaggregation/nixl/conn.py | 269 ++++++++++++------ python/sglang/srt/disaggregation/prefill.py | 6 +- 2 files changed, 192 insertions(+), 83 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 764fd9e42689..ec3f31f0bbae 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -1,7 +1,9 @@ from __future__ import annotations +import concurrent.futures import dataclasses import logging +import os import struct import threading import time @@ -19,7 +21,7 @@ CommonKVReceiver, CommonKVSender, ) -from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous +from sglang.srt.disaggregation.common.utils import FastQueue, group_concurrent_contiguous from sglang.srt.disaggregation.utils import ( DisaggregationMode, filter_kv_indices_for_cp_rank, @@ -68,6 +70,17 @@ def from_zmq(cls, msg: List[bytes]): ) +@dataclasses.dataclass +class TransferKVChunk: + room: int + prefill_kv_indices: npt.NDArray[np.int32] + index_slice: slice + is_last: bool + chunk_id: int + prefill_aux_index: Optional[int] + state_indices: Optional[List[int]] + + @dataclasses.dataclass class KVArgsRegisterInfo: """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread.""" @@ -190,6 +203,30 @@ def __init__( self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: + cpu_count = os.cpu_count() or 1 + transfer_thread_pool_size = envs.SGLANG_DISAGGREGATION_THREAD_POOL_SIZE.get() + if transfer_thread_pool_size is None: + transfer_thread_pool_size = min( + max(4, int(0.5 * cpu_count) // 8), 12 + ) + transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get() + self.transfer_queues: List[FastQueue] = [ + FastQueue() for _ in range(transfer_queue_size) + ] + assert transfer_thread_pool_size >= transfer_queue_size, ( + f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be " + f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}." + ) + self.executors = [ + concurrent.futures.ThreadPoolExecutor( + max_workers=max(1, transfer_thread_pool_size // transfer_queue_size) + ) + for _ in range(transfer_queue_size) + ] + for queue, executor in zip(self.transfer_queues, self.executors): + threading.Thread( + target=self.transfer_worker, args=(queue, executor), daemon=True + ).start() self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( @@ -287,6 +324,127 @@ def _handle_node_failure(self, failed_bootstrap_addr): logger.error(f"Let room {room} be failed due to prefill down") self.update_status(room, KVPoll.Failed) + def check_status(self, bootstrap_room: int): + return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping) + + def transfer_worker( + self, queue: FastQueue, executor: concurrent.futures.Executor + ): + while True: + kv_chunk: TransferKVChunk = queue.get() + room = kv_chunk.room + try: + if ( + room in self.request_status + and self.check_status(room) == KVPoll.Failed + ): + continue + + if room not in self.transfer_infos: + time.sleep(0.001) + queue.put(kv_chunk) + continue + + self.update_status(room, KVPoll.Transferring) + + reqs_to_be_processed = list(self.transfer_infos[room].values()) + handles: List = [] + + for req in reqs_to_be_processed: + assert room == req.room + if req.is_dummy(): + continue + + chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice] + if len(chunked_dst_kv_indice) < len(kv_chunk.prefill_kv_indices): + kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[ + : len(chunked_dst_kv_indice) + ] + assert req.agent_name in self.decode_kv_args_table + + notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.pp_rank}" + decode_tp_size = self.decode_kv_args_table[ + req.agent_name + ].decode_tp_size + + if self.is_mla_backend or ( + decode_tp_size == self.attn_tp_size + ): + kv_xfer_handle = self.send_kvcache( + req.agent_name, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + ) + else: + kv_xfer_handle = self.send_kvcache_slice( + req.agent_name, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + prefill_tp_size=self.attn_tp_size, + decode_tp_size=decode_tp_size, + decode_tp_rank=self.decode_kv_args_table[ + req.agent_name + ].decode_tp_rank, + dst_kv_item_len=self.decode_kv_args_table[ + req.agent_name + ].dst_kv_item_len, + ) + handles.append(kv_xfer_handle) + + if kv_chunk.is_last: + if kv_chunk.state_indices is not None: + dst_info = self.decode_kv_args_table[req.agent_name] + state_xfer_handle = self.maybe_send_extra( + req.agent_name, + kv_chunk.state_indices, + dst_info.dst_state_data_ptrs, + req.dst_state_indices, + dst_info.gpu_id, + f"{req.room}_state_{self.kv_args.pp_rank}", + decode_tp_size, + ) + if state_xfer_handle is not None: + handles.append(state_xfer_handle) + + if kv_chunk.prefill_aux_index is None: + raise RuntimeError("Missing aux index for last chunk") + aux_xfer_handle = self.send_aux( + req.agent_name, + kv_chunk.prefill_aux_index, + self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, + req.dst_aux_index, + f"{req.room}_aux", + ) + handles.append(aux_xfer_handle) + + while handles: + states = [self.agent.check_xfer_state(h) for h in handles] + if any(s == "ERR" for s in states): + raise RuntimeError( + f"NIXL transfer encountered ERR room={room}" + ) + if all(s == "DONE" for s in states): + break + time.sleep(0.001) + + if kv_chunk.is_last: + if room in self.transfer_infos: + del self.transfer_infos[room] + self.update_status(room, KVPoll.Success) + else: + self.update_status(room, KVPoll.Transferring) + except Exception as e: + reason = f"Prefill transfer worker error room={room}: {e}" + logger.exception(reason) + self.record_failure(room, reason) + self.update_status(room, KVPoll.Failed) + def register_buffer_to_engine(self): kv_addrs = [] for kv_data_ptr, kv_data_len in zip( @@ -737,76 +895,22 @@ def add_transfer_request( assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) - reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() - handles = [] - for req in reqs_to_be_processed: - assert bootstrap_room == req.room - if req.is_dummy(): - continue - - chunked_dst_kv_indice = req.dst_kv_indices[index_slice] - assert len(chunked_dst_kv_indice) == len(kv_indices) - assert req.agent_name in self.decode_kv_args_table - - notif = f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}" - decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size - - if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): - kv_xfer_handle = self.send_kvcache( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - ) - else: - kv_xfer_handle = self.send_kvcache_slice( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - prefill_tp_size=self.attn_tp_size, - decode_tp_size=decode_tp_size, - decode_tp_rank=self.decode_kv_args_table[ - req.agent_name - ].decode_tp_rank, - dst_kv_item_len=self.decode_kv_args_table[ - req.agent_name - ].dst_kv_item_len, - ) - - handles.append(kv_xfer_handle) - # Only the last chunk we need to send the aux data. - if is_last: - if state_indices is not None: - dst_info = self.decode_kv_args_table[req.agent_name] - state_xfer_handle = self.maybe_send_extra( - req.agent_name, - state_indices, - dst_info.dst_state_data_ptrs, - req.dst_state_indices, - dst_info.gpu_id, - f"{req.room}_state_{self.kv_args.pp_rank}", - decode_tp_size, - ) - if state_xfer_handle is not None: - handles.append(state_xfer_handle) - - assert aux_index is not None - aux_xfer_handle = self.send_aux( - req.agent_name, - aux_index, - self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, - req.dst_aux_index, - f"{req.room}_aux", - ) - handles.append(aux_xfer_handle) - if is_last: - del self.transfer_infos[bootstrap_room] - return handles + if bootstrap_room not in self.request_status: + self.update_status(bootstrap_room, KVPoll.Bootstrapping) + + shard_idx = bootstrap_room % len(self.transfer_queues) + self.transfer_queues[shard_idx].put( + TransferKVChunk( + room=bootstrap_room, + prefill_kv_indices=kv_indices, + index_slice=index_slice, + is_last=is_last, + chunk_id=chunk_id, + prefill_aux_index=aux_index, + state_indices=state_indices, + ) + ) + return None def update_transfer_status(self): # Process notifications from received transfers. @@ -895,7 +999,6 @@ def __init__( pp_rank: int, ): super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank) - self.xfer_handles = [] self.has_sent = False self.chunk_id = 0 @@ -922,7 +1025,7 @@ def send( self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success) return - new_xfer_handles = self.kv_mgr.add_transfer_request( + self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, index_slice, @@ -931,21 +1034,23 @@ def send( self.aux_index, state_indices, ) - self.xfer_handles.extend(new_xfer_handles) self.chunk_id += 1 if is_last: self.has_sent = True - del self.kv_mgr.request_status[self.bootstrap_room] def poll(self) -> KVPoll: + status = self.kv_mgr.check_status(self.bootstrap_room) if not self.has_sent: - return self.kv_mgr.check_status(self.bootstrap_room) - states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] - if all([x == "DONE" for x in states]): - return KVPoll.Success # type: ignore - if any([x == "ERR" for x in states]): - raise Exception("KVSender transfer encountered an error.") - return KVPoll.WaitingForInput # type: ignore + return status + if status in (KVPoll.Success, KVPoll.Failed): + return status + return status + + def clear(self): + try: + self.kv_mgr.request_status.pop(self.bootstrap_room, None) + except Exception: + pass def failure_exception(self): raise RuntimeError("NIXL KVSender Exception") diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 381e92590cb7..f459b29e306e 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -571,7 +571,11 @@ def process_disagg_prefill_inflight_queue( assert poll == KVPoll.Success or poll == KVPoll.Failed - if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: + if poll in [ + KVPoll.Bootstrapping, + KVPoll.WaitingForInput, + KVPoll.Transferring, + ]: undone_reqs.append(req) elif poll == KVPoll.Success: # transfer done release_kv_cache(req, self.tree_cache) # unlock the tree From 563b805be126d194446cc71afe9bdb8c1dae52b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=84=E6=AD=A6?= Date: Mon, 16 Mar 2026 15:38:51 +0800 Subject: [PATCH 2/5] fix format --- python/sglang/srt/disaggregation/nixl/conn.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index ec3f31f0bbae..d9902a92563b 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -21,7 +21,10 @@ CommonKVReceiver, CommonKVSender, ) -from sglang.srt.disaggregation.common.utils import FastQueue, group_concurrent_contiguous +from sglang.srt.disaggregation.common.utils import ( + FastQueue, + group_concurrent_contiguous, +) from sglang.srt.disaggregation.utils import ( DisaggregationMode, filter_kv_indices_for_cp_rank, @@ -204,11 +207,11 @@ def __init__( if self.disaggregation_mode == DisaggregationMode.PREFILL: cpu_count = os.cpu_count() or 1 - transfer_thread_pool_size = envs.SGLANG_DISAGGREGATION_THREAD_POOL_SIZE.get() + transfer_thread_pool_size = ( + envs.SGLANG_DISAGGREGATION_THREAD_POOL_SIZE.get() + ) if transfer_thread_pool_size is None: - transfer_thread_pool_size = min( - max(4, int(0.5 * cpu_count) // 8), 12 - ) + transfer_thread_pool_size = min(max(4, int(0.5 * cpu_count) // 8), 12) transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get() self.transfer_queues: List[FastQueue] = [ FastQueue() for _ in range(transfer_queue_size) @@ -327,9 +330,7 @@ def _handle_node_failure(self, failed_bootstrap_addr): def check_status(self, bootstrap_room: int): return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping) - def transfer_worker( - self, queue: FastQueue, executor: concurrent.futures.Executor - ): + def transfer_worker(self, queue: FastQueue, executor: concurrent.futures.Executor): while True: kv_chunk: TransferKVChunk = queue.get() room = kv_chunk.room @@ -367,9 +368,7 @@ def transfer_worker( req.agent_name ].decode_tp_size - if self.is_mla_backend or ( - decode_tp_size == self.attn_tp_size - ): + if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): kv_xfer_handle = self.send_kvcache( req.agent_name, kv_chunk.prefill_kv_indices, @@ -426,9 +425,7 @@ def transfer_worker( while handles: states = [self.agent.check_xfer_state(h) for h in handles] if any(s == "ERR" for s in states): - raise RuntimeError( - f"NIXL transfer encountered ERR room={room}" - ) + raise RuntimeError(f"NIXL transfer encountered ERR room={room}") if all(s == "DONE" for s in states): break time.sleep(0.001) From 7c0f8914fa4ce663f005ff0822befe609ab1401a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=84=E6=AD=A6?= Date: Mon, 16 Mar 2026 19:06:55 +0800 Subject: [PATCH 3/5] optimize --- python/sglang/srt/disaggregation/nixl/conn.py | 24 +++---------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index d9902a92563b..b1381cd80102 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -1,9 +1,7 @@ from __future__ import annotations -import concurrent.futures import dataclasses import logging -import os import struct import threading import time @@ -206,29 +204,13 @@ def __init__( self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: - cpu_count = os.cpu_count() or 1 - transfer_thread_pool_size = ( - envs.SGLANG_DISAGGREGATION_THREAD_POOL_SIZE.get() - ) - if transfer_thread_pool_size is None: - transfer_thread_pool_size = min(max(4, int(0.5 * cpu_count) // 8), 12) transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get() self.transfer_queues: List[FastQueue] = [ FastQueue() for _ in range(transfer_queue_size) ] - assert transfer_thread_pool_size >= transfer_queue_size, ( - f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be " - f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}." - ) - self.executors = [ - concurrent.futures.ThreadPoolExecutor( - max_workers=max(1, transfer_thread_pool_size // transfer_queue_size) - ) - for _ in range(transfer_queue_size) - ] - for queue, executor in zip(self.transfer_queues, self.executors): + for queue in self.transfer_queues: threading.Thread( - target=self.transfer_worker, args=(queue, executor), daemon=True + target=self.transfer_worker, args=(queue,), daemon=True ).start() self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: @@ -330,7 +312,7 @@ def _handle_node_failure(self, failed_bootstrap_addr): def check_status(self, bootstrap_room: int): return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping) - def transfer_worker(self, queue: FastQueue, executor: concurrent.futures.Executor): + def transfer_worker(self, queue: FastQueue): while True: kv_chunk: TransferKVChunk = queue.get() room = kv_chunk.room From da64ed07eb7d6cd0b5e9fdfd5111587a91678da9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=84=E6=AD=A6?= Date: Mon, 16 Mar 2026 19:28:04 +0800 Subject: [PATCH 4/5] remove redundant try-catch --- python/sglang/srt/disaggregation/nixl/conn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index b1381cd80102..5cb46eaf0d60 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -1026,10 +1026,8 @@ def poll(self) -> KVPoll: return status def clear(self): - try: - self.kv_mgr.request_status.pop(self.bootstrap_room, None) - except Exception: - pass + if self.bootstrap_room in self.kv_mgr.request_status: + self.kv_mgr.request_status.pop(self.bootstrap_room) def failure_exception(self): raise RuntimeError("NIXL KVSender Exception") From 89467c93d92a6224236dbf024fb3719f40b4a0ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=84=E6=AD=A6?= Date: Wed, 1 Apr 2026 16:37:36 +0800 Subject: [PATCH 5/5] optimize --- python/sglang/srt/disaggregation/nixl/conn.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 5cb46eaf0d60..39a2f415de4c 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -317,16 +317,10 @@ def transfer_worker(self, queue: FastQueue): kv_chunk: TransferKVChunk = queue.get() room = kv_chunk.room try: - if ( - room in self.request_status - and self.check_status(room) == KVPoll.Failed - ): + if self.check_status(room) == KVPoll.Failed: continue - if room not in self.transfer_infos: - time.sleep(0.001) - queue.put(kv_chunk) - continue + assert room in self.transfer_infos self.update_status(room, KVPoll.Transferring) @@ -410,7 +404,7 @@ def transfer_worker(self, queue: FastQueue): raise RuntimeError(f"NIXL transfer encountered ERR room={room}") if all(s == "DONE" for s in states): break - time.sleep(0.001) + time.sleep(0) if kv_chunk.is_last: if room in self.transfer_infos: