From da7e9d492a2b03ed91214fa5fb907c411b5e4e1f Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 25 Apr 2025 13:26:04 +0800 Subject: [PATCH 1/2] [PD] Add kvargs table and thread pool for kvcache sender of mooncake Signed-off-by: Shangming Cai --- .../srt/disaggregation/mooncake/conn.py | 122 ++++++++++++++---- 1 file changed, 100 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 9e8fa476fa2..947114131dc 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -1,8 +1,10 @@ from __future__ import annotations import asyncio +import concurrent.futures import dataclasses import logging +import os import queue import socket import struct @@ -73,9 +75,7 @@ class TransferInfo: endpoint: str dst_port: int mooncake_session_id: str - dst_kv_ptrs: list[int] dst_kv_indices: npt.NDArray[np.int64] - dst_aux_ptrs: list[int] dst_aux_index: int @classmethod @@ -85,10 +85,29 @@ def from_zmq(cls, msg: List[bytes]): endpoint=msg[1].decode("ascii"), dst_port=int(msg[2].decode("ascii")), mooncake_session_id=msg[3].decode("ascii"), + dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64), + dst_aux_index=int(msg[5].decode("ascii")), + ) + + +@dataclasses.dataclass +class KVArgsRegisterInfo: + room: str + endpoint: str + dst_port: int + mooncake_session_id: str + dst_kv_ptrs: list[int] + dst_aux_ptrs: list[int] + + @classmethod + def from_zmq(cls, msg: List[bytes]): + return cls( + room=str(msg[0].decode("ascii")), + endpoint=msg[1].decode("ascii"), + dst_port=int(msg[2].decode("ascii")), + mooncake_session_id=msg[3].decode("ascii"), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), - dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), - dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), - dst_aux_index=int(msg[7].decode("ascii")), + dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), ) @@ -123,8 +142,15 @@ def __init__( if self.disaggregation_mode == DisaggregationMode.PREFILL: self.transfer_queue = queue.Queue() self.transfer_infos: Dict[int, TransferInfo] = {} + self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.start_prefill_thread() self._register_to_bootstrap() + + # Determine the number of threads to use for kv sender + cpu_count = os.cpu_count() + self.executor = concurrent.futures.ThreadPoolExecutor( + max_workers=cpu_count if cpu_count is not None else 64 + ) elif self.disaggregation_mode == DisaggregationMode.DECODE: self.start_decode_thread() self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} @@ -158,28 +184,53 @@ def send_kvcache( dst_kv_ptrs: list[int], dst_kv_indices: npt.NDArray[np.int64], ): - # group by indices + # Group by indices prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_indices, dst_kv_indices ) num_layers = len(self.kv_args.kv_data_ptrs) - for layer_id in range(num_layers): - src_ptr = self.kv_args.kv_data_ptrs[layer_id] - dst_ptr = dst_kv_ptrs[layer_id] - item_len = self.kv_args.kv_item_lens[layer_id] + layers_params = [ + ( + self.kv_args.kv_data_ptrs[layer_id], + dst_kv_ptrs[layer_id], + self.kv_args.kv_item_lens[layer_id], + ) + for layer_id in range(num_layers) + ] + # Worker function for processing a single layer + def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): src_addr = src_ptr + int(prefill_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len length = item_len * len(prefill_index) - # TODO: make async later status = self.engine.transfer_sync( mooncake_session_id, src_addr, dst_addr, length ) if status != 0: return status + return 0 + + futures = [ + self.executor.submit( + process_layer, + src_ptr, + dst_ptr, + item_len, + ) + for (src_ptr, dst_ptr, item_len) in layers_params + ] + + for future in concurrent.futures.as_completed(futures): + status = future.result() + if status != 0: + # Immediate shutdown on first error (existing tasks will finish) + executor.shutdown(wait=False) + for f in futures: + f.cancel() + return status return 0 @@ -223,6 +274,13 @@ def bootstrap_thread(): waiting_req_bytes = self.server_socket.recv_multipart() room = waiting_req_bytes[0].decode("ascii") if room == "None": + mooncake_session_id = waiting_req_bytes[3].decode("ascii") + self.decode_kv_args_table[mooncake_session_id] = ( + KVArgsRegisterInfo.from_zmq(waiting_req_bytes) + ) + logger.debug( + f"Register KVArgs from {mooncake_session_id} successfully" + ) continue room = int(room) self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes) @@ -244,7 +302,7 @@ def transfer_thread(): ret = self.send_kvcache( req.mooncake_session_id, kv_chunk.prefill_kv_indices, - req.dst_kv_ptrs, + self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs, chunked_dst_kv_indice, ) if ret != 0: @@ -259,7 +317,9 @@ def transfer_thread(): ret = self.send_aux( req.mooncake_session_id, kv_chunk.prefill_aux_index, - req.dst_aux_ptrs, + self.decode_kv_args_table[ + req.mooncake_session_id + ].dst_aux_ptrs, req.dst_aux_index, ) self.request_status[req.room] = ( @@ -460,6 +520,8 @@ def __init__( ) else: self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info + # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server + self._register_kv_args() else: self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key] @@ -502,6 +564,30 @@ def _get_prefill_dp_size_from_server(self) -> int: logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") return None + def _register_kv_args(self): + self.prefill_server_url = ( + f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}" + ) + + packed_kv_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs + ) + packed_aux_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs + ) + sock, lock = self._connect("tcp://" + self.prefill_server_url) + with lock: + sock.send_multipart( + [ + "None".encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.session_id.encode("ascii"), + packed_kv_data_ptrs, + packed_aux_data_ptrs, + ] + ) + @classmethod def _connect(cls, endpoint: str): with cls._global_lock: @@ -520,12 +606,6 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" ) - packed_kv_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs - ) - packed_aux_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs - ) sock, lock = self._connect("tcp://" + self.prefill_server_url) with lock: sock.send_multipart( @@ -534,9 +614,7 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non get_local_ip_by_remote().encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"), self.session_id.encode("ascii"), - packed_kv_data_ptrs, kv_indices.tobytes(), - packed_aux_data_ptrs, str(aux_index).encode("ascii"), ] ) @@ -610,7 +688,7 @@ async def _handle_route_put(self, request: web.Request): "rank_port": rank_port, } logger.debug( - f"Registered Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Registere Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200) From 278a2a0e9d554c9b5290aec1df0c03422e338036 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 25 Apr 2025 17:34:30 +0800 Subject: [PATCH 2/2] fix typo Signed-off-by: Shangming Cai --- python/sglang/srt/disaggregation/mooncake/conn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 947114131dc..c2a516c4673 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -688,7 +688,7 @@ async def _handle_route_put(self, request: web.Request): "rank_port": rank_port, } logger.debug( - f"Registere Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200)