Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 151 additions & 137 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
FastQueue,
group_concurrent_contiguous,
)
from sglang.srt.server_args import ServerArgs
Expand Down Expand Up @@ -151,23 +152,38 @@ def __init__(
self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread()
self._register_to_bootstrap()
self.session_failures = defaultdict(int)
self.failed_sessions = set()
self.session_lock = threading.Lock()

# Determine the number of threads to use for kv sender
cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor(
get_int_env_var(
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
min(max(1, cpu_count // 8), 8),
)
transfer_thread_pool_size = get_int_env_var(
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
min(max(4, int(0.75 * cpu_count) // 8), 12),
)
transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
self.transfer_queues: List[FastQueue] = [
FastQueue() for _ in range(transfer_queue_size)
]
assert transfer_thread_pool_size >= transfer_queue_size, (
f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be "
f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}."
)
self.executors = [
concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for _ in range(transfer_queue_size)
]
for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True
).start()

self.bootstrap_time_out = get_int_env_var(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
)
Expand All @@ -183,7 +199,7 @@ def __init__(
)
# Heartbeat failure should be at least 1
self.max_failures = max(
int(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
)
self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
Expand Down Expand Up @@ -220,6 +236,7 @@ def send_kvcache(
prefill_kv_indices: npt.NDArray[np.int64],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
executor: concurrent.futures.ThreadPoolExecutor,
):
# Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
Expand Down Expand Up @@ -251,7 +268,7 @@ def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
return 0

futures = [
self.executor.submit(
executor.submit(
process_layer,
src_ptr,
dst_ptr,
Expand Down Expand Up @@ -298,6 +315,123 @@ def sync_status_to_decode_endpoint(
]
)

def transfer_worker(
self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor
):
while True:
try:
kv_chunk: TransferKVChunk = queue.get()
reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = []
dst_ranks_infos = []
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
)
break

chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]

# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice)
]
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)

ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_kv_ptrs,
chunked_dst_kv_indice,
executor,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if self.session_failures[req.mooncake_session_id] >= 1:
self.failed_sessions.add(req.mooncake_session_id)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room, KVPoll.Failed
)
break

if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
)
polls.append(True if ret == 0 else False)
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
)

# Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num:
status = KVPoll.Success if all(polls) else KVPoll.Failed
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint(
endpoint, dst_port, room, status
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)

if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
if kv_chunk.room in self.transfer_infos:
self.transfer_infos.pop(kv_chunk.room)

except queue.Empty:
continue
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)

def start_prefill_thread(self):
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
Expand Down Expand Up @@ -335,134 +469,7 @@ def bootstrap_thread():
if len(self.transfer_infos[room]) == required_dst_info_num:
self.update_status(room, KVPoll.WaitingForInput)

def transfer_thread():
# TODO: Shall we use KVPoll.Transferring state?
while True:
try:
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = []
dst_ranks_infos = []
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
)
break

chunked_dst_kv_indice = req.dst_kv_indices[
kv_chunk.index_slice
]

# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
kv_chunk.prefill_kv_indices = (
kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice)
]
)
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)

ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_kv_ptrs,
chunked_dst_kv_indice,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if (
self.session_failures[req.mooncake_session_id]
>= 1
):
self.failed_sessions.add(
req.mooncake_session_id
)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room, KVPoll.Failed
)
break

if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
)
polls.append(True if ret == 0 else False)
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
)

# Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num:
status = (
KVPoll.Success if all(polls) else KVPoll.Failed
)
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint(
endpoint, dst_port, room, status
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)

if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
if kv_chunk.room in self.transfer_infos:
self.transfer_infos.pop(kv_chunk.room)

except queue.Empty:
continue
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)

threading.Thread(target=bootstrap_thread).start()
threading.Thread(target=transfer_thread).start()

def start_decode_thread(self):
self.rank_port = get_free_port()
Expand Down Expand Up @@ -555,7 +562,14 @@ def add_transfer_request(
)
return

self.transfer_queue.put(
# NOTE(shangming): sharding according to the dst_infos to make sure
# requests with the same dst_sessions will be added into the same
# queue, which enables early abort with failed sessions.
dst_infos = self.transfer_infos[bootstrap_room].keys()
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
shard_idx = session_port_sum % len(self.transfer_queues)

self.transfer_queues[shard_idx].put(
TransferKVChunk(
room=bootstrap_room,
prefill_kv_indices=kv_indices,
Expand Down
20 changes: 20 additions & 0 deletions python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import os
import random
import threading
import warnings
from collections import deque
from enum import Enum
Expand Down Expand Up @@ -281,6 +282,25 @@ def set_buf(self, req: Req):
)


class FastQueue:
def __init__(self):
self._buf = deque()
self._cond = threading.Condition()

def put(self, item):
with self._cond:
self._buf.append(item)
# wake up a thread of wait()
self._cond.notify()

def get(self):
with self._cond:
# if queue is empty ,block until is notified()
while not self._buf:
self._cond.wait()
return self._buf.popleft()


def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
Expand Down
Loading