From 6969cfe97f81433f24729bac9ad2ee6656a9871e Mon Sep 17 00:00:00 2001 From: Neal Vaidya Date: Tue, 17 Feb 2026 20:58:43 +0000 Subject: [PATCH 1/2] feat: add nsa and swa disagg support with nixl Signed-off-by: Neal Vaidya --- python/sglang/srt/disaggregation/nixl/conn.py | 131 +++++++++++++----- 1 file changed, 100 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index c95d2c5edf7c..584bfaaec99f 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -368,51 +368,55 @@ def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo): self.decode_kv_args_table[agent_name] = decode_kv_args self.agent.add_remote_agent(decode_kv_args.agent_metadata) - def send_kvcache( + def _send_kvcache_generic( self, peer_name: str, - prefill_kv_indices: npt.NDArray[np.int32], - dst_kv_ptrs: list[int], - dst_kv_indices: npt.NDArray[np.int32], + src_data_ptrs: list[int], + dst_data_ptrs: list[int], + item_lens: list[int], + prefill_data_indices: npt.NDArray[np.int32], + dst_data_indices: npt.NDArray[np.int32], dst_gpu_id: int, notif: str, ): + """Generic KV cache transfer supporting both MHA and MLA architectures. + Used by both send_kvcache and maybe_send_extra.""" # group by indices prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( - prefill_kv_indices, dst_kv_indices + prefill_data_indices, dst_data_indices ) logger.debug(f"sending kvcache to {peer_name} with notif {notif}") # Make descs if self.is_mla_backend: src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( - self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs) ) layers_params = [ ( src_kv_ptrs[layer_id], dst_kv_ptrs[layer_id], - self.kv_args.kv_item_lens[layer_id], + item_lens[layer_id], ) for layer_id in range(layers_current_pp_stage) ] else: src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( - self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs) ) layers_params = [ ( src_k_ptrs[layer_id], dst_k_ptrs[layer_id], - self.kv_args.kv_item_lens[layer_id], + item_lens[layer_id], ) for layer_id in range(layers_current_pp_stage) ] + [ ( src_v_ptrs[layer_id], dst_v_ptrs[layer_id], - self.kv_args.kv_item_lens[layer_id], + item_lens[layer_id], ) for layer_id in range(layers_current_pp_stage) ] @@ -455,7 +459,7 @@ def make_req_array(addr_chunks, len_chunks, gpu): dst_reqs = make_req_array(dst_addrs, dst_lens, dst_gpu_id) logger.debug( - f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}" + f"len(src_addrs): before group: {len(prefill_data_indices)}, after group: {len(src_addrs)}" ) src_descs = self.agent.get_xfer_descs(src_reqs, "VRAM") dst_descs = self.agent.get_xfer_descs(dst_reqs, "VRAM") @@ -474,6 +478,26 @@ def make_req_array(addr_chunks, len_chunks, gpu): raise Exception("KVSender failed to post transfer") return xfer_handle + def send_kvcache( + self, + peer_name: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_kv_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int32], + dst_gpu_id: int, + notif: str, + ): + return self._send_kvcache_generic( + peer_name=peer_name, + src_data_ptrs=self.kv_args.kv_data_ptrs, + dst_data_ptrs=dst_kv_ptrs, + item_lens=self.kv_args.kv_item_lens, + prefill_data_indices=prefill_kv_indices, + dst_data_indices=dst_kv_indices, + dst_gpu_id=dst_gpu_id, + notif=notif, + ) + def send_kvcache_slice( self, peer_name: str, @@ -684,6 +708,60 @@ def _send_mamba_state( raise Exception("Failed to post Mamba state transfer") return xfer_handle + def maybe_send_extra( + self, + peer_name: str, + prefill_state_indices: List[int], + dst_state_data_ptrs: list[int], + dst_state_indices: List[int], + dst_gpu_id: int, + notif: str, + decode_tp_size: int, + ): + """Send state or extra pool data with type-specific handling.""" + state_type = getattr(self.kv_args, "state_type", "none") + + if state_type == "mamba": + if self.attn_tp_size != decode_tp_size: + raise RuntimeError( + "PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet." + ) + return self._send_mamba_state( + peer_name, + prefill_state_indices, + dst_state_data_ptrs, + dst_state_indices, + dst_gpu_id, + notif, + ) + elif state_type in ["swa", "nsa"]: + if ( + not self.is_mla_backend + and self.attn_tp_size != decode_tp_size + ): + raise RuntimeError( + f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet." + ) + if len(prefill_state_indices) != len(dst_state_indices): + raise RuntimeError( + f"State index length mismatch: prefill={len(prefill_state_indices)}, " + f"dst={len(dst_state_indices)}" + ) + return self._send_kvcache_generic( + peer_name=peer_name, + src_data_ptrs=self.kv_args.state_data_ptrs, + dst_data_ptrs=dst_state_data_ptrs, + item_lens=self.kv_args.state_item_lens, + prefill_data_indices=np.array(prefill_state_indices, dtype=np.int32), + dst_data_indices=np.array(dst_state_indices, dtype=np.int32), + dst_gpu_id=dst_gpu_id, + notif=notif, + ) + else: + if state_type != "none": + raise RuntimeError(f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet.") + return None + def add_transfer_request( self, bootstrap_room: int, @@ -742,26 +820,17 @@ def add_transfer_request( # Only the last chunk we need to send the aux data. if is_last: if state_indices is not None: - state_type = getattr(self.kv_args, "state_type", "none") - if ( - self.attn_tp_size - != self.decode_kv_args_table[req.agent_name].decode_tp_size - ): - raise RuntimeError( - "PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet." - ) - - if state_type == "mamba": - state_xfer_handle = self._send_mamba_state( - req.agent_name, - state_indices, - self.decode_kv_args_table[ - req.agent_name - ].dst_state_data_ptrs, - req.dst_state_indices, - self.decode_kv_args_table[req.agent_name].gpu_id, - f"{req.room}_state_{self.kv_args.pp_rank}", - ) + dst_info = self.decode_kv_args_table[req.agent_name] + state_xfer_handle = self.maybe_send_extra( + req.agent_name, + state_indices, + dst_info.dst_state_data_ptrs, + req.dst_state_indices, + dst_info.gpu_id, + f"{req.room}_state_{self.kv_args.pp_rank}", + decode_tp_size, + ) + if state_xfer_handle is not None: handles.append(state_xfer_handle) assert aux_index is not None From 13f542110dd8c390e218389cf1868d76a5fb94ec Mon Sep 17 00:00:00 2001 From: Neal Vaidya Date: Tue, 17 Feb 2026 21:08:04 +0000 Subject: [PATCH 2/2] formatting Signed-off-by: Neal Vaidya --- python/sglang/srt/disaggregation/nixl/conn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 584bfaaec99f..25cd5cc7f183 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -735,10 +735,7 @@ def maybe_send_extra( notif, ) elif state_type in ["swa", "nsa"]: - if ( - not self.is_mla_backend - and self.attn_tp_size != decode_tp_size - ): + if not self.is_mla_backend and self.attn_tp_size != decode_tp_size: raise RuntimeError( f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet." ) @@ -759,7 +756,9 @@ def maybe_send_extra( ) else: if state_type != "none": - raise RuntimeError(f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet.") + raise RuntimeError( + f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet." + ) return None def add_transfer_request(