diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py index 2c604617297c..56beda4e5b9c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py @@ -3,7 +3,6 @@ import asyncio import threading import time -import uuid from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -83,28 +82,10 @@ class RecvReqMeta: @dataclass class SendBlockMeta: local_block_ids: list[int] - ready: threading.Event + ready: asyncio.Event expire_time: float = float("inf") -@dataclass -class SendReqMeta: - reqs: dict[ReqId, SendBlockMeta] - lock: threading.Lock - - -@dataclass -class FinishedSendReqSet: - set: set[ReqId] - lock: threading.Lock - - -@dataclass -class FinishedReceiveReqSet: - set: set[ReqId] - lock: asyncio.Lock - - class MooncakeConnectorMetadata(KVConnectorMetadata): def __init__(self): self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {} @@ -437,39 +418,50 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): assert vllm_config.kv_transfer_config self.kv_role = vllm_config.kv_transfer_config.kv_role - self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "num_workers", 10 + self.num_sender_workers = ( + vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "num_workers", 10 + ) ) + # Create more tasks than workers to keep the thread pool saturated. + # Tasks can await async events, so a surplus (2x is a robust heuristic) + # prevents workers from idling. + self.num_sender_tasks = self.num_sender_workers * 2 self.kv_caches_base_addr: list[int] = [] self.device_kv_caches: dict[str, torch.Tensor] = {} - self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock()) + self.reqs_need_send: dict[ReqId, SendBlockMeta] = {} # For kv_both, we will act both prefiller and decoder. if self.kv_role != "kv_consumer": - # Background thread for sending kvcaches to D. - self._mooncake_sender_t: threading.Thread | None = None - # Background thread for processing new sending requests. + # Background threads for sending kvcaches to D. self._sender_executor = ThreadPoolExecutor( - max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender" + max_workers=self.num_sender_workers, + thread_name_prefix="vllm-mooncake-sender", ) logger.debug( - "Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers + "Mooncake Prefiller: use %d workers to send kvcaches", + self.num_sender_workers, ) + # An asyncio queue to buffer incoming requests for the sender + self.sender_worker_queue = asyncio.Queue[tuple[bytes, bytes]]() + self.sender_loop = asyncio.new_event_loop() + # Background thread for processing new sending requests. + self._sender_listener_t = threading.Thread( + target=_async_loop, args=(self.sender_loop,), daemon=True + ) + self._sender_listener_t.start() + if self.kv_role != "kv_producer": self.receiver_loop = asyncio.new_event_loop() self._mooncake_receiver_t = threading.Thread( - target=self._receiver_loop, args=(self.receiver_loop,), daemon=True + target=_async_loop, args=(self.receiver_loop,), daemon=True ) self._mooncake_receiver_t.start() logger.debug("Mooncake Decoder: start receiver thread") - self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet( - set(), threading.Lock() - ) - self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet( - set(), asyncio.Lock() - ) + self.finished_sending_reqs: set[ReqId] = set() + self.finished_recving_reqs: set[ReqId] = set() self.block_size = vllm_config.cache_config.block_size self.model_config = vllm_config.model_config @@ -500,7 +492,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): attn_backend=backend, ) - self.zmq_ctx = zmq.Context() self.async_zmq_ctx = zmq.asyncio.Context() self._encoder = msgspec.msgpack.Encoder() self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) @@ -510,21 +501,17 @@ def __del__(self): def shutdown(self): """Cleanup background threads on destruction.""" - self.zmq_ctx.term() self.async_zmq_ctx.term() if self.kv_role != "kv_consumer": self._sender_executor.shutdown(wait=False) - if self._mooncake_sender_t: - self._mooncake_sender_t.join() + if self.sender_loop.is_running(): + self.sender_loop.call_soon_threadsafe(self.sender_loop.stop) + self._sender_listener_t.join() if self.kv_role != "kv_producer" and self.receiver_loop.is_running(): self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop) self._mooncake_receiver_t.join() - def _receiver_loop(self, loop: asyncio.AbstractEventLoop): - asyncio.set_event_loop(loop) - loop.run_forever() - - def _mooncake_sender( + async def _mooncake_sender_listener( self, ready_event: threading.Event, base_port: int, tp_rank: int ): """ @@ -532,93 +519,86 @@ def _mooncake_sender( to a thread pool, and sends acknowledgments upon completion. """ - frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank) - frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER) - logger.debug("Mooncake sender starting listening on path: %s", frontend_path) - - backend_path = make_zmq_path("inproc", str(uuid.uuid4())) - backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL) + path = make_zmq_path("tcp", self.hostname, base_port + tp_rank) + sock = make_zmq_socket(self.async_zmq_ctx, path, zmq.ROUTER) + logger.debug("Mooncake sender starting listening on path: %s", path) - poller = zmq.Poller() - poller.register(frontend, zmq.POLLIN) - poller.register(backend, zmq.POLLIN) + # Create async worker tasks that process items from the queue + sender_tasks = [ + asyncio.create_task(self._sender_worker(sock)) + for _ in range(self.num_sender_tasks) + ] ready_event.set() try: while True: - sockets = dict(poller.poll()) - - if frontend in sockets: - identity, _, metadata_bytes = frontend.recv_multipart() - self._sender_executor.submit( - self._sender_worker, - identity, - metadata_bytes, - backend_path, - ) - - if backend in sockets: - identity, status = backend.recv_multipart() - frontend.send_multipart((identity, b"", status)) - + identity, _, metadata_bytes = await sock.recv_multipart() + await self.sender_worker_queue.put((identity, metadata_bytes)) except zmq.ContextTerminated: logger.debug("ZMQ context terminated, exiting Mooncake sender thread.") except Exception as e: logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e)) finally: - frontend.close() - backend.close() - - def _sender_worker( - self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str - ): - status = TRANS_ERROR + # Clean up worker tasks + for task in sender_tasks: + task.cancel() + await asyncio.gather(*sender_tasks, return_exceptions=True) + sock.close() - try: - metadata = self._decoder.decode(metadata_bytes) - self.send_kv_to_decode(metadata) - status = TRANS_DONE - except Exception as e: - logger.error("Error processing Mooncake handshake: %s", e) - finally: - pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH) + async def _sender_worker(self, sock: zmq.asyncio.Socket): + while True: try: - pusher.send_multipart((identity, status)) - except zmq.ZMQError as e: - logger.warning( - "Internal error, maybe the server is shutting down. Error: %s", - e, - ) - finally: - pusher.close() - - def send_kv_to_decode(self, meta: MooncakeAgentMetadata): + identity, metadata_bytes = await self.sender_worker_queue.get() + try: + metadata = self._decoder.decode(metadata_bytes) + await self.send_kv_to_decode(metadata) + await sock.send_multipart((identity, b"", TRANS_DONE)) + except Exception as e: + logger.error("Error processing Mooncake xfer request: %s", e) + await sock.send_multipart((identity, b"", TRANS_ERROR)) + finally: + self.sender_worker_queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Error in _sender_worker: %s", e) + + async def send_kv_to_decode(self, meta: MooncakeAgentMetadata): send_reqs: list[tuple[ReqId, SendBlockMeta]] = [] - with self.reqs_need_send.lock: - for req_id in meta.request_ids: - send_meta = self.reqs_need_send.reqs.get(req_id) - if send_meta is None: - logger.warning("Request %s not found in reqs_need_send", req_id) - return - # Mark it as not expired. We will send it now. - send_meta.expire_time = float("inf") - send_reqs.append((req_id, send_meta)) + for req_id in meta.request_ids: + send_meta = self.reqs_need_send.get(req_id) + if send_meta is None: + logger.warning("Request %s not found in reqs_need_send", req_id) + return + # Mark it as not expired. We will send it now. + send_meta.expire_time = float("inf") + send_reqs.append((req_id, send_meta)) + + src_ptrs, dst_ptrs, lengths = await self._build_transfer_params(send_reqs, meta) + remote_session = f"{meta.remote_hostname}:{meta.remote_port}" + ret_value = await self.sender_loop.run_in_executor( + self._sender_executor, + self._send_blocks, + remote_session, + src_ptrs, + dst_ptrs, + lengths, + ) - self._send_blocks(send_reqs, meta) + if ret_value != 0: + raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}") - with self.reqs_need_send.lock: - for req_id in meta.request_ids: - del self.reqs_need_send.reqs[req_id] + for req_id in meta.request_ids: + del self.reqs_need_send[req_id] - with self.finished_sending_reqs.lock: - self.finished_sending_reqs.set.update(meta.request_ids) + self.finished_sending_reqs.update(meta.request_ids) - def _send_blocks( + async def _build_transfer_params( self, send_reqs: list[tuple[ReqId, SendBlockMeta]], agent_meta: MooncakeAgentMetadata, - ): + ) -> tuple[list[int], list[int], list[int]]: src_ptrs = [] dst_ptrs = [] lengths = [] @@ -631,7 +611,7 @@ def _send_blocks( for (req_id, send_meta), remote_block_ids in zip( send_reqs, agent_meta.block_ids ): - send_meta.ready.wait() + await send_meta.ready.wait() num_remote_blocks = len(remote_block_ids) if num_remote_blocks == 0: @@ -670,18 +650,26 @@ def _send_blocks( remote_session, ) + return src_ptrs, dst_ptrs, lengths + + def _send_blocks( + self, + remote_session: str, + src_ptrs: list[int], + dst_ptrs: list[int], + lengths: list[int], + ) -> int: start_time = time.perf_counter() ret_value = self.engine.batch_transfer_sync_write( remote_session, src_ptrs, dst_ptrs, lengths ) - if ret_value != 0: - raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}") - - logger.debug( - "Sending to %s done, took %s", - remote_session, - time.perf_counter() - start_time, - ) + if ret_value == 0: + logger.debug( + "Sending to %s done, took %s", + remote_session, + time.perf_counter() - start_time, + ) + return ret_value def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in mooncake.""" @@ -740,41 +728,63 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): return ready_event = threading.Event() - self._mooncake_sender_t = threading.Thread( - target=self._mooncake_sender, - args=(ready_event, self.side_channel_port, self.tp_rank), - daemon=True, - name="mooncake_sender", + asyncio.run_coroutine_threadsafe( + self._mooncake_sender_listener( + ready_event, self.side_channel_port, self.tp_rank + ), + self.sender_loop, ) - self._mooncake_sender_t.start() ready_event.wait() # Wait for listener ZMQ socket to be ready. async def fetch_finished_recving_reqs(self) -> set[ReqId]: - async with self.finished_recving_reqs.lock: - finished_recving_reqs = self.finished_recving_reqs.set - self.finished_recving_reqs.set = set() + finished_recving_reqs = self.finished_recving_reqs + self.finished_recving_reqs = set() return finished_recving_reqs + async def fetch_finished_sending_reqs(self) -> set[ReqId]: + finished_sending_reqs = self.finished_sending_reqs + self.finished_sending_reqs = set() + + # Handle timeout to avoid stranding blocks on remote. + now = time.perf_counter() + expired_reqs = [ + req_id + for req_id, send_meta in self.reqs_need_send.items() + if send_meta.expire_time < now + ] + for req_id in expired_reqs: + logger.warning( + "Request %s timed out after %d seconds without " + "being sent. Freeing its blocks on the producer side.", + req_id, + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, + ) + del self.reqs_need_send[req_id] + if expired_reqs: + finished_sending_reqs.update(expired_reqs) + + return finished_sending_reqs + def get_finished(self) -> tuple[set[str] | None, set[str] | None]: """ Get requests that are done sending or recving on this specific worker. The scheduler process (via the MultiprocExecutor) will use this output to track which workers are done. """ - fut = None + recv_fut = None + send_fut = None if self.kv_role != "kv_producer": - fut = asyncio.run_coroutine_threadsafe( + recv_fut = asyncio.run_coroutine_threadsafe( self.fetch_finished_recving_reqs(), self.receiver_loop ) if self.kv_role != "kv_consumer": - with self.finished_sending_reqs.lock: - finished_sending_reqs = self.finished_sending_reqs.set - self.finished_sending_reqs.set = set() - else: - finished_sending_reqs = set() + send_fut = asyncio.run_coroutine_threadsafe( + self.fetch_finished_sending_reqs(), self.sender_loop + ) - finished_recving_reqs = fut.result() if fut else set() + finished_recving_reqs = recv_fut.result() if recv_fut else set() + finished_sending_reqs = send_fut.result() if send_fut else set() if finished_sending_reqs or finished_recving_reqs: logger.debug( @@ -785,25 +795,6 @@ def get_finished(self) -> tuple[set[str] | None, set[str] | None]: len(finished_recving_reqs), ) - # Handle timeout to avoid stranding blocks on remote. - now = time.perf_counter() - with self.reqs_need_send.lock: - expired_reqs = [ - req_id - for req_id, send_meta in self.reqs_need_send.reqs.items() - if send_meta.expire_time < now - ] - for req_id in expired_reqs: - logger.warning( - "Request %s timed out after %d seconds without " - "being sent. Freeing its blocks on the producer side.", - req_id, - envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, - ) - del self.reqs_need_send.reqs[req_id] - if expired_reqs: - finished_sending_reqs.update(expired_reqs) - return finished_sending_reqs or None, finished_recving_reqs or None async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]): @@ -844,8 +835,7 @@ async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]): finally: sock.close() - async with self.finished_recving_reqs.lock: - self.finished_recving_reqs.set.update(req_ids) + self.finished_recving_reqs.update(req_ids) logger.debug("pulling kv_caches for %s finished", req_ids) @@ -865,6 +855,24 @@ def group_kv_pull(self, metadata: MooncakeConnectorMetadata): return kv_pulls + async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): + for req_id, block_ids in metadata.reqs_to_send.items(): + if block_ids: + # Already gone through request_finished() + send_meta = self.reqs_need_send[req_id] + send_meta.local_block_ids = block_ids + send_meta.expire_time = ( + time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT + ) + send_meta.ready.set() + else: + # From update_state_after_alloc(), + # but not reach request_finished() yet + self.reqs_need_send[req_id] = SendBlockMeta( + local_block_ids=[], + ready=asyncio.Event(), + ) + def start_load_kv(self, metadata: MooncakeConnectorMetadata): if self.kv_role != "kv_producer": kv_pulls = self.group_kv_pull(metadata) @@ -874,23 +882,9 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): ) if self.kv_role != "kv_consumer": - with self.reqs_need_send.lock: - for req_id, block_ids in metadata.reqs_to_send.items(): - if block_ids: - # Already gone through request_finished() - send_meta = self.reqs_need_send.reqs[req_id] - send_meta.local_block_ids = block_ids - send_meta.ready.set() - send_meta.expire_time = ( - time.perf_counter() - + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT - ) - else: - # From update_state_after_alloc(), - # but not reach request_finished() yet - self.reqs_need_send.reqs[req_id] = SendBlockMeta( - local_block_ids=[], ready=threading.Event() - ) + asyncio.run_coroutine_threadsafe( + self.record_send_reqs(metadata), self.sender_loop + ) def group_concurrent_contiguous( @@ -917,3 +911,8 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int: + vllm_config.parallel_config.data_parallel_index * vllm_config.parallel_config.tensor_parallel_size ) + + +def _async_loop(loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + loop.run_forever()