Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/sglang/srt/disaggregation/ascend/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def register_buffer_to_engine(self):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
# Batch register state/extra pool data buffers
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
self.engine.batch_register(
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
for component_ptrs, component_lens in zip(
self.kv_args.state_data_ptrs or [],
self.kv_args.state_data_lens or [],
):
self.engine.batch_register(component_ptrs, component_lens)

def send_kvcache(
self,
Expand Down
23 changes: 15 additions & 8 deletions python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional

Expand All @@ -13,6 +14,12 @@
from sglang.srt.disaggregation.utils import DisaggregationMode


class StateType(str, enum.Enum):
MAMBA = "mamba"
SWA = "swa"
NSA = "nsa"


@dataclasses.dataclass
class KVTransferMetric:
# Backends that cannot isolate transfer latency can leave this as None.
Expand All @@ -28,12 +35,12 @@ class KVArgs:
aux_data_ptrs: List[int]
aux_data_lens: List[int]
aux_item_lens: List[int]
state_data_ptrs: List[int]
state_data_lens: List[int]
state_item_lens: List[int]
state_type: str # "none", "mamba", "swa", "nsa"
# for mamba state different tp slice transfer
state_dim_per_tensor: List[int] # dimension to slice for each state tensor
state_types: List[StateType]
state_data_ptrs: List[List[int]]
state_data_lens: List[List[int]]
state_item_lens: List[List[int]]
# Per-tensor TP slice dim, used when prefill/decode attn_tp_size differ.
state_dim_per_tensor: List[List[int]]
ib_device: str
ib_traffic_class: str
gpu_id: int
Expand Down Expand Up @@ -96,7 +103,7 @@ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
):
"""
Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server.
Expand Down Expand Up @@ -154,7 +161,7 @@ def send_metadata(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
decode_prefix_len: Optional[int] = None,
):
"""
Expand Down
12 changes: 7 additions & 5 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
):
self.kv_args = args
self.kv_item_lens_sum = sum(args.kv_item_lens)
self.state_item_lens_sum = sum(args.state_item_lens)
self.state_item_lens_sum = sum(x for comp in args.state_item_lens for x in comp)
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
self.server_args = server_args
Expand Down Expand Up @@ -520,16 +520,18 @@ def get_transfer_metric(self) -> KVTransferMetric:
def _record_transfer_indices(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]],
state_indices: Optional[List],
):
self._transfer_num_kv_indices += len(kv_indices)
if state_indices is not None:
self._transfer_num_state_indices += len(state_indices)
if state_indices:
for component_indices in state_indices:
if component_indices is not None:
self._transfer_num_state_indices += len(component_indices)

def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
):
pass

Expand Down
34 changes: 34 additions & 0 deletions python/sglang/srt/disaggregation/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import struct
import threading
from collections import deque
from typing import List, Tuple
Expand All @@ -6,6 +7,39 @@
import numpy.typing as npt


def pack_list_of_buffers(buffers: List[bytes]) -> bytes:
if not buffers:
return b""
n = len(buffers)
header = struct.pack(f"<{n+1}I", n, *(len(b) for b in buffers))
return header + b"".join(buffers)


def unpack_list_of_buffers(buf: bytes) -> List[bytes]:
if buf == b"":
return []
(n,) = struct.unpack("<I", buf[:4])
lens = struct.unpack(f"<{n}I", buf[4 : 4 + 4 * n])
out = []
offset = 4 + 4 * n
for length in lens:
out.append(buf[offset : offset + length])
offset += length
return out


def pack_int_lists(lists, fmt: str) -> bytes:
return pack_list_of_buffers([struct.pack(f"<{len(a)}{fmt}", *a) for a in lists])


def unpack_int_lists(buf: bytes, fmt: str) -> List[List[int]]:
width = struct.calcsize(fmt)
return [
list(struct.unpack(f"<{len(b)//width}{fmt}", b))
for b in unpack_list_of_buffers(buf)
]


class FastQueue:
def __init__(self):
self._buf = deque()
Expand Down
54 changes: 33 additions & 21 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import KVPoll
from sglang.srt.disaggregation.base.conn import StateType
from sglang.srt.disaggregation.common.conn import CommonKVManager, CommonKVReceiver
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
Expand All @@ -56,17 +57,14 @@
from sglang.srt.managers.utils import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, EvictParams
from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool
from sglang.srt.mem_cache.common import (
kv_to_page_indices,
page_align_floor,
release_kv_cache,
)
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
HybridReqToTokenPool,
KVCache,
NSATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.observability.req_time_stats import (
Expand Down Expand Up @@ -366,7 +364,12 @@ def _init_kv_manager(self) -> CommonKVManager:
self.metadata_buffers.get_buf_infos()
)

setup_state_kv_args(kv_args, self.token_to_kv_pool, self.draft_token_to_kv_pool)
setup_state_kv_args(
kv_args,
self.token_to_kv_pool,
self.draft_token_to_kv_pool,
req_to_token_pool=getattr(self, "req_to_token_pool", None),
)

kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
Expand Down Expand Up @@ -809,45 +812,54 @@ def pop_preallocated(
)
page_size = self.token_to_kv_pool_allocator.page_size

# Prepare extra pool indices for hybrid models
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
# Mamba hybrid model: single mamba state index
state_indices = [
seq_len = len(decode_req.req.origin_input_ids)

def _mamba_payload():
return [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
decode_req.req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool, BaseSWAKVPool):
seq_len = len(decode_req.req.origin_input_ids)
window_size = self.scheduler.sliding_window_size

def _swa_payload():
window_size = self.scheduler.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = page_align_floor(window_start, page_size)
window_kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, window_start:seq_len
]

# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
seq_len = len(decode_req.req.origin_input_ids)
return kv_to_page_indices(
window_kv_indices_swa.cpu().numpy(), page_size
)

def _nsa_payload():
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
# Indexer lives on device pool; always use device page_size
device_page_size = self.token_to_kv_pool.page_size
state_indices = kv_to_page_indices(state_indices, device_page_size)
else:
state_indices = None
return kv_to_page_indices(
kv_indices_full.cpu().numpy(), device_page_size
)

state_types = self.kv_manager.kv_args.state_types
state_indices: Optional[list] = []
Comment thread
ispobock marked this conversation as resolved.
Outdated
for st in state_types:
if st == StateType.MAMBA:
state_indices.append(_mamba_payload())
elif st == StateType.SWA:
state_indices.append(_swa_payload())
elif st == StateType.NSA:
state_indices.append(_nsa_payload())
else:
state_indices.append(None)

decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/disaggregation/fake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def init(
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
):
self.has_sent = True
logger.debug(
Expand Down Expand Up @@ -111,7 +111,7 @@ def send_metadata(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
state_indices: Optional[List] = None,
decode_prefix_len: Optional[int] = None,
):
self.has_sent_metadata = True
Expand Down
Loading
Loading