diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index a0e03b002b34..305ab1d2ed16 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -38,11 +38,15 @@ ids of requests that have completed async sending/recving. """ +import asyncio import enum +import threading +import time from abc import ABC, abstractmethod from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any, Literal +import aiohttp import torch from vllm.logger import init_logger @@ -167,6 +171,7 @@ def __init__( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design." ) + self._reqs_need_lease: set[Request] = set() self._connector_metadata: KVConnectorMetadata | None = None self._vllm_config = vllm_config if vllm_config.kv_transfer_config is not None: @@ -183,6 +188,22 @@ def __init__( ) self._role = role + if role == KVConnectorRole.SCHEDULER: + self._last_lease_refresh_time = time.perf_counter() + self._lease_refresh_loop = asyncio.new_event_loop() + + def background_loop(loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + loop.run_forever() + + self._lease_refresh_thread = threading.Thread( + target=background_loop, + args=(self._lease_refresh_loop,), + daemon=True, + name="kv-lease-refresh-thread", + ) + self._lease_refresh_thread.start() + @property def role(self) -> KVConnectorRole: return self._role @@ -413,6 +434,62 @@ def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: # Scheduler-side methods # ============================== + def add_request(self, request: "Request"): + """ + Add a request to the connector's state. This is called when a new + request is added to the scheduler, and can be used by the connector + to track the requests it needs to handle. + + Args: + request (Request): the request object. + """ + if ( + request.kv_transfer_params is not None + and hasattr(request.kv_transfer_params, "remote_engine_id") + and getattr(request.kv_transfer_params, "do_remote_prefill", False) + ): + self._reqs_need_lease.add(request.request_id) + + def refresh_leases(self): + """ + Refresh the leases for requests that need it. This is called periodically + by the scheduler to ensure that the connector can maintain any necessary + leases for the requests it is handling. + """ + + LEASE_REFRESH_TIME_S = 5 + if time.perf_counter() - self._last_lease_refresh_time < LEASE_REFRESH_TIME_S: + return + + async def _http_lease_refresh(request_id: str): + async with aiohttp.ClientSession() as session: + url = "http://localhost:7000/refresh_kv_lease" + async with session.post( + url, json={"request_id": request_id} + ) as response: + print(f"[BG] [{request_id}] {url} -> {response.status}") + + for request_id in self._reqs_need_lease: + # TODO: get result of the future and check if the remote engine + # is still running so we can avoid KV transfer failure. + _ = asyncio.run_coroutine_threadsafe( + _http_lease_refresh(request_id), self._lease_refresh_loop + ) + self._last_lease_refresh_time = time.perf_counter() + + def finish_lease_refresh(self, request_id: str): + """ + Stop lease refresh for a request. This is called when a request is finished + to stop refreshing its lease. + + Args: + request_id (str): the ID of the request. + """ + self._reqs_need_lease.discard(request_id) + + def handle_refresh_lease(self, request_id: str): + return + @abstractmethod def get_num_new_matched_tokens( self, @@ -513,6 +590,9 @@ def request_finished( Optional KVTransferParams to be included in the request outputs returned by the engine. """ + # NOTE(rob): we need to ensure all subclasses call this super() method, + # else we will get a leak from not cleaning up _reqs_need_lease. + self._reqs_need_lease.discard(request.request_id) return False, None def take_events(self) -> Iterable["KVCacheEvent"]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 87091d650b17..76fdffb56749 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -256,6 +256,7 @@ def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {} self.reqs_to_send: dict[ReqId, float] = {} + self.reqs_to_refresh: set[ReqId] = set() self.reqs_in_batch: set[ReqId] = set() self.reqs_not_processed: set[ReqId] = set() @@ -281,6 +282,12 @@ def add_new_req_to_save( local_block_ids, kv_transfer_params ) + def add_new_req_to_refresh( + self, + request_id: ReqId, + ): + self.reqs_to_refresh.add(request_id) + def add_new_req_to_recv( self, request_id: ReqId, @@ -366,6 +373,12 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig): # Scheduler Side Methods ############################################################ + def handle_refresh_lease(self, request_id: str): + assert self.connector_scheduler is not None + return self.connector_scheduler.handle_refresh_lease( + request_id, + ) + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int | None, bool]: @@ -394,6 +407,8 @@ def request_finished( request: "Request", block_ids: list[int], ) -> tuple[bool, dict[str, Any] | None]: + # NOTE(rob): we need to ensure the subclasses all call this. + super().request_finished(request, block_ids) assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -433,7 +448,8 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None - return self.connector_worker.get_finished() + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + return self.connector_worker.get_finished(self._connector_metadata) def get_block_ids_with_load_errors(self) -> set[int]: """Get block IDs that failed to load via NIXL.""" @@ -546,6 +562,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._reqs_need_save: dict[ReqId, Request] = {} # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} + self._reqs_need_refresh: set[ReqId] = set() self._reqs_in_batch: set[ReqId] = set() # Reqs to remove from processed set because they're not to send after # remote prefill or aborted. @@ -557,6 +574,11 @@ def shutdown(self): self._nixl_handshake_listener_t.join() self._nixl_handshake_listener_t = None + def handle_refresh_lease(self, request_id: str): + # We will refresh the lease by extending the expiration time. + # TODO: should check if in reqs_need_send? + self._reqs_need_refresh.add(request_id) + def set_xfer_handshake_metadata( self, metadata: dict[int, KVConnectorHandshakeMetadata] ) -> None: @@ -777,6 +799,7 @@ def build_connector_meta( self._reqs_in_batch = set() self._reqs_not_processed = set() self._reqs_need_send = {} + self._reqs_need_refresh = set() return meta @@ -989,6 +1012,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list) # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} + self._reqs_to_refresh: set[ReqId] = set() # Set of requests that have been part of a batch, regardless of status. self._reqs_to_process: set[ReqId] = set() @@ -1931,7 +1955,9 @@ def post_process_device_kv_on_receive( cache, indices, block_size_ratio ) - def get_finished(self) -> tuple[set[str], set[str]]: + def get_finished( + self, connector_metadata: NixlConnectorMetadata + ) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. The scheduler process (via the MultiprocExecutor) will use this output @@ -1979,17 +2005,25 @@ def get_finished(self) -> tuple[set[str], set[str]]: ) in block_ids_for_blocksize_post_process.items(): self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list) - # Handle timeout to avoid stranding blocks on remote. + # Handle timeout: free P-side KV blocks for requests whose lease expired. + # Full scan (not early-break) since lease refreshes may change expiry + # order arbitrarily. now = time.perf_counter() - while self._reqs_to_send: - req_id, expires = next(iter(self._reqs_to_send.items())) - # Sorted dict, oldest requests are put first so we can exit early. - if now < expires: - break + for req_id in self._reqs_to_send: + # If we have a lease refresh, update it. + if req_id in connector_metadata.reqs_to_refresh: + self._reqs_to_send[req_id] = now + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + + # Lease not yet expired, continue. + expires_t = self._reqs_to_send[req_id] + if expires_t > now: + continue + + # Lease expired, free it and push back to scheduler. count = self.consumer_notification_counts_by_req.pop(req_id, 0) self.xfer_stats.record_kv_expired_req() logger.warning( - "Releasing expired KV blocks for request %s which were " + "Releasing expired KV blocks for request %s which were not " "retrieved by %d decode worker(s) within %d seconds.", req_id, count, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d76a7446d2a9..a9bdad47c326 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -188,6 +188,12 @@ def build_app( register_models_api_router(app) + from vllm.entrypoints.openai.internal import ( + attach_router as attach_internal_router, + ) + + attach_internal_router(app) + from vllm.entrypoints.sagemaker.api_router import ( attach_router as register_sagemaker_api_router, ) diff --git a/vllm/entrypoints/openai/internal.py b/vllm/entrypoints/openai/internal.py new file mode 100644 index 000000000000..7ea049a76da9 --- /dev/null +++ b/vllm/entrypoints/openai/internal.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import Response + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +router = APIRouter() + + +@router.post("/internal/kv_connector_refresh_lease") +async def kv_connector_refresh_lease(raw_request: Request) -> Response: + """Receive KV lease refresh requests from D workers. + + D workers POST here periodically while requests are queued, before the + KV transfer begins, to prevent P from expiring and freeing KV blocks + prematurely. + """ + try: + body = await raw_request.json() + request_id: str = body.get("request_id") + except (json.JSONDecodeError, Exception) as e: + logger.warning( + "kv_connector_refresh_lease: failed to parse request body: %s", e + ) + return Response(status_code=400) + + engine_client = raw_request.app.state.engine_client + await engine_client.call_utility_async("kv_connector_refresh_lease", request_id) + return Response(status_code=200) + + +def attach_router(app: FastAPI) -> None: + app.include_router(router) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index bf397ad681ca..fe7b745af341 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1676,6 +1676,8 @@ def add_request(self, request: Request) -> None: if request.resumable: request.streaming_queue = deque() self.waiting.add_request(request) + if self.connector is not None: + self.connector.add_request(request) self.requests[request.request_id] = request if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) @@ -1976,6 +1978,9 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: if request.request_id not in self.finished_recving_kv_req_ids: return False + # Stop KV lease refresh for this request as we are done. + self.connector.finish_lease_refresh(request.request_id) + if request.request_id in self.failed_recving_kv_req_ids: # Request had KV load failures; num_computed_tokens was already # updated in _update_requests_with_invalid_blocks diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4de3e4ea7d3a..9551b3b1c9ec 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -576,6 +576,16 @@ def reset_prefix_cache( reset_running_requests, reset_connector ) + def kv_connector_refresh_lease(self, request_id: str) -> None: + """Refresh the KV block lease for pending remote-decode requests. + + Called on P-side OpenAI API server when D workers POST + /internal/kv_connector_refresh_lease. + """ + connector = self.scheduler.get_kv_connector() + if connector is not None: + connector.handle_refresh_lease(request_id) + def reset_encoder_cache(self) -> None: """Reset the encoder cache to invalidate all cached encoder outputs.