diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index d986f686657f..2b7c7f652259 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -227,6 +227,12 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert isinstance(self._connector_metadata, MooncakeConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) + def get_block_ids_with_load_errors(self) -> set[int]: + """Get the set of block IDs that failed to load.""" + if self.connector_worker is not None: + return self.connector_worker.get_block_ids_with_load_errors() + return set() + def wait_for_layer_load(self, layer_name: str) -> None: """MooncakeConnector does not do layerwise saving.""" pass @@ -541,6 +547,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.finished_sending_reqs: set[ReqId] = set() self.finished_recving_reqs: set[ReqId] = set() + self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict) + # Track invalid block IDs for failed transfers (similar to nixl_connector) + self._invalid_block_ids: set[int] = set() self.block_size = vllm_config.cache_config.block_size self.model_config = vllm_config.model_config @@ -974,6 +983,39 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): async def fetch_finished_recving_reqs(self) -> set[ReqId]: finished_recving_reqs = self.finished_recving_reqs self.finished_recving_reqs = set() + + # Handle timeout to avoid stranding blocks on remote + now = time.perf_counter() + + expired_req_ids = [] + # Create a copy of items to avoid concurrent modification issues + reqs_to_recv_items = list(self.reqs_to_recv.items()) + for remote_engine_id, pull_metas in reqs_to_recv_items: + pull_metas_items = list(pull_metas.items()) + for req_id, pull_meta in pull_metas_items: + if pull_meta.expire_time < now: + logger.warning( + "Request %s timed out after %d seconds without " + "finishing KV transfer from remote engine %s. " + "Marking %d blocks as invalid to prevent garbage output.", + pull_meta.d_req_id, + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, + remote_engine_id, + len(pull_meta.local_block_ids), + ) + # Mark blocks as invalid to prevent garbage output + self._invalid_block_ids.update(pull_meta.local_block_ids) + finished_recving_reqs.add(pull_meta.d_req_id) + expired_req_ids.append((remote_engine_id, req_id)) + + # Remove expired requests from tracking + for remote_engine_id, req_id in expired_req_ids: + if ( + remote_engine_id in self.reqs_to_recv + and req_id in self.reqs_to_recv[remote_engine_id] + ): + del self.reqs_to_recv[remote_engine_id][req_id] + return finished_recving_reqs async def fetch_finished_sending_reqs(self) -> set[ReqId]: @@ -1088,7 +1130,17 @@ async def receive_kv_from_single_worker( except zmq.ContextTerminated: logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.") except Exception as e: - logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e) + logger.error( + "MooncakeXferMetadata transfer failed for %s: %s. " + "Marking associated blocks as invalid to prevent garbage output.", + req_ids, + e, + ) + # Add failed requests to finished_recving_reqs for scheduler cleanup + for req_id, pull_meta in pull_metas.items(): + # Mark blocks as invalid to prevent garbage output + self._invalid_block_ids.update(pull_meta.local_block_ids) + self.finished_recving_reqs.add(pull_meta.d_req_id) return def process_pulling_result( @@ -1110,10 +1162,18 @@ def process_pulling_result( if response.err_reqs: logger.error( - "pulling kv_caches for %s failed: %s", + "pulling kv_caches for %s failed: %s. " + "Marking associated blocks as invalid to prevent garbage output.", response.err_reqs, response.err_msg, ) + # Add failed requests to finished_recving_reqs for scheduler cleanup + for req_id in response.err_reqs: + if req_id in pull_metas: + pull_meta = pull_metas[req_id] + # Mark blocks as invalid to prevent garbage output + self._invalid_block_ids.update(pull_meta.local_block_ids) + self.finished_recving_reqs.add(pull_meta.d_req_id) async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str): url = remote_bootstrap_addr + "/query" @@ -1157,8 +1217,13 @@ def receive_kv( raise NotImplementedError( "Mooncake: Heterogeneous TP is not supported yet." ) + expire_time = time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT for pull_meta in pull_metas.values(): pull_meta.pull_tasks_count = count + # Set expire time to avoid infinitely waiting for remote KV transfer + # Only set if not already set to avoid race conditions + if pull_meta.expire_time == float("inf"): + pull_meta.expire_time = expire_time for remote_tp_rank in remote_tp_ranks: worker_addr = self._remote_agents[remote_engine_id][remote_tp_rank][0] asyncio.create_task( @@ -1228,6 +1293,8 @@ async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): def start_load_kv(self, metadata: MooncakeConnectorMetadata): if not self.is_kv_producer and metadata.reqs_to_recv: + # Store requests to receive for timeout tracking + self.reqs_to_recv.update(metadata.reqs_to_recv) asyncio.run_coroutine_threadsafe( self._start_load_kv(metadata.reqs_to_recv), self.receiver_loop ) @@ -1239,6 +1306,20 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): self.record_send_reqs(metadata), self.sender_loop ) + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Return and clear the set of block IDs that failed to load. + + This is called by the scheduler to identify blocks that need + to be retried after a Mooncake transfer failure or timeout. + + Returns: + Set of block IDs that encountered load errors. Empty set if none. + """ + result = self._invalid_block_ids + self._invalid_block_ids = set() + return result + def group_concurrent_contiguous( src_indices: list[int], dst_indices: list[int]