diff --git a/python/sglang/srt/disaggregation/ascend/conn.py b/python/sglang/srt/disaggregation/ascend/conn.py index 6c2bf00c569c..098b905da691 100644 --- a/python/sglang/srt/disaggregation/ascend/conn.py +++ b/python/sglang/srt/disaggregation/ascend/conn.py @@ -35,10 +35,11 @@ def register_buffer_to_engine(self): self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens ) # Batch register state/extra pool data buffers - if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens: - self.engine.batch_register( - self.kv_args.state_data_ptrs, self.kv_args.state_data_lens - ) + for component_ptrs, component_lens in zip( + self.kv_args.state_data_ptrs or [], + self.kv_args.state_data_lens or [], + ): + self.engine.batch_register(component_ptrs, component_lens) def send_kvcache( self, diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 7266c0271740..edc44895d896 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import enum from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional @@ -13,6 +14,12 @@ from sglang.srt.disaggregation.utils import DisaggregationMode +class StateType(str, enum.Enum): + MAMBA = "mamba" + SWA = "swa" + NSA = "nsa" + + @dataclasses.dataclass class KVTransferMetric: # Backends that cannot isolate transfer latency can leave this as None. @@ -28,12 +35,12 @@ class KVArgs: aux_data_ptrs: List[int] aux_data_lens: List[int] aux_item_lens: List[int] - state_data_ptrs: List[int] - state_data_lens: List[int] - state_item_lens: List[int] - state_type: str # "none", "mamba", "swa", "nsa" - # for mamba state different tp slice transfer - state_dim_per_tensor: List[int] # dimension to slice for each state tensor + state_types: List[StateType] + state_data_ptrs: List[List[int]] + state_data_lens: List[List[int]] + state_item_lens: List[List[int]] + # Per-tensor TP slice dim, used when prefill/decode attn_tp_size differ. + state_dim_per_tensor: List[List[int]] ib_device: str ib_traffic_class: str gpu_id: int @@ -96,7 +103,7 @@ def init(self, num_kv_indices: int, aux_index: Optional[int] = None): def send( self, kv_indices: npt.NDArray[np.int32], - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, ): """ Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server. @@ -154,7 +161,7 @@ def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, decode_prefix_len: Optional[int] = None, ): """ diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 1fcd91d9cdfc..4d0e582c9cb0 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -96,7 +96,7 @@ def __init__( ): self.kv_args = args self.kv_item_lens_sum = sum(args.kv_item_lens) - self.state_item_lens_sum = sum(args.state_item_lens) + self.state_item_lens_sum = sum(x for comp in args.state_item_lens for x in comp) self.is_mla_backend = is_mla_backend self.disaggregation_mode = disaggregation_mode self.server_args = server_args @@ -520,16 +520,18 @@ def get_transfer_metric(self) -> KVTransferMetric: def _record_transfer_indices( self, kv_indices: npt.NDArray[np.int32], - state_indices: Optional[List[int]], + state_indices: Optional[List], ): self._transfer_num_kv_indices += len(kv_indices) - if state_indices is not None: - self._transfer_num_state_indices += len(state_indices) + if state_indices: + for component_indices in state_indices: + if component_indices is not None: + self._transfer_num_state_indices += len(component_indices) def send( self, kv_indices: npt.NDArray[np.int32], - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, ): pass diff --git a/python/sglang/srt/disaggregation/common/utils.py b/python/sglang/srt/disaggregation/common/utils.py index 6f3da21285a4..4e5e96c6f205 100644 --- a/python/sglang/srt/disaggregation/common/utils.py +++ b/python/sglang/srt/disaggregation/common/utils.py @@ -1,3 +1,4 @@ +import struct import threading from collections import deque from typing import List, Tuple @@ -6,6 +7,39 @@ import numpy.typing as npt +def pack_list_of_buffers(buffers: List[bytes]) -> bytes: + if not buffers: + return b"" + n = len(buffers) + header = struct.pack(f"<{n+1}I", n, *(len(b) for b in buffers)) + return header + b"".join(buffers) + + +def unpack_list_of_buffers(buf: bytes) -> List[bytes]: + if buf == b"": + return [] + (n,) = struct.unpack(" bytes: + return pack_list_of_buffers([struct.pack(f"<{len(a)}{fmt}", *a) for a in lists]) + + +def unpack_int_lists(buf: bytes, fmt: str) -> List[List[int]]: + width = struct.calcsize(fmt) + return [ + list(struct.unpack(f"<{len(b)//width}{fmt}", b)) + for b in unpack_list_of_buffers(buf) + ] + + class FastQueue: def __init__(self): self._buf = deque() diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 5d3d64f1cff2..228edd0b97a3 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -34,6 +34,7 @@ from sglang.srt.configs.mamba_utils import Mamba2CacheParams from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.disaggregation.base import KVPoll +from sglang.srt.disaggregation.base.conn import StateType from sglang.srt.disaggregation.common.conn import CommonKVManager, CommonKVReceiver from sglang.srt.disaggregation.utils import ( FAKE_BOOTSTRAP_HOST, @@ -56,17 +57,14 @@ from sglang.srt.managers.utils import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, EvictParams -from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool from sglang.srt.mem_cache.common import ( kv_to_page_indices, page_align_floor, release_kv_cache, ) from sglang.srt.mem_cache.memory_pool import ( - HybridLinearKVPool, HybridReqToTokenPool, KVCache, - NSATokenToKVPool, ReqToTokenPool, ) from sglang.srt.observability.req_time_stats import ( @@ -366,7 +364,12 @@ def _init_kv_manager(self) -> CommonKVManager: self.metadata_buffers.get_buf_infos() ) - setup_state_kv_args(kv_args, self.token_to_kv_pool, self.draft_token_to_kv_pool) + setup_state_kv_args( + kv_args, + self.token_to_kv_pool, + self.draft_token_to_kv_pool, + req_to_token_pool=getattr(self, "req_to_token_pool", None), + ) kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id @@ -809,45 +812,54 @@ def pop_preallocated( ) page_size = self.token_to_kv_pool_allocator.page_size - # Prepare extra pool indices for hybrid models - if isinstance(self.token_to_kv_pool, HybridLinearKVPool): - # Mamba hybrid model: single mamba state index - state_indices = [ + seq_len = len(decode_req.req.origin_input_ids) + + def _mamba_payload(): + return [ self.req_to_token_pool.req_index_to_mamba_index_mapping[ decode_req.req.req_pool_idx ] .cpu() .numpy() ] - elif isinstance(self.token_to_kv_pool, BaseSWAKVPool): - seq_len = len(decode_req.req.origin_input_ids) - window_size = self.scheduler.sliding_window_size + def _swa_payload(): + window_size = self.scheduler.sliding_window_size window_start = max(0, seq_len - window_size) window_start = page_align_floor(window_start, page_size) window_kv_indices_full = self.req_to_token_pool.req_to_token[ decode_req.req.req_pool_idx, window_start:seq_len ] - - # Translate to SWA pool indices window_kv_indices_swa = ( self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa( window_kv_indices_full ) ) - state_indices = window_kv_indices_swa.cpu().numpy() - state_indices = kv_to_page_indices(state_indices, page_size) - elif isinstance(self.token_to_kv_pool, NSATokenToKVPool): - seq_len = len(decode_req.req.origin_input_ids) + return kv_to_page_indices( + window_kv_indices_swa.cpu().numpy(), page_size + ) + + def _nsa_payload(): kv_indices_full = self.req_to_token_pool.req_to_token[ decode_req.req.req_pool_idx, :seq_len ] - state_indices = kv_indices_full.cpu().numpy() # Indexer lives on device pool; always use device page_size device_page_size = self.token_to_kv_pool.page_size - state_indices = kv_to_page_indices(state_indices, device_page_size) - else: - state_indices = None + return kv_to_page_indices( + kv_indices_full.cpu().numpy(), device_page_size + ) + + state_types = self.kv_manager.kv_args.state_types + state_indices: Optional[List] = [] + for st in state_types: + if st == StateType.MAMBA: + state_indices.append(_mamba_payload()) + elif st == StateType.SWA: + state_indices.append(_swa_payload()) + elif st == StateType.NSA: + state_indices.append(_nsa_payload()) + else: + state_indices.append(None) decode_req.metadata_buffer_index = ( self.req_to_metadata_buffer_idx_allocator.alloc() diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 638834207263..d59641c3c428 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -71,7 +71,7 @@ def init( def send( self, kv_indices: npt.NDArray[np.int32], - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, ): self.has_sent = True logger.debug( @@ -111,7 +111,7 @@ def send_metadata( self, kv_indices: list[int], aux_index: Optional[int] = None, - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, decode_prefix_len: Optional[int] = None, ): self.has_sent_metadata = True diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 0e7e7aac12e3..7b26ec0aa66b 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -14,7 +14,7 @@ import numpy as np import numpy.typing as npt -from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll +from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll, StateType from sglang.srt.disaggregation.common.conn import ( CommonKVBootstrapServer, CommonKVManager, @@ -30,6 +30,8 @@ from sglang.srt.disaggregation.common.utils import ( FastQueue, group_concurrent_contiguous, + pack_int_lists, + unpack_int_lists, ) from sglang.srt.disaggregation.mooncake.utils import ( check_mooncake_custom_mem_pool_enabled, @@ -64,7 +66,7 @@ class TransferKVChunk: index_slice: slice is_last_chunk: bool prefill_aux_index: Optional[int] - state_indices: Optional[List[int]] + state_indices: Optional[List] # decode @@ -76,7 +78,7 @@ class TransferInfo: mooncake_session_id: str dst_kv_indices: npt.NDArray[np.int32] dst_aux_index: int - dst_state_indices: List[int] + dst_state_indices: List[List[int]] # parallel to receiver's state_types required_dst_info_num: int is_dummy: bool decode_prefix_len: Optional[int] = None @@ -93,10 +95,7 @@ def from_zmq(cls, msg: List[bytes]): else: dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32) dst_aux_index = int(msg[5].decode("ascii")) - if msg[6] == b"": - dst_state_indices = [] - else: - dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32)) + dst_state_indices = unpack_int_lists(msg[6], "i") is_dummy = False return cls( room=int(msg[0].decode("ascii")), @@ -123,13 +122,13 @@ class KVArgsRegisterInfo: mooncake_session_id: str dst_kv_ptrs: list[int] dst_aux_ptrs: list[int] - dst_state_data_ptrs: list[int] + dst_state_data_ptrs: List[List[int]] # parallel to state_types (same below) dst_tp_rank: int dst_attn_tp_size: int dst_kv_item_len: int # for mamba state different tp slice transfer - dst_state_item_lens: list[int] - dst_state_dim_per_tensor: list[int] + dst_state_item_lens: List[List[int]] + dst_state_dim_per_tensor: List[List[int]] # HiSparse: decode host pool stores KV at token granularity enable_hisparse: bool = False # Note: always put the staging field at the final (since the staging field is optional and contains multiple inputs) @@ -144,19 +143,15 @@ def from_zmq(cls, msg: List[bytes]): mooncake_session_id=msg[3].decode("ascii"), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), - dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), + dst_state_data_ptrs=unpack_int_lists(msg[6], "Q"), dst_tp_rank=int(msg[7].decode("ascii")), dst_attn_tp_size=int(msg[8].decode("ascii")), dst_kv_item_len=int(msg[9].decode("ascii")), dst_state_item_lens=( - list(struct.unpack(f"{len(msg[10])//4}I", msg[10])) - if len(msg) > 10 and len(msg[10]) > 0 - else [] + unpack_int_lists(msg[10], "I") if len(msg) > 10 else [] ), dst_state_dim_per_tensor=( - list(struct.unpack(f"{len(msg[11])//4}I", msg[11])) - if len(msg) > 11 and len(msg[11]) > 0 - else [] + unpack_int_lists(msg[11], "I") if len(msg) > 11 else [] ), enable_hisparse=( msg[12].decode("ascii") == "1" if len(msg) > 12 else False @@ -272,11 +267,11 @@ def register_buffer_to_engine(self): self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens ) - # Batch register state/extra pool data buffers - if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens: - self.engine.batch_register( - self.kv_args.state_data_ptrs, self.kv_args.state_data_lens - ) + for ptrs, lens in zip( + self.kv_args.state_data_ptrs, self.kv_args.state_data_lens + ): + if ptrs and lens: + self.engine.batch_register(ptrs, lens) # ------------------------------------------------------------------ # Staging buffer methods (all delegate to staging_handler.py) @@ -966,88 +961,133 @@ def _handle_aux_data(self, msg: List[bytes]): def maybe_send_extra( self, req: TransferInfo, - prefill_state_indices: list[int], - dst_state_data_ptrs: list[int], + prefill_state_indices: List, executor: concurrent.futures.ThreadPoolExecutor, target_rank_registration_info: Optional[KVArgsRegisterInfo] = None, ): - """Send state or extra pool data with type-specific handling.""" - state_type = getattr(self.kv_args, "state_type", "none") - - if state_type == "mamba": - # Check if we need slice transfer for different TP sizes - if ( - target_rank_registration_info is not None - and self.attn_tp_size != target_rank_registration_info.dst_attn_tp_size - ): - return self._send_mamba_state_slice( - req, - prefill_state_indices, - dst_state_data_ptrs, - target_rank_registration_info.dst_state_item_lens, - target_rank_registration_info.dst_state_dim_per_tensor, - target_rank_registration_info.dst_tp_rank, - target_rank_registration_info.dst_attn_tp_size, - ) - else: - return self._send_mamba_state( - req, - prefill_state_indices, - dst_state_data_ptrs, - ) - elif state_type in ["swa", "nsa"]: - # Non-MLA SWA / NSA hybrid models do not support different TP sizes yet. - if ( - target_rank_registration_info is not None - and not self.is_mla_backend - and self.attn_tp_size != target_rank_registration_info.dst_attn_tp_size - ): - raise RuntimeError( - f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet." + rc = 0 + state_types = getattr(self.kv_args, "state_types", []) + for i, st in enumerate(state_types): + indices = ( + prefill_state_indices[i] if i < len(prefill_state_indices) else None + ) + if indices is None: + continue + src_data_ptrs = self.kv_args.state_data_ptrs[i] + src_item_lens = self.kv_args.state_item_lens[i] + src_dim_per_tensor = ( + self.kv_args.state_dim_per_tensor[i] + if i < len(self.kv_args.state_dim_per_tensor) + else [] + ) + if target_rank_registration_info is not None: + dst_data_ptrs = ( + target_rank_registration_info.dst_state_data_ptrs[i] + if i < len(target_rank_registration_info.dst_state_data_ptrs) + else [] ) - dst_state_indices = req.dst_state_indices - if len(prefill_state_indices) > len(dst_state_indices): - logger.warning( - f"len(prefill_state_indices) = {len(prefill_state_indices)}, len(dst_state_indices) = {len(dst_state_indices)}" + dst_item_lens = ( + target_rank_registration_info.dst_state_item_lens[i] + if i < len(target_rank_registration_info.dst_state_item_lens) + else [] ) - prefill_state_indices = prefill_state_indices[: len(dst_state_indices)] - elif len(prefill_state_indices) < len(dst_state_indices): - logger.warning( - f"len(prefill_state_indices) = {len(prefill_state_indices)}, len(dst_state_indices) = {len(dst_state_indices)}" + dst_dim_per_tensor = ( + target_rank_registration_info.dst_state_dim_per_tensor[i] + if i < len(target_rank_registration_info.dst_state_dim_per_tensor) + else [] ) - dst_state_indices = dst_state_indices[: len(prefill_state_indices)] - # Reuse _send_kvcache_generic interface to send extra pool data - prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32) - dst_state_indices = np.array(dst_state_indices, dtype=np.int32) - return self._send_kvcache_generic( - mooncake_session_id=req.mooncake_session_id, - src_data_ptrs=self.kv_args.state_data_ptrs, - dst_data_ptrs=dst_state_data_ptrs, - item_lens=self.kv_args.state_item_lens, - prefill_data_indices=prefill_state_indices, - dst_data_indices=dst_state_indices, - executor=executor, + else: + dst_data_ptrs, dst_item_lens, dst_dim_per_tensor = [], [], [] + dst_indices = ( + req.dst_state_indices[i] if i < len(req.dst_state_indices) else [] ) - else: - return 0 + + if st == StateType.MAMBA: + if ( + target_rank_registration_info is not None + and self.attn_tp_size + != target_rank_registration_info.dst_attn_tp_size + ): + rc = ( + self._send_mamba_state_slice( + req, + indices, + src_data_ptrs, + src_item_lens, + src_dim_per_tensor, + dst_data_ptrs, + dst_indices, + dst_item_lens, + dst_dim_per_tensor, + target_rank_registration_info.dst_tp_rank, + target_rank_registration_info.dst_attn_tp_size, + ) + or rc + ) + else: + rc = ( + self._send_mamba_state( + req, + indices, + src_data_ptrs, + src_item_lens, + dst_data_ptrs, + dst_indices, + ) + or rc + ) + elif st in (StateType.SWA, StateType.NSA): + if ( + target_rank_registration_info is not None + and not self.is_mla_backend + and self.attn_tp_size + != target_rank_registration_info.dst_attn_tp_size + ): + raise RuntimeError( + f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {st.upper()} hybrid models yet." + ) + src_indices = list(indices) + dst_indices_local = list(dst_indices) + if len(src_indices) > len(dst_indices_local): + logger.warning( + f"len(prefill_state_indices) = {len(src_indices)}, len(dst_state_indices) = {len(dst_indices_local)}" + ) + src_indices = src_indices[: len(dst_indices_local)] + elif len(src_indices) < len(dst_indices_local): + logger.warning( + f"len(prefill_state_indices) = {len(src_indices)}, len(dst_state_indices) = {len(dst_indices_local)}" + ) + dst_indices_local = dst_indices_local[: len(src_indices)] + rc = ( + self._send_kvcache_generic( + mooncake_session_id=req.mooncake_session_id, + src_data_ptrs=src_data_ptrs, + dst_data_ptrs=dst_data_ptrs, + item_lens=src_item_lens, + prefill_data_indices=np.array(src_indices, dtype=np.int32), + dst_data_indices=np.array(dst_indices_local, dtype=np.int32), + executor=executor, + ) + or rc + ) + return rc def _send_mamba_state( self, req: TransferInfo, - prefill_mamba_index: list[int], + prefill_mamba_index: list, + src_state_data_ptrs: list[int], + src_state_item_lens: list[int], dst_state_data_ptrs: list[int], + dst_mamba_index: list, ): - """Transfer Mamba states.""" assert len(prefill_mamba_index) == 1, "Mamba should have single state index" transfer_blocks = [] - prefill_state_data_ptrs = self.kv_args.state_data_ptrs - prefill_state_item_lens = self.kv_args.state_item_lens - for i, dst_state_ptr in enumerate(dst_state_data_ptrs): - length = prefill_state_item_lens[i] - src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0]) - dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0]) + length = src_state_item_lens[i] + src_addr = src_state_data_ptrs[i] + length * int(prefill_mamba_index[0]) + dst_addr = dst_state_ptr + length * int(dst_mamba_index[0]) transfer_blocks.append((src_addr, dst_addr, length)) return self._transfer_data(req.mooncake_session_id, transfer_blocks) @@ -1055,8 +1095,12 @@ def _send_mamba_state( def _send_mamba_state_slice( self, req: TransferInfo, - prefill_mamba_index: list[int], + prefill_mamba_index: list, + src_state_data_ptrs: list[int], + src_state_item_lens: list[int], + src_state_dim_per_tensor: list[int], dst_state_data_ptrs: list[int], + dst_mamba_index: list, dst_state_item_lens: list[int], dst_state_dim_per_tensor: list[int], dst_tp_rank: int, @@ -1078,33 +1122,33 @@ def _send_mamba_state_slice( ) assert len(prefill_mamba_index) == 1, "Mamba should have single state index" - transfer_blocks = [] - prefill_state_data_ptrs = self.kv_args.state_data_ptrs - prefill_state_item_lens = self.kv_args.state_item_lens - src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", []) - # If no dimension info available, fall back to regular transfer if not src_state_dim_per_tensor or not dst_state_dim_per_tensor: - return self._send_mamba_state(req, prefill_mamba_index, dst_state_data_ptrs) + return self._send_mamba_state( + req, + prefill_mamba_index, + src_state_data_ptrs, + src_state_item_lens, + dst_state_data_ptrs, + dst_mamba_index, + ) local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size + transfer_blocks = [] for i, dst_state_ptr in enumerate(dst_state_data_ptrs): - src_item_len = prefill_state_item_lens[i] + src_item_len = src_state_item_lens[i] dst_item_len = dst_state_item_lens[i] src_dim = src_state_dim_per_tensor[i] dst_dim = dst_state_dim_per_tensor[i] - # Calculate bytes per dimension slice # item_len = dim * trailing_dims_size, so trailing_dims_size = item_len / dim src_bytes_per_dim = src_item_len // src_dim dst_bytes_per_dim = dst_item_len // dst_dim - # Determine slicing parameters based on TP configuration if self.attn_tp_size > dst_attn_tp_size: # Multiple prefill ranks send to 1 decode rank - # Each prefill sends all its dims to the appropriate offset in decode src_dim_start = 0 num_dims_to_send = src_dim writers_per_decode = self.attn_tp_size // dst_attn_tp_size @@ -1112,26 +1156,21 @@ def _send_mamba_state_slice( dst_dim_start = local_writer_idx * src_dim else: # 1 prefill rank sends to multiple decode ranks - # Prefill sends a slice of its dims to each decode rank src_dim_start = (dst_tp_rank_in_group * dst_dim) % src_dim num_dims_to_send = dst_dim dst_dim_start = 0 - # Calculate byte offsets src_dim_offset = src_dim_start * src_bytes_per_dim dst_dim_offset = dst_dim_start * dst_bytes_per_dim bytes_to_send = num_dims_to_send * src_bytes_per_dim - # Calculate addresses for this state tensor src_addr = ( - prefill_state_data_ptrs[i] + src_state_data_ptrs[i] + src_item_len * int(prefill_mamba_index[0]) + src_dim_offset ) dst_addr = ( - dst_state_ptr - + dst_item_len * int(req.dst_state_indices[0]) - + dst_dim_offset + dst_state_ptr + dst_item_len * int(dst_mamba_index[0]) + dst_dim_offset ) transfer_blocks.append((src_addr, dst_addr, bytes_to_send)) @@ -1297,11 +1336,10 @@ def transfer_worker( break if kv_chunk.is_last_chunk: - if kv_chunk.state_indices is not None: + if kv_chunk.state_indices: self.maybe_send_extra( req, kv_chunk.state_indices, - target_rank_registration_info.dst_state_data_ptrs, executor, target_rank_registration_info, ) @@ -1576,7 +1614,7 @@ def add_transfer_request( index_slice: slice, is_last_chunk: bool, aux_index: Optional[int] = None, - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, ): assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last_chunk or (is_last_chunk and aux_index is not None) @@ -1672,7 +1710,7 @@ def should_send_kv_chunk(self, num_pages: int, last_chunk: bool) -> bool: def send( self, kv_indices: npt.NDArray[np.int32], - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) @@ -1769,19 +1807,14 @@ def _register_kv_args(self): packed_aux_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) - packed_state_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs - ) - # Pack state_item_lens and state_dim_per_tensor for mamba state slice transfer - packed_state_item_lens = b"".join( - struct.pack("I", item_len) - for item_len in self.kv_mgr.kv_args.state_item_lens + packed_state_data_ptrs = pack_int_lists( + self.kv_mgr.kv_args.state_data_ptrs, "Q" ) - state_dim_per_tensor = getattr( - self.kv_mgr.kv_args, "state_dim_per_tensor", [] + packed_state_item_lens = pack_int_lists( + self.kv_mgr.kv_args.state_item_lens, "I" ) - packed_state_dim_per_tensor = b"".join( - struct.pack("I", dim) for dim in state_dim_per_tensor + packed_state_dim_per_tensor = pack_int_lists( + getattr(self.kv_mgr.kv_args, "state_dim_per_tensor", []) or [], "I" ) # Note(shangming): No need to add pp rank here since decode pp size should be equal to prefill pp size or 1 tp_rank = self.kv_mgr.kv_args.engine_rank @@ -1834,7 +1867,7 @@ def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, decode_prefix_len: Optional[int] = None, ): if self.bootstrap_infos is None: @@ -1868,11 +1901,8 @@ def send_metadata( kv_indices.tobytes() if not is_dummy else b"", str(aux_index).encode("ascii") if not is_dummy else b"", ( - np.array( - state_indices, - dtype=np.int32, - ).tobytes() - if not is_dummy and state_indices is not None + pack_int_lists(state_indices, "i") + if not is_dummy and state_indices else b"" ), str(self.required_dst_info_num).encode("ascii"), diff --git a/python/sglang/srt/disaggregation/mori/conn.py b/python/sglang/srt/disaggregation/mori/conn.py index 523b3d8232b8..5f1925af6193 100644 --- a/python/sglang/srt/disaggregation/mori/conn.py +++ b/python/sglang/srt/disaggregation/mori/conn.py @@ -351,16 +351,18 @@ def _register_local_buffers(self) -> None: MemoryLocationType.CPU, ) self.aux_mem_descs.append(desc) - for ptr, length in zip( - self.kv_args.state_data_ptrs, getattr(self.kv_args, "state_data_lens", []) + for component_ptrs, component_lens in zip( + self.kv_args.state_data_ptrs, + getattr(self.kv_args, "state_data_lens", []), ): - desc = self.engine.register_memory( - ptr, - length, - self.kv_args.gpu_id, - MemoryLocationType.GPU, - ) - self.state_mem_descs.append(desc) + for ptr, length in zip(component_ptrs, component_lens): + desc = self.engine.register_memory( + ptr, + length, + self.kv_args.gpu_id, + MemoryLocationType.GPU, + ) + self.state_mem_descs.append(desc) def update_status(self, bootstrap_room: int, status: KVPoll): current = self.request_status.get(bootstrap_room) @@ -1239,7 +1241,7 @@ def __init__( def send( self, kv_indices: npt.NDArray[np.int32], - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) @@ -1453,7 +1455,7 @@ def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, decode_prefix_len: Optional[int] = None, ): if self.bootstrap_infos is None or self.bootstrap_room is None: diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index cb21fd5c945b..3292d6cc89be 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -13,7 +13,7 @@ import numpy as np import numpy.typing as npt -from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll +from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll, StateType from sglang.srt.disaggregation.common.conn import ( CommonKVBootstrapServer, CommonKVManager, @@ -23,6 +23,8 @@ from sglang.srt.disaggregation.common.utils import ( FastQueue, group_concurrent_contiguous, + pack_int_lists, + unpack_int_lists, ) from sglang.srt.disaggregation.utils import ( DisaggregationMode, @@ -62,7 +64,7 @@ class TransferInfo: dst_kv_indices: npt.NDArray[np.int32] dst_aux_index: int required_dst_info_num: int - dst_state_indices: List[int] + dst_state_indices: List[List[int]] decode_prefix_len: Optional[int] = None # for decode radix cache def is_dummy(self): @@ -76,11 +78,9 @@ def is_dummy(self): @classmethod def from_zmq(cls, msg: List[bytes]): - # Parse state_indices from msg[7] if present - if len(msg) > 7 and msg[7] != b"": - dst_state_indices = list(np.frombuffer(msg[7], dtype=np.int32)) - else: - dst_state_indices = [] + dst_state_indices = ( + unpack_int_lists(msg[7], "i") if len(msg) > 7 and msg[7] != b"" else [] + ) return cls( room=int(msg[0].decode("ascii")), @@ -105,7 +105,7 @@ class TransferKVChunk: is_last: bool chunk_id: int prefill_aux_index: Optional[int] - state_indices: Optional[List[int]] + state_indices: Optional[List] @dataclasses.dataclass @@ -119,30 +119,25 @@ class KVArgsRegisterInfo: agent_metadata: bytes dst_kv_ptrs: list[int] dst_aux_ptrs: list[int] - dst_state_data_ptrs: list[int] + dst_state_data_ptrs: List[List[int]] gpu_id: int decode_tp_size: int decode_tp_rank: int dst_kv_item_len: int - dst_state_item_lens: list[int] = dataclasses.field(default_factory=list) - dst_state_dim_per_tensor: list[int] = dataclasses.field(default_factory=list) + dst_state_item_lens: List[List[int]] = dataclasses.field(default_factory=list) + dst_state_dim_per_tensor: List[List[int]] = dataclasses.field(default_factory=list) @classmethod def from_zmq(cls, msg: List[bytes]): - # Parse state_data_ptrs from msg[7] if present - if len(msg) > 7 and msg[7] != b"": - dst_state_data_ptrs = list(struct.unpack(f"{len(msg[7]) // 8}Q", msg[7])) - else: - dst_state_data_ptrs = [] - - dst_state_item_lens = [] - dst_state_dim_per_tensor = [] - if len(msg) > 12 and len(msg[12]) > 0: - dst_state_item_lens = list(struct.unpack(f"{len(msg[12]) // 4}I", msg[12])) - if len(msg) > 13 and len(msg[13]) > 0: - dst_state_dim_per_tensor = list( - struct.unpack(f"{len(msg[13]) // 4}I", msg[13]) - ) + dst_state_data_ptrs = ( + unpack_int_lists(msg[7], "Q") if len(msg) > 7 and msg[7] != b"" else [] + ) + dst_state_item_lens = ( + unpack_int_lists(msg[12], "I") if len(msg) > 12 and len(msg[12]) > 0 else [] + ) + dst_state_dim_per_tensor = ( + unpack_int_lists(msg[13], "I") if len(msg) > 13 and len(msg[13]) > 0 else [] + ) return cls( room=str(msg[0].decode("ascii")), @@ -445,23 +440,21 @@ def transfer_worker(self, queue: FastQueue): handles.append(kv_xfer_handle) - if kv_chunk.is_last: - if kv_chunk.state_indices is not None: - dst_info = self.decode_kv_args_table[req.agent_name] - state_xfer_handle = self.maybe_send_extra( - req.agent_name, - kv_chunk.state_indices, - dst_info.dst_state_data_ptrs, - req.dst_state_indices, - dst_info.gpu_id, - f"{req.room}_state_{self.kv_args.engine_rank}", - decode_tp_size, - decode_tp_rank=dst_info.decode_tp_rank, - dst_state_item_lens=dst_info.dst_state_item_lens, - dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, - ) - if state_xfer_handle is not None: - handles.append(state_xfer_handle) + if kv_chunk.is_last and kv_chunk.state_indices: + dst_info = self.decode_kv_args_table[req.agent_name] + state_xfer_handles = self.maybe_send_extra( + req.agent_name, + kv_chunk.state_indices, + dst_info.dst_state_data_ptrs, + req.dst_state_indices, + dst_info.gpu_id, + f"{req.room}_state_{self.kv_args.engine_rank}", + decode_tp_size, + decode_tp_rank=dst_info.decode_tp_rank, + dst_state_item_lens=dst_info.dst_state_item_lens, + dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, + ) + handles.extend(h for h in state_xfer_handles if h is not None) if kv_chunk.prefill_aux_index is None: raise RuntimeError("Missing aux index for last chunk") @@ -528,15 +521,18 @@ def register_buffer_to_engine(self): if not self.aux_descs: raise Exception("NIXL memory registration failed for aux tensors") - # Register state/extra pool data buffers if present - if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens: - state_addrs = [] - for state_data_ptr, state_data_len in zip( - self.kv_args.state_data_ptrs, self.kv_args.state_data_lens - ): + state_addrs = [] + for comp_ptrs, comp_lens in zip( + self.kv_args.state_data_ptrs or [], + self.kv_args.state_data_lens or [], + ): + for state_data_ptr, state_data_len in zip(comp_ptrs, comp_lens): + if state_data_ptr == 0 or state_data_len == 0: + continue state_addrs.append( (state_data_ptr, state_data_len, self.kv_args.gpu_id, "") ) + if state_addrs: self.state_descs = self.agent.register_memory(state_addrs, "VRAM") logger.debug( f"Register state tensors, len(state_addrs)= {len(state_addrs)}" @@ -867,6 +863,8 @@ def _send_mamba_state( self, peer_name: str, prefill_state_indices: List[int], + src_state_data_ptrs: list[int], + src_state_item_lens: list[int], dst_state_data_ptrs: list[int], dst_state_indices: List[int], dst_gpu_id: int, @@ -881,14 +879,11 @@ def _send_mamba_state( src_addrs = [] dst_addrs = [] - prefill_state_data_ptrs = self.kv_args.state_data_ptrs - prefill_state_item_lens = self.kv_args.state_item_lens - for i, dst_state_ptr in enumerate(dst_state_data_ptrs): - length = prefill_state_item_lens[i] - src_addr = prefill_state_data_ptrs[i] + length * int( - prefill_state_indices[0] - ) + length = src_state_item_lens[i] + if length == 0 or src_state_data_ptrs[i] == 0 or dst_state_ptr == 0: + continue + src_addr = src_state_data_ptrs[i] + length * int(prefill_state_indices[0]) dst_addr = dst_state_ptr + length * int(dst_state_indices[0]) src_addrs.append((src_addr, length, self.kv_args.gpu_id)) dst_addrs.append((dst_addr, length, dst_gpu_id)) @@ -914,12 +909,15 @@ def _send_mamba_state_slice( self, peer_name: str, prefill_state_indices: List[int], + src_state_data_ptrs: list[int], + src_state_item_lens: list[int], + src_state_dim_per_tensor: list[int], dst_state_data_ptrs: list[int], dst_state_indices: List[int], - dst_gpu_id: int, - notif: str, dst_state_item_lens: list[int], dst_state_dim_per_tensor: list[int], + dst_gpu_id: int, + notif: str, decode_tp_size: int, decode_tp_rank: int, ): @@ -936,14 +934,12 @@ def _send_mamba_state_slice( ) assert len(prefill_state_indices) == 1, "Mamba should have single state index" - prefill_state_data_ptrs = self.kv_args.state_data_ptrs - prefill_state_item_lens = self.kv_args.state_item_lens - src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", []) - if not src_state_dim_per_tensor or not dst_state_dim_per_tensor: return self._send_mamba_state( peer_name, prefill_state_indices, + src_state_data_ptrs, + src_state_item_lens, dst_state_data_ptrs, dst_state_indices, dst_gpu_id, @@ -957,8 +953,10 @@ def _send_mamba_state_slice( dst_addrs = [] for i, dst_state_ptr in enumerate(dst_state_data_ptrs): - src_item_len = prefill_state_item_lens[i] + src_item_len = src_state_item_lens[i] dst_item_len = dst_state_item_lens[i] + if src_item_len == 0 or src_state_data_ptrs[i] == 0 or dst_state_ptr == 0: + continue src_dim = src_state_dim_per_tensor[i] dst_dim = dst_state_dim_per_tensor[i] @@ -981,7 +979,7 @@ def _send_mamba_state_slice( bytes_to_send = num_dims_to_send * src_bytes_per_dim src_addr = ( - prefill_state_data_ptrs[i] + src_state_data_ptrs[i] + src_item_len * int(prefill_state_indices[0]) + src_dim_offset ) @@ -1013,67 +1011,101 @@ def _send_mamba_state_slice( def maybe_send_extra( self, peer_name: str, - prefill_state_indices: List[int], - dst_state_data_ptrs: list[int], - dst_state_indices: List[int], + prefill_state_indices: List[List[int]], + dst_state_data_ptrs: List[List[int]], + dst_state_indices: List[List[int]], dst_gpu_id: int, notif: str, decode_tp_size: int, decode_tp_rank: int = 0, - dst_state_item_lens: list[int] | None = None, - dst_state_dim_per_tensor: list[int] | None = None, + dst_state_item_lens: List[List[int]] | None = None, + dst_state_dim_per_tensor: List[List[int]] | None = None, ): - """Send state or extra pool data with type-specific handling.""" - state_type = getattr(self.kv_args, "state_type", "none") - - if state_type == "mamba": - if self.attn_tp_size != decode_tp_size: - return self._send_mamba_state_slice( - peer_name, - prefill_state_indices, - dst_state_data_ptrs, - dst_state_indices, - dst_gpu_id, - notif, - dst_state_item_lens or [], - dst_state_dim_per_tensor or [], - decode_tp_size, - decode_tp_rank, - ) - return self._send_mamba_state( - peer_name, - prefill_state_indices, - dst_state_data_ptrs, - dst_state_indices, - dst_gpu_id, - notif, + """Send state per hybrid component, dispatching by state_type[i].""" + state_types = getattr(self.kv_args, "state_types", []) or [] + src_state_data_ptrs = self.kv_args.state_data_ptrs or [] + src_state_item_lens = self.kv_args.state_item_lens or [] + src_state_dim_per_tensor = ( + getattr(self.kv_args, "state_dim_per_tensor", []) or [] + ) + dst_state_item_lens = dst_state_item_lens or [] + dst_state_dim_per_tensor = dst_state_dim_per_tensor or [] + + handles = [] + for i, st in enumerate(state_types): + src_indices = ( + prefill_state_indices[i] if i < len(prefill_state_indices) else None ) - elif state_type in ["swa", "nsa"]: - if not self.is_mla_backend and self.attn_tp_size != decode_tp_size: - raise RuntimeError( - f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet." - ) - if len(prefill_state_indices) != len(dst_state_indices): - raise RuntimeError( - f"State index length mismatch: prefill={len(prefill_state_indices)}, " - f"dst={len(dst_state_indices)}" - ) - return self._send_kvcache_generic( - peer_name=peer_name, - src_data_ptrs=self.kv_args.state_data_ptrs, - dst_data_ptrs=dst_state_data_ptrs, - item_lens=self.kv_args.state_item_lens, - prefill_data_indices=np.array(prefill_state_indices, dtype=np.int32), - dst_data_indices=np.array(dst_state_indices, dtype=np.int32), - dst_gpu_id=dst_gpu_id, - notif=notif, + if src_indices is None or len(src_indices) == 0: + continue + src_ptrs = src_state_data_ptrs[i] if i < len(src_state_data_ptrs) else [] + src_lens = src_state_item_lens[i] if i < len(src_state_item_lens) else [] + src_dims = ( + src_state_dim_per_tensor[i] if i < len(src_state_dim_per_tensor) else [] ) - else: - if state_type != "none": + dst_ptrs = dst_state_data_ptrs[i] if i < len(dst_state_data_ptrs) else [] + dst_indices = dst_state_indices[i] if i < len(dst_state_indices) else [] + dst_lens = dst_state_item_lens[i] if i < len(dst_state_item_lens) else [] + dst_dims = ( + dst_state_dim_per_tensor[i] if i < len(dst_state_dim_per_tensor) else [] + ) + comp_notif = f"{notif}_{i}" + + if st == StateType.MAMBA: + if self.attn_tp_size != decode_tp_size: + h = self._send_mamba_state_slice( + peer_name, + src_indices, + src_ptrs, + src_lens, + src_dims, + dst_ptrs, + dst_indices, + dst_lens, + dst_dims, + dst_gpu_id, + comp_notif, + decode_tp_size, + decode_tp_rank, + ) + else: + h = self._send_mamba_state( + peer_name, + src_indices, + src_ptrs, + src_lens, + dst_ptrs, + dst_indices, + dst_gpu_id, + comp_notif, + ) + elif st in (StateType.SWA, StateType.NSA): + if not self.is_mla_backend and self.attn_tp_size != decode_tp_size: + raise RuntimeError( + f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {st.upper()} hybrid models yet." + ) + if len(src_indices) != len(dst_indices): + raise RuntimeError( + f"State index length mismatch at component {i}: " + f"prefill={len(src_indices)}, dst={len(dst_indices)}" + ) + h = self._send_kvcache_generic( + peer_name=peer_name, + src_data_ptrs=src_ptrs, + dst_data_ptrs=dst_ptrs, + item_lens=src_lens, + prefill_data_indices=np.array(src_indices, dtype=np.int32), + dst_data_indices=np.array(dst_indices, dtype=np.int32), + dst_gpu_id=dst_gpu_id, + notif=comp_notif, + ) + else: raise RuntimeError( - f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet." + f"PD Disaggregation via NIXL does NOT support {st} hybrid models yet." ) - return None + if h is not None: + handles.append(h) + return handles def add_transfer_request( self, @@ -1083,7 +1115,7 @@ def add_transfer_request( is_last: bool, chunk_id: int, aux_index: Optional[int] = None, - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, ): assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) @@ -1221,7 +1253,7 @@ def should_send_kv_chunk(self, num_pages: int, last_chunk: bool) -> bool: def send( self, kv_indices: npt.NDArray[np.int32], - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, ): if self._send_failed: return @@ -1310,7 +1342,7 @@ def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, - state_indices: Optional[List[int]] = None, + state_indices: Optional[List] = None, decode_prefix_len: Optional[int] = None, ): if self.bootstrap_infos is None: @@ -1329,6 +1361,13 @@ def send_metadata( logger.debug( f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}" ) + packed_state_indices = ( + pack_int_lists( + [(idx if idx is not None else []) for idx in state_indices], "i" + ) + if not is_dummy and state_indices is not None + else b"" + ) with lock: sock.send_multipart( [ @@ -1340,11 +1379,7 @@ def send_metadata( kv_indices.tobytes() if not is_dummy else b"", str(aux_index).encode("ascii"), str(self.required_dst_info_num).encode("ascii"), - ( - np.array(state_indices, dtype=np.int32).tobytes() - if not is_dummy and state_indices is not None - else b"" - ), + packed_state_indices, str(decode_prefix_len or 0).encode("ascii"), ] ) @@ -1404,19 +1439,14 @@ def _register_kv_args(self): packed_aux_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) - packed_state_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs - ) - - packed_state_item_lens = b"".join( - struct.pack("I", item_len) - for item_len in self.kv_mgr.kv_args.state_item_lens + packed_state_data_ptrs = pack_int_lists( + self.kv_mgr.kv_args.state_data_ptrs or [], "Q" ) - state_dim_per_tensor = getattr( - self.kv_mgr.kv_args, "state_dim_per_tensor", [] + packed_state_item_lens = pack_int_lists( + self.kv_mgr.kv_args.state_item_lens or [], "I" ) - packed_state_dim_per_tensor = b"".join( - struct.pack("I", dim) for dim in state_dim_per_tensor + packed_state_dim_per_tensor = pack_int_lists( + getattr(self.kv_mgr.kv_args, "state_dim_per_tensor", []) or [], "I" ) with lock: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index bcf77768c436..eec95ef9336e 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -27,6 +27,7 @@ import torch from sglang.srt.disaggregation.base import KVPoll +from sglang.srt.disaggregation.base.conn import StateType from sglang.srt.disaggregation.common.conn import CommonKVManager from sglang.srt.disaggregation.utils import ( FAKE_BOOTSTRAP_HOST, @@ -48,14 +49,12 @@ Req, ScheduleBatch, ) -from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool from sglang.srt.mem_cache.common import ( kv_to_page_indices, kv_to_page_num, maybe_cache_unfinished_req, release_kv_cache, ) -from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool from sglang.srt.observability.req_time_stats import set_schedule_time_batch if TYPE_CHECKING: @@ -177,7 +176,13 @@ def _init_kv_manager(self) -> CommonKVManager: kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id - setup_state_kv_args(kv_args, self.token_to_kv_pool, self.draft_token_to_kv_pool) + req_to_token_pool = getattr(self.scheduler, "req_to_token_pool", None) + setup_state_kv_args( + kv_args, + self.token_to_kv_pool, + self.draft_token_to_kv_pool, + req_to_token_pool=req_to_token_pool, + ) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager = kv_manager_class( @@ -771,52 +776,56 @@ def send_kv_chunk( .cpu() .numpy() ) - state_indices = None + state_indices: Optional[List] = None if last_chunk: self.disagg_metadata_buffers.set_buf(req) - # Prepare extra pool indices for hybrid models - if isinstance( - self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool - ): - # Mamba hybrid model: send single mamba state index - state_indices = [ + seq_len = len(req.fill_ids) + + def _mamba_payload(): + return [ self.req_to_token_pool.req_index_to_mamba_index_mapping[ req.req_pool_idx ] .cpu() .numpy() ] - elif isinstance( - self.token_to_kv_pool_allocator.get_kvcache(), BaseSWAKVPool - ): - # SWA hybrid model: send last window KV indices - seq_len = len(req.fill_ids) + + def _swa_payload(): window_size = self.sliding_window_size window_start = max(0, seq_len - window_size) window_start = (window_start // page_size) * page_size - window_kv_indices_full = self.req_to_token_pool.req_to_token[ req.req_pool_idx, window_start:seq_len ] - - # Translate to SWA pool indices window_kv_indices_swa = ( self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa( window_kv_indices_full ) ) - state_indices = window_kv_indices_swa.cpu().numpy() - state_indices = kv_to_page_indices(state_indices, page_size) - elif isinstance( - self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool - ): - seq_len = len(req.fill_ids) + return kv_to_page_indices( + window_kv_indices_swa.cpu().numpy(), page_size + ) + + def _nsa_payload(): kv_indices_full = self.req_to_token_pool.req_to_token[ req.req_pool_idx, :seq_len ] - state_indices = kv_indices_full.cpu().numpy() - state_indices = kv_to_page_indices(state_indices, page_size) + return kv_to_page_indices(kv_indices_full.cpu().numpy(), page_size) + + state_types = ( + self.disagg_prefill_bootstrap_queue.kv_manager.kv_args.state_types + ) + state_indices = [] + for st in state_types: + if st == StateType.MAMBA: + state_indices.append(_mamba_payload()) + elif st == StateType.SWA: + state_indices.append(_swa_payload()) + elif st == StateType.NSA: + state_indices.append(_nsa_payload()) + else: + state_indices.append(None) page_indices = kv_to_page_indices(kv_indices, page_size) if not req.disagg_kv_sender.should_send_kv_chunk(len(page_indices), last_chunk): diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index ea5e379993c7..2a0b4beec070 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -5,7 +5,7 @@ from collections import deque from contextlib import nullcontext from enum import Enum -from typing import TYPE_CHECKING, Literal, Optional, Tuple, Type, overload +from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, overload import numpy as np import torch @@ -15,7 +15,7 @@ from sglang.srt.utils import is_npu if TYPE_CHECKING: - from sglang.srt.disaggregation.base.conn import KVArgs + from sglang.srt.disaggregation.base.conn import KVArgs, StateType from sglang.srt.disaggregation.common.conn import ( CommonKVBootstrapServer, CommonKVManager, @@ -532,57 +532,92 @@ def is_mla_backend(target_kv_pool) -> bool: return isinstance(target_kv_pool, (MLATokenToKVPool, DeepSeekV4TokenToKVPool)) +def append_state_component( + kv_args: KVArgs, + state_type: StateType, + data_ptrs: List[int], + data_lens: List[int], + item_lens: List[int], + dim_per_tensor: Optional[List[int]] = None, +) -> None: + """Append one state component. Caller orders state_types consistently + on prefill and decode sides.""" + kv_args.state_types.append(state_type) + kv_args.state_data_ptrs.append(data_ptrs) + kv_args.state_data_lens.append(data_lens) + kv_args.state_item_lens.append(item_lens) + kv_args.state_dim_per_tensor.append(dim_per_tensor or []) + + def setup_state_kv_args( kv_args: KVArgs, token_to_kv_pool, draft_token_to_kv_pool=None, + req_to_token_pool=None, ) -> None: """Populate ``kv_args`` state-buffer fields from the given pool. - Shared by prefill and decode bootstrap paths so the state_type dispatch lives in one place. """ + from sglang.srt.disaggregation.base.conn import StateType from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool - if not hasattr(token_to_kv_pool, "get_state_buf_infos"): - kv_args.state_data_ptrs = [] - kv_args.state_data_lens = [] - kv_args.state_item_lens = [] - kv_args.state_type = "none" - return + kv_args.state_types = [] + kv_args.state_data_ptrs = [] + kv_args.state_data_lens = [] + kv_args.state_item_lens = [] + kv_args.state_dim_per_tensor = [] - state_data_ptrs, state_data_lens, state_item_lens = ( - token_to_kv_pool.get_state_buf_infos() - ) - kv_args.state_data_ptrs = state_data_ptrs - kv_args.state_data_lens = state_data_lens - kv_args.state_item_lens = state_item_lens - - # DeepSeekV4TokenToKVPool inherits BaseSWAKVPool; its heterogeneous - # state list is described per-entry via get_state_buf_infos. - if isinstance(token_to_kv_pool, BaseSWAKVPool): - kv_args.state_type = "swa" - elif isinstance(token_to_kv_pool, HybridLinearKVPool): - kv_args.state_type = "mamba" - # Get state dimension info for cross-TP slice transfer - if hasattr(token_to_kv_pool, "get_state_dim_per_tensor"): - kv_args.state_dim_per_tensor = token_to_kv_pool.get_state_dim_per_tensor() - elif isinstance(token_to_kv_pool, NSATokenToKVPool): - kv_args.state_type = "nsa" - if draft_token_to_kv_pool is not None and isinstance( - draft_token_to_kv_pool, NSATokenToKVPool - ): - ( - draft_state_data_ptrs, - draft_state_data_lens, - draft_state_item_lens, - ) = draft_token_to_kv_pool.get_state_buf_infos() - kv_args.state_data_ptrs += draft_state_data_ptrs - kv_args.state_data_lens += draft_state_data_lens - kv_args.state_item_lens += draft_state_item_lens - else: - kv_args.state_type = "none" + if hasattr(token_to_kv_pool, "get_state_buf_infos"): + data_ptrs, data_lens, item_lens = token_to_kv_pool.get_state_buf_infos() + + # DeepSeekV4TokenToKVPool inherits BaseSWAKVPool; its heterogeneous + # state list is described per-entry via get_state_buf_infos. + if isinstance(token_to_kv_pool, BaseSWAKVPool): + append_state_component( + kv_args, StateType.SWA, data_ptrs, data_lens, item_lens + ) + elif isinstance(token_to_kv_pool, HybridLinearKVPool): + dim = ( + token_to_kv_pool.get_state_dim_per_tensor() + if hasattr(token_to_kv_pool, "get_state_dim_per_tensor") + else None + ) + append_state_component( + kv_args, StateType.MAMBA, data_ptrs, data_lens, item_lens, dim + ) + elif isinstance(token_to_kv_pool, NSATokenToKVPool): + if draft_token_to_kv_pool is not None and isinstance( + draft_token_to_kv_pool, NSATokenToKVPool + ): + ( + draft_data_ptrs, + draft_data_lens, + draft_item_lens, + ) = draft_token_to_kv_pool.get_state_buf_infos() + data_ptrs = data_ptrs + draft_data_ptrs + data_lens = data_lens + draft_data_lens + item_lens = item_lens + draft_item_lens + append_state_component( + kv_args, StateType.NSA, data_ptrs, data_lens, item_lens + ) + + if ( + StateType.MAMBA not in kv_args.state_types + and req_to_token_pool is not None + and hasattr(req_to_token_pool, "get_state_buf_infos") + ): + data_ptrs, data_lens, item_lens = req_to_token_pool.get_state_buf_infos() + if data_ptrs: + dim = ( + req_to_token_pool.get_state_dim_per_tensor() + if hasattr(req_to_token_pool, "get_state_dim_per_tensor") + else None + ) + append_state_component( + kv_args, StateType.MAMBA, data_ptrs, data_lens, item_lens, dim + ) def prepare_abort(req: Req, error_message: str, status_code=None): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 5d62bacb0c7b..442b05b4e5e8 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -624,6 +624,12 @@ def mamba2_layer_cache(self, layer_id: int): def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState: return self.mamba_pool.get_speculative_mamba2_params_all_layers() + def get_state_buf_infos(self): + return self.mamba_pool.get_contiguous_buf_infos() + + def get_state_dim_per_tensor(self): + return self.mamba_pool.get_state_dim_per_tensor() + def get_mamba_ping_pong_other_idx(self, mamba_next_track_idx: int) -> int: if self.mamba_ping_pong_track_buffer_size == 2: return 1 - mamba_next_track_idx diff --git a/test/registered/unit/disaggregation/test_disaggregation_wire.py b/test/registered/unit/disaggregation/test_disaggregation_wire.py new file mode 100644 index 000000000000..2c635b1d625a --- /dev/null +++ b/test/registered/unit/disaggregation/test_disaggregation_wire.py @@ -0,0 +1,49 @@ +import unittest + +import numpy as np + +from sglang.srt.disaggregation.common.utils import ( + pack_int_lists, + pack_list_of_buffers, + unpack_int_lists, + unpack_list_of_buffers, +) +from sglang.test.ci.ci_register import register_cpu_ci + +register_cpu_ci(est_time=2, suite="stage-a-test-cpu") + + +class TestDisaggregationWire(unittest.TestCase): + def test_int_lists_roundtrip(self): + cases = [ + ("Q", [[1, 2, 3], [4]]), + ("I", [[10, 20], [30, 40, 50]]), + ("i", [[-1, 2], [3, -4, 5]]), + ] + for fmt, sample in cases: + packed = pack_int_lists(sample, fmt) + self.assertEqual(unpack_int_lists(packed, fmt), sample, msg=fmt) + + def test_pack_accepts_ndarray(self): + arrs = [ + np.array([1, 2, 3], dtype=np.int32), + np.array([4, 5], dtype=np.int32), + ] + packed = pack_int_lists(arrs, "i") + self.assertEqual(unpack_int_lists(packed, "i"), [[1, 2, 3], [4, 5]]) + + def test_empty_outer_list(self): + self.assertEqual(pack_int_lists([], "Q"), b"") + self.assertEqual(unpack_int_lists(b"", "Q"), []) + + def test_empty_inner_list(self): + packed = pack_int_lists([[]], "I") + self.assertEqual(unpack_int_lists(packed, "I"), [[]]) + + def test_list_of_buffers_roundtrip(self): + bufs = [b"abc", b"", b"de", b"x" * 17] + self.assertEqual(unpack_list_of_buffers(pack_list_of_buffers(bufs)), bufs) + + +if __name__ == "__main__": + unittest.main()