Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,14 @@ def __init__(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)

def maybe_prepare_async_kv(self, sch, batch):
"""Optional hook for layerwise async KV transfers.

Default implementation returns None (feature unsupported).
"""

return None

def check_status(self, bootstrap_room: int) -> KVPoll:
return self.request_status[bootstrap_room]

Expand Down
633 changes: 633 additions & 0 deletions python/sglang/srt/disaggregation/mooncake/async_kv_mixin.py

Large diffs are not rendered by default.

106 changes: 106 additions & 0 deletions python/sglang/srt/disaggregation/mooncake/async_kv_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from __future__ import annotations

import dataclasses
import logging
import os
import queue
import threading
from typing import Callable, Optional, Tuple

import numpy as np
import numpy.typing as npt

from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous


logger = logging.getLogger(__name__)


@dataclasses.dataclass
class TransferKVChunkSet:
rooms: Tuple[int, ...] = dataclasses.field(default_factory=tuple)
prefill_kv_indices: Tuple[npt.NDArray[np.int64], ...] = dataclasses.field(
default_factory=tuple
)
index_slices: Tuple[slice, ...] = dataclasses.field(default_factory=tuple)
prefill_state_indices: Tuple[int, ...] = dataclasses.field(default_factory=tuple)


@dataclasses.dataclass
class AsyncInfo:
layer_ids: Tuple[int, ...] = dataclasses.field(default_factory=tuple)
kv_chunk_info: TransferKVChunkSet = dataclasses.field(
default_factory=TransferKVChunkSet
)


class StreamAsyncSubmitter:
"""Single-worker async submitter with counters.

The worker thread runs as a daemon and never exits. We use monotonically
increasing counters to let the caller wait until submitted work has finished.
"""

def __init__(self, submit_func: Callable[[], None]):
self._submit_func = submit_func
self._queue: queue.SimpleQueue[None] = queue.SimpleQueue()
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
self._submitted = 0
self._finished = 0
self._exc: Optional[BaseException] = None
threading.Thread(target=self._worker, daemon=True).start()

def _worker(self):
while True:
self._queue.get()
try:
self._submit_func()
except BaseException as e:
# Persist the exception so waiters can fail fast.
with self._cond:
self._exc = e
self._cond.notify_all()
logger.exception("Unhandled exception in StreamAsyncSubmitter worker.")
finally:
with self._cond:
self._finished += 1
self._cond.notify_all()

def step_async(self) -> int:
with self._cond:
if self._exc is not None:
raise RuntimeError("StreamAsyncSubmitter worker has failed") from self._exc
self._submitted += 1
self._queue.put(None)
return self._submitted

def get_step_count(self) -> int:
with self._cond:
return self._submitted

def wait_sent_finish(self, target_count: int) -> None:
with self._cond:
if self._exc is not None:
raise RuntimeError("StreamAsyncSubmitter worker has failed") from self._exc
while self._finished < target_count:
self._cond.wait()
if self._exc is not None:
raise RuntimeError("StreamAsyncSubmitter worker has failed") from self._exc


def cached_group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
):
# NOTE: despite the name, this function is not memoized; it only normalizes
# dtypes before calling the grouping helper.
src = np.asarray(src_indices, dtype=np.int32)
dst = np.asarray(dst_indices, dtype=np.int32)
return group_concurrent_contiguous(src, dst)


def env_int(name: str, default: str) -> int:
try:
return int(os.getenv(name, default))
except Exception:
return int(default)
190 changes: 104 additions & 86 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FastQueue,
group_concurrent_contiguous,
)
from sglang.srt.disaggregation.mooncake.async_kv_mixin import MooncakeKVAsyncMixin
from sglang.srt.disaggregation.mooncake.utils import (
check_mooncake_custom_mem_pool_enabled,
)
Expand Down Expand Up @@ -184,7 +185,7 @@ def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):
return


class MooncakeKVManager(CommonKVManager):
class MooncakeKVManager(MooncakeKVAsyncMixin, CommonKVManager):
AUX_DATA_HEADER = b"AUX_DATA"

def __init__(
Expand All @@ -198,6 +199,9 @@ def __init__(
self.init_engine()
self.register_buffer_to_engine()
self.enable_staging = envs.SGLANG_DISAGG_STAGING_BUFFER.get()

# Async KV (layerwise) state lives in MooncakeKVAsyncMixin.
self._async_kv_init_state()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.start_prefill_thread()
self.session_failures = defaultdict(int)
Expand Down Expand Up @@ -246,6 +250,9 @@ def __init__(
),
daemon=True,
).start()

if envs.SGLANG_MOONCAKE_ASYNC_KV.get():
self._async_kv_enable()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self._staging_ctx = DecodeStagingContext() if self.enable_staging else None
if self.enable_staging:
Expand Down Expand Up @@ -276,6 +283,12 @@ def register_buffer_to_engine(self):
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)

