Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@
ids of requests that have completed async sending/recving.
"""

import asyncio
import enum
import threading
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Literal

import aiohttp
import torch

from vllm.logger import init_logger
Expand Down Expand Up @@ -167,6 +171,7 @@
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design."
)
self._reqs_need_lease: set[Request] = set()
self._connector_metadata: KVConnectorMetadata | None = None
self._vllm_config = vllm_config
if vllm_config.kv_transfer_config is not None:
Expand All @@ -183,6 +188,22 @@
)
self._role = role

if role == KVConnectorRole.SCHEDULER:
self._last_lease_refresh_time = time.perf_counter()
self._lease_refresh_loop = asyncio.new_event_loop()

def background_loop(loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()

self._lease_refresh_thread = threading.Thread(
target=background_loop,
args=(self._lease_refresh_loop,),
daemon=True,
name="kv-lease-refresh-thread",
)
self._lease_refresh_thread.start()

@property
def role(self) -> KVConnectorRole:
return self._role
Expand Down Expand Up @@ -413,6 +434,62 @@
# Scheduler-side methods
# ==============================

def add_request(self, request: "Request"):
"""
Add a request to the connector's state. This is called when a new
request is added to the scheduler, and can be used by the connector
to track the requests it needs to handle.

Args:
request (Request): the request object.
"""
if (
request.kv_transfer_params is not None
and hasattr(request.kv_transfer_params, "remote_engine_id")
and getattr(request.kv_transfer_params, "do_remote_prefill", False)
):
self._reqs_need_lease.add(request.request_id)

Check failure on line 451 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "add" of "set" has incompatible type "str"; expected "Request" [arg-type]

Check failure on line 451 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "add" of "set" has incompatible type "str"; expected "Request" [arg-type]

Check failure on line 451 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "add" of "set" has incompatible type "str"; expected "Request" [arg-type]

def refresh_leases(self):
"""
Refresh the leases for requests that need it. This is called periodically
by the scheduler to ensure that the connector can maintain any necessary
leases for the requests it is handling.
"""

LEASE_REFRESH_TIME_S = 5
if time.perf_counter() - self._last_lease_refresh_time < LEASE_REFRESH_TIME_S:
return

async def _http_lease_refresh(request_id: str):
async with aiohttp.ClientSession() as session:
url = "http://localhost:7000/refresh_kv_lease"
async with session.post(
url, json={"request_id": request_id}
) as response:
print(f"[BG] [{request_id}] {url} -> {response.status}")

for request_id in self._reqs_need_lease:
# TODO: get result of the future and check if the remote engine
# is still running so we can avoid KV transfer failure.
_ = asyncio.run_coroutine_threadsafe(
_http_lease_refresh(request_id), self._lease_refresh_loop

Check failure on line 476 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "_http_lease_refresh" has incompatible type "Request"; expected "str" [arg-type]

Check failure on line 476 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "_http_lease_refresh" has incompatible type "Request"; expected "str" [arg-type]

Check failure on line 476 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "_http_lease_refresh" has incompatible type "Request"; expected "str" [arg-type]
)
self._last_lease_refresh_time = time.perf_counter()

def finish_lease_refresh(self, request_id: str):
"""
Stop lease refresh for a request. This is called when a request is finished
to stop refreshing its lease.

