diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 38a4d15cf048..64eca0073d08 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -84,6 +84,8 @@ class KVArgsRegisterInfo: decode_tp_size: int decode_tp_rank: int dst_kv_item_len: int + dst_state_item_lens: list[int] = dataclasses.field(default_factory=list) + dst_state_dim_per_tensor: list[int] = dataclasses.field(default_factory=list) @classmethod def from_zmq(cls, msg: List[bytes]): @@ -93,6 +95,15 @@ def from_zmq(cls, msg: List[bytes]): else: dst_state_data_ptrs = [] + dst_state_item_lens = [] + dst_state_dim_per_tensor = [] + if len(msg) > 12 and len(msg[12]) > 0: + dst_state_item_lens = list(struct.unpack(f"{len(msg[12]) // 4}I", msg[12])) + if len(msg) > 13 and len(msg[13]) > 0: + dst_state_dim_per_tensor = list( + struct.unpack(f"{len(msg[13]) // 4}I", msg[13]) + ) + return cls( room=str(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), @@ -106,6 +117,8 @@ def from_zmq(cls, msg: List[bytes]): decode_tp_size=int(msg[9].decode("ascii")), decode_tp_rank=int(msg[10].decode("ascii")), dst_kv_item_len=int(msg[11].decode("ascii")), + dst_state_item_lens=dst_state_item_lens, + dst_state_dim_per_tensor=dst_state_dim_per_tensor, ) @@ -671,6 +684,106 @@ def _send_mamba_state( raise Exception("Failed to post Mamba state transfer") return xfer_handle + def _send_mamba_state_slice( + self, + peer_name: str, + prefill_state_indices: List[int], + dst_state_data_ptrs: list[int], + dst_state_indices: List[int], + dst_gpu_id: int, + notif: str, + dst_state_item_lens: list[int], + dst_state_dim_per_tensor: list[int], + decode_tp_size: int, + decode_tp_rank: int, + ): + """Transfer Mamba states with TP slice support via RDMA. + + When prefill and decode have different attn_tp_size, we slice the + TP-sharded dimension (3rd dim) of conv_state and temporal_state + accordingly, mirroring Mooncake's _send_mamba_state_slice. + """ + logger.warning_once( + "Using Mamba state slice transfer for different TP sizes. " + f"Prefill attn_tp_size={self.attn_tp_size}, " + f"Decode attn_tp_size={decode_tp_size}." + ) + assert len(prefill_state_indices) == 1, "Mamba should have single state index" + + prefill_state_data_ptrs = self.kv_args.state_data_ptrs + prefill_state_item_lens = self.kv_args.state_item_lens + src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", []) + + if not src_state_dim_per_tensor or not dst_state_dim_per_tensor: + return self._send_mamba_state( + peer_name, + prefill_state_indices, + dst_state_data_ptrs, + dst_state_indices, + dst_gpu_id, + notif, + ) + + local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size + dst_tp_rank_in_group = decode_tp_rank % decode_tp_size + + src_addrs = [] + dst_addrs = [] + + for i, dst_state_ptr in enumerate(dst_state_data_ptrs): + src_item_len = prefill_state_item_lens[i] + dst_item_len = dst_state_item_lens[i] + src_dim = src_state_dim_per_tensor[i] + dst_dim = dst_state_dim_per_tensor[i] + + src_bytes_per_dim = src_item_len // src_dim + dst_bytes_per_dim = dst_item_len // dst_dim + + if self.attn_tp_size > decode_tp_size: + src_dim_start = 0 + num_dims_to_send = src_dim + writers_per_decode = self.attn_tp_size // decode_tp_size + local_writer_idx = local_tp_rank_in_group % writers_per_decode + dst_dim_start = local_writer_idx * src_dim + else: + src_dim_start = (dst_tp_rank_in_group * dst_dim) % src_dim + num_dims_to_send = dst_dim + dst_dim_start = 0 + + src_dim_offset = src_dim_start * src_bytes_per_dim + dst_dim_offset = dst_dim_start * dst_bytes_per_dim + bytes_to_send = num_dims_to_send * src_bytes_per_dim + + src_addr = ( + prefill_state_data_ptrs[i] + + src_item_len * int(prefill_state_indices[0]) + + src_dim_offset + ) + dst_addr = ( + dst_state_ptr + + dst_item_len * int(dst_state_indices[0]) + + dst_dim_offset + ) + src_addrs.append((src_addr, bytes_to_send, self.kv_args.gpu_id)) + dst_addrs.append((dst_addr, bytes_to_send, dst_gpu_id)) + + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM") + + xfer_handle = self.agent.initialize_xfer( + "WRITE", + src_descs, + dst_descs, + peer_name, + notif.encode("ascii"), + ) + if not xfer_handle: + raise Exception("Failed to create Mamba state slice transfer") + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise Exception("Failed to post Mamba state slice transfer") + return xfer_handle + def maybe_send_extra( self, peer_name: str, @@ -680,14 +793,26 @@ def maybe_send_extra( dst_gpu_id: int, notif: str, decode_tp_size: int, + decode_tp_rank: int = 0, + dst_state_item_lens: list[int] | None = None, + dst_state_dim_per_tensor: list[int] | None = None, ): """Send state or extra pool data with type-specific handling.""" state_type = getattr(self.kv_args, "state_type", "none") if state_type == "mamba": if self.attn_tp_size != decode_tp_size: - raise RuntimeError( - "PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet." + return self._send_mamba_state_slice( + peer_name, + prefill_state_indices, + dst_state_data_ptrs, + dst_state_indices, + dst_gpu_id, + notif, + dst_state_item_lens or [], + dst_state_dim_per_tensor or [], + decode_tp_size, + decode_tp_rank, ) return self._send_mamba_state( peer_name, @@ -791,6 +916,9 @@ def add_transfer_request( dst_info.gpu_id, f"{req.room}_state_{self.kv_args.pp_rank}", decode_tp_size, + decode_tp_rank=dst_info.decode_tp_rank, + dst_state_item_lens=dst_info.dst_state_item_lens, + dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, ) if state_xfer_handle is not None: handles.append(state_xfer_handle) @@ -1068,6 +1196,17 @@ def _register_kv_args(self): struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs ) + packed_state_item_lens = b"".join( + struct.pack("I", item_len) + for item_len in self.kv_mgr.kv_args.state_item_lens + ) + state_dim_per_tensor = getattr( + self.kv_mgr.kv_args, "state_dim_per_tensor", [] + ) + packed_state_dim_per_tensor = b"".join( + struct.pack("I", dim) for dim in state_dim_per_tensor + ) + with lock: sock.send_multipart( [ @@ -1084,6 +1223,8 @@ def _register_kv_args(self): str(self.kv_mgr.attn_tp_size).encode("ascii"), str(self.kv_mgr.kv_args.engine_rank).encode("ascii"), str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"), + packed_state_item_lens, + packed_state_dim_per_tensor, ] )