def _use_async_for_room(self, room: int) -> bool:
return self._async_use_for_room(room)

def _flush_all_layers(self, rid: int) -> None:
return self._async_flush_all_layers(rid)

# ------------------------------------------------------------------
# Staging buffer methods (all delegate to staging_handler.py)
# ------------------------------------------------------------------
Expand Down Expand Up @@ -1173,6 +1186,11 @@ def transfer_worker(
+ self.pp_rank * self.attn_cp_size
+ self.attn_cp_rank
)

use_async = self._use_async_for_room(kv_chunk.room)
if use_async and kv_chunk.is_last_chunk:
# Wait for any in-flight per-layer submits before finalizing the request.
self._flush_all_layers(kv_chunk.room)
# When staging transfer is not yet ready (watermark/allocation pending),
# the chunk is re-enqueued and we break out of the req loop to retry later.
staging_deferred = False
Expand All @@ -1195,107 +1213,107 @@ def transfer_worker(
)
break

chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]

# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
: len(chunked_dst_kv_indice)
]

target_rank_registration_info: KVArgsRegisterInfo = (
self.decode_kv_args_table[req.mooncake_session_id]
)
if self.is_mla_backend or (
self.attn_tp_size
== target_rank_registration_info.dst_attn_tp_size
):
if target_rank_registration_info.enable_hisparse:
ret = self.send_kvcache_hisparse(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
req.dst_kv_indices,
kv_chunk.index_slice,
if not use_async:
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]

# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(kv_chunk.prefill_kv_indices):
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
: len(chunked_dst_kv_indice)
]

if self.is_mla_backend or (
self.attn_tp_size
== target_rank_registration_info.dst_attn_tp_size
):
if target_rank_registration_info.enable_hisparse:
ret = self.send_kvcache_hisparse(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
req.dst_kv_indices,
kv_chunk.index_slice,
executor,
)
else:
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice,
executor,
)
elif (
self.enable_staging
and staging_strategy is not None
and target_rank_registration_info.staging is not None
):
ret, deferred = self._do_staging_transfer(
staging_strategy,
kv_chunk,
req,
target_rank_registration_info,
chunked_dst_kv_indice,
executor,
queue,
prefill_unique_rank,
)
if deferred:
staging_deferred = True
# Chunk re-enqueued; stop processing remaining reqs for this chunk
break
else:
ret = self.send_kvcache(
ret = self.send_kvcache_slice(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice,
target_rank_registration_info.dst_tp_rank,
target_rank_registration_info.dst_attn_tp_size,
target_rank_registration_info.dst_kv_item_len,
executor,
)
elif (
self.enable_staging
and staging_strategy is not None
and target_rank_registration_info.staging is not None
):
ret, deferred = self._do_staging_transfer(
staging_strategy,
kv_chunk,
req,
target_rank_registration_info,
chunked_dst_kv_indice,
executor,
queue,
prefill_unique_rank,
)
if deferred:
staging_deferred = True
# Chunk re-enqueued; stop processing remaining reqs for this chunk
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if self.session_failures[req.mooncake_session_id] >= 1:
self.failed_sessions.add(req.mooncake_session_id)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to "
f"{NetworkAddress(req.endpoint, req.dst_port).to_host_port_str()}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
prefill_unique_rank,
)
break
else:
ret = self.send_kvcache_slice(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice,
target_rank_registration_info.dst_tp_rank,
target_rank_registration_info.dst_attn_tp_size,
target_rank_registration_info.dst_kv_item_len,
executor,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if self.session_failures[req.mooncake_session_id] >= 1:
self.failed_sessions.add(req.mooncake_session_id)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to "
f"{NetworkAddress(req.endpoint, req.dst_port).to_host_port_str()}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
prefill_unique_rank,
)
break

if kv_chunk.is_last_chunk:
if kv_chunk.state_indices is not None:
self.maybe_send_extra(
req,
kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs,
executor,
target_rank_registration_info,
)
if not use_async:
self.maybe_send_extra(
req,
kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs,
executor,
target_rank_registration_info,
)

# Only the last chunk we need to send the aux data
ret = self.send_aux(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
Req,
ScheduleBatch,
)
from sglang.srt.environ import envs
from sglang.srt.mem_cache.common import release_kv_cache
from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool
from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class Envs:
ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE = EnvBool(False)
ASCEND_NPU_PHY_ID = EnvInt(-1)
SGLANG_MOONCAKE_SEND_AUX_TCP = EnvBool(False)
SGLANG_MOONCAKE_ASYNC_KV = EnvBool(False)

# Mooncake Store
SGLANG_HICACHE_MOONCAKE_CONFIG_PATH = EnvStr(None)
Expand Down
Loading
Loading