Args:
request_id (str): the ID of the request.
"""
self._reqs_need_lease.discard(request_id)

Check failure on line 488 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "discard" of "set" has incompatible type "str"; expected "Request" [arg-type]

Check failure on line 488 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "discard" of "set" has incompatible type "str"; expected "Request" [arg-type]

def handle_refresh_lease(self, request_id: str):
return

@abstractmethod
def get_num_new_matched_tokens(
self,
Expand Down Expand Up @@ -513,6 +590,9 @@
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
# NOTE(rob): we need to ensure all subclasses call this super() method,
# else we will get a leak from not cleaning up _reqs_need_lease.
self._reqs_need_lease.discard(request.request_id)

Check failure on line 595 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "discard" of "set" has incompatible type "str"; expected "Request" [arg-type]

Check failure on line 595 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "discard" of "set" has incompatible type "str"; expected "Request" [arg-type]
return False, None

def take_events(self) -> Iterable["KVCacheEvent"]:
Expand Down
52 changes: 43 additions & 9 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_to_refresh: set[ReqId] = set()
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()

Expand All @@ -281,6 +282,12 @@ def add_new_req_to_save(
local_block_ids, kv_transfer_params
)

def add_new_req_to_refresh(
self,
request_id: ReqId,
):
self.reqs_to_refresh.add(request_id)

def add_new_req_to_recv(
self,
request_id: ReqId,
Expand Down Expand Up @@ -366,6 +373,12 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
# Scheduler Side Methods
############################################################

def handle_refresh_lease(self, request_id: str):
assert self.connector_scheduler is not None
return self.connector_scheduler.handle_refresh_lease(
request_id,
)

def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
Expand Down Expand Up @@ -394,6 +407,8 @@ def request_finished(
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
# NOTE(rob): we need to ensure the subclasses all call this.
super().request_finished(request, block_ids)
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)

Expand Down Expand Up @@ -433,7 +448,8 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
return self.connector_worker.get_finished(self._connector_metadata)

def get_block_ids_with_load_errors(self) -> set[int]:
"""Get block IDs that failed to load via NIXL."""
Expand Down Expand Up @@ -546,6 +562,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_need_refresh: set[ReqId] = set()
self._reqs_in_batch: set[ReqId] = set()
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
Expand All @@ -557,6 +574,11 @@ def shutdown(self):
self._nixl_handshake_listener_t.join()
self._nixl_handshake_listener_t = None

def handle_refresh_lease(self, request_id: str):
# We will refresh the lease by extending the expiration time.
# TODO: should check if in reqs_need_send?
self._reqs_need_refresh.add(request_id)

def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
Expand Down Expand Up @@ -777,6 +799,7 @@ def build_connector_meta(
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
self._reqs_need_refresh = set()

return meta

Expand Down Expand Up @@ -989,6 +1012,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list)
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
self._reqs_to_refresh: set[ReqId] = set()
# Set of requests that have been part of a batch, regardless of status.
self._reqs_to_process: set[ReqId] = set()

Expand Down Expand Up @@ -1931,7 +1955,9 @@ def post_process_device_kv_on_receive(
cache, indices, block_size_ratio
)

def get_finished(self) -> tuple[set[str], set[str]]:
def get_finished(
self, connector_metadata: NixlConnectorMetadata
) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
Expand Down Expand Up @@ -1979,17 +2005,25 @@ def get_finished(self) -> tuple[set[str], set[str]]:
) in block_ids_for_blocksize_post_process.items():
self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list)

# Handle timeout to avoid stranding blocks on remote.
# Handle timeout: free P-side KV blocks for requests whose lease expired.
# Full scan (not early-break) since lease refreshes may change expiry
# order arbitrarily.
now = time.perf_counter()
while self._reqs_to_send:
req_id, expires = next(iter(self._reqs_to_send.items()))
# Sorted dict, oldest requests are put first so we can exit early.
if now < expires:
break
for req_id in self._reqs_to_send:
# If we have a lease refresh, update it.
if req_id in connector_metadata.reqs_to_refresh:
self._reqs_to_send[req_id] = now + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT

# Lease not yet expired, continue.
expires_t = self._reqs_to_send[req_id]
if expires_t > now:
continue

# Lease expired, free it and push back to scheduler.
count = self.consumer_notification_counts_by_req.pop(req_id, 0)
self.xfer_stats.record_kv_expired_req()
logger.warning(
"Releasing expired KV blocks for request %s which were "
"Releasing expired KV blocks for request %s which were not "
"retrieved by %d decode worker(s) within %d seconds.",
req_id,
count,
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def build_app(

register_models_api_router(app)

from vllm.entrypoints.openai.internal import (
attach_router as attach_internal_router,
)

attach_internal_router(app)

from vllm.entrypoints.sagemaker.api_router import (
attach_router as register_sagemaker_api_router,
)
Expand Down
39 changes: 39 additions & 0 deletions vllm/entrypoints/openai/internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import Response

from vllm.logger import init_logger

logger = init_logger(__name__)

router = APIRouter()


@router.post("/internal/kv_connector_refresh_lease")
async def kv_connector_refresh_lease(raw_request: Request) -> Response:
"""Receive KV lease refresh requests from D workers.

D workers POST here periodically while requests are queued, before the
KV transfer begins, to prevent P from expiring and freeing KV blocks
prematurely.
"""
try:
body = await raw_request.json()
request_id: str = body.get("request_id")
except (json.JSONDecodeError, Exception) as e:
logger.warning(
"kv_connector_refresh_lease: failed to parse request body: %s", e
)
return Response(status_code=400)

engine_client = raw_request.app.state.engine_client
await engine_client.call_utility_async("kv_connector_refresh_lease", request_id)
return Response(status_code=200)


def attach_router(app: FastAPI) -> None:
app.include_router(router)
5 changes: 5 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,8 @@ def add_request(self, request: Request) -> None:
if request.resumable:
request.streaming_queue = deque()
self.waiting.add_request(request)
if self.connector is not None:
self.connector.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
Expand Down Expand Up @@ -1976,6 +1978,9 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool:
if request.request_id not in self.finished_recving_kv_req_ids:
return False

# Stop KV lease refresh for this request as we are done.
self.connector.finish_lease_refresh(request.request_id)

if request.request_id in self.failed_recving_kv_req_ids:
# Request had KV load failures; num_computed_tokens was already
# updated in _update_requests_with_invalid_blocks
Expand Down
10 changes: 10 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,16 @@ def reset_prefix_cache(
reset_running_requests, reset_connector
)

def kv_connector_refresh_lease(self, request_id: str) -> None:
"""Refresh the KV block lease for pending remote-decode requests.

Called on P-side OpenAI API server when D workers POST
/internal/kv_connector_refresh_lease.
"""
connector = self.scheduler.get_kv_connector()
if connector is not None:
connector.handle_refresh_lease(request_id)

def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.

Expand Down