Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/sglang/srt/disaggregation/ascend/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def register_buffer_to_engine(self):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
# Batch register state/extra pool data buffers
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
self.engine.batch_register(
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
for component_ptrs, component_lens in zip(
self.kv_args.state_data_ptrs or [],
self.kv_args.state_data_lens or [],
):
self.engine.batch_register(component_ptrs, component_lens)

def send_kvcache(
self,
Expand Down
23 changes: 15 additions & 8 deletions python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional

Expand All @@ -13,6 +14,12 @@
from sglang.srt.disaggregation.utils import DisaggregationMode


class StateType(str, enum.Enum):
MAMBA = "mamba"
SWA = "swa"
NSA = "nsa"


@dataclasses.dataclass
class KVTransferMetric:
# Backends that cannot isolate transfer latency can leave this as None.
Expand All @@ -28,12 +35,12 @@ class KVArgs:
aux_data_ptrs: List[int]
aux_data_lens: List[int]
aux_item_lens: List[int]
state_data_ptrs: List[int]
state_data_lens: List[int]
state_item_lens: List[int]
state_type: str # "none", "mamba", "swa", "nsa"
# for mamba state different tp slice transfer
state_dim_per_tensor: List[int] # dimension to slice for each state tensor
state_types: List[StateType]
state_data_ptrs: List[List[int]]
state_data_lens: List[List[int]]
state_item_lens: List[List[int]]
# Per-tensor TP slice dim, used when prefill/decode attn_tp_size differ.
state_dim_per_tensor: List[List[int]]
ib_device: str
ib_traffic_class: str
gpu_id: int
Expand Down Expand Up @@ -96,7 +103,7 @@ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
):
"""
Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server.
Expand Down Expand Up @@ -154,7 +161,7 @@ def send_metadata(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
decode_prefix_len: Optional[int] = None,
):
"""
Expand Down
12 changes: 7 additions & 5 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
):
self.kv_args = args
self.kv_item_lens_sum = sum(args.kv_item_lens)
self.state_item_lens_sum = sum(args.state_item_lens)
self.state_item_lens_sum = sum(x for comp in args.state_item_lens for x in comp)
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
self.server_args = server_args
Expand Down Expand Up @@ -520,16 +520,18 @@ def get_transfer_metric(self) -> KVTransferMetric:
def _record_transfer_indices(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]],
state_indices: Optional[List],
):
self._transfer_num_kv_indices += len(kv_indices)
if state_indices is not None:
self._transfer_num_state_indices += len(state_indices)
if state_indices:
for component_indices in state_indices:
if component_indices is not None:
self._transfer_num_state_indices += len(component_indices)

def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
):
pass

Expand Down
34 changes: 34 additions & 0 deletions python/sglang/srt/disaggregation/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import struct
import threading
from collections import deque
from typing import List, Tuple
Expand All @@ -6,6 +7,39 @@
import numpy.typing as npt


def pack_list_of_buffers(buffers: List[bytes]) -> bytes:
if not buffers:
return b""
n = len(buffers)
header = struct.pack(f"<{n+1}I", n, *(len(b) for b in buffers))
return header + b"".join(buffers)


def unpack_list_of_buffers(buf: bytes) -> List[bytes]:
if buf == b"":
return []
(n,) = struct.unpack("<I", buf[:4])
lens = struct.unpack(f"<{n}I", buf[4 : 4 + 4 * n])
out = []
offset = 4 + 4 * n
for length in lens:
out.append(buf[offset : offset + length])
offset += length
return out


def pack_int_lists(lists, fmt: str) -> bytes:
return pack_list_of_buffers([struct.pack(f"<{len(a)}{fmt}", *a) for a in lists])


def unpack_int_lists(buf: bytes, fmt: str) -> List[List[int]]:
width = struct.calcsize(fmt)
return [
list(struct.unpack(f"<{len(b)//width}{fmt}", b))
for b in unpack_list_of_buffers(buf)
]


class FastQueue:
def __init__(self):
self._buf = deque()
Expand Down
54 changes: 33 additions & 21 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import KVPoll
from sglang.srt.disaggregation.base.conn import StateType
from sglang.srt.disaggregation.common.conn import CommonKVManager, CommonKVReceiver
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
Expand All @@ -56,17 +57,14 @@
from sglang.srt.managers.utils import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, EvictParams
from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool
from sglang.srt.mem_cache.common import (
kv_to_page_indices,
page_align_floor,
release_kv_cache,
)
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
HybridReqToTokenPool,
KVCache,
NSATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.observability.req_time_stats import (
Expand Down Expand Up @@ -366,7 +364,12 @@ def _init_kv_manager(self) -> CommonKVManager:
self.metadata_buffers.get_buf_infos()
)

setup_state_kv_args(kv_args, self.token_to_kv_pool, self.draft_token_to_kv_pool)
setup_state_kv_args(
kv_args,
self.token_to_kv_pool,
self.draft_token_to_kv_pool,
req_to_token_pool=getattr(self, "req_to_token_pool", None),
)

kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
Expand Down Expand Up @@ -809,45 +812,54 @@ def pop_preallocated(
)
page_size = self.token_to_kv_pool_allocator.page_size

# Prepare extra pool indices for hybrid models
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
# Mamba hybrid model: single mamba state index
state_indices = [
seq_len = len(decode_req.req.origin_input_ids)

def _mamba_payload():
return [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
decode_req.req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool, BaseSWAKVPool):
seq_len = len(decode_req.req.origin_input_ids)
window_size = self.scheduler.sliding_window_size

def _swa_payload():
window_size = self.scheduler.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = page_align_floor(window_start, page_size)
window_kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, window_start:seq_len
]

# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
seq_len = len(decode_req.req.origin_input_ids)
return kv_to_page_indices(
window_kv_indices_swa.cpu().numpy(), page_size
)

def _nsa_payload():
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
# Indexer lives on device pool; always use device page_size
device_page_size = self.token_to_kv_pool.page_size
state_indices = kv_to_page_indices(state_indices, device_page_size)
else:
state_indices = None
return kv_to_page_indices(
kv_indices_full.cpu().numpy(), device_page_size
)

state_types = self.kv_manager.kv_args.state_types
state_indices: Optional[List] = []
for st in state_types:
if st == StateType.MAMBA:
state_indices.append(_mamba_payload())
elif st == StateType.SWA:
state_indices.append(_swa_payload())
elif st == StateType.NSA:
state_indices.append(_nsa_payload())
else:
state_indices.append(None)

decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/disaggregation/fake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def init(
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
):
self.has_sent = True
logger.debug(
Expand Down Expand Up @@ -111,7 +111,7 @@ def send_metadata(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
decode_prefix_len: Optional[int] = None,
):
self.has_sent_metadata = True
Expand Down
Loading
Loading