diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 3803e4fd3869..1a8612cd3fc4 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -321,6 +321,34 @@ def test_prompt_less_than_block_size(): assert len(scheduler_output.scheduled_new_reqs) == 0 +def test_abort_immediately_remote_prefill_enqueues_empty_recv(): + """A remote-prefill request added with abort_immediately=True should + be added to the scheduler's waiting queue then immediately aborted, so the + NIXL connector's request_finished hook enqueues an empty recv to notify + the prefill instance to free its blocks.""" + from vllm.v1.request import RequestStatus + + scheduler = create_scheduler(create_vllm_config()) + + request = create_request(request_id=42, num_tokens=10, do_remote_prefill=True) + assert request.kv_transfer_params is not None + assert request.kv_transfer_params["do_remote_prefill"] is True + + # Mimic the EngineCore.add_request path for an abort-immediately req. + scheduler.add_request(request) + scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED) + + scheduler_output = scheduler.schedule() + meta = scheduler_output.kv_connector_metadata + assert isinstance(meta, NixlConnectorMetadata) + assert set(meta.reqs_to_recv) == {request.request_id} + req_meta = meta.reqs_to_recv[request.request_id] + assert req_meta.local_block_ids == [] + assert req_meta.remote.request_id == f"prefill-{42}" + # do_remote_prefill is consumed by request_finished to prevent re-issuing. + assert request.kv_transfer_params["do_remote_prefill"] is False + + @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", FakeNixlWrapper, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py index 02c418ebd8d7..a053f2c32611 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py @@ -525,7 +525,9 @@ def request_finished( if params.get("do_remote_prefill"): # If do_remote_prefill is still True when the request is finished, # update_state_after_alloc must not have been called (the request - # must have been aborted before it was scheduled). + # must have been aborted before it was scheduled, e.g. via the + # abort_immediately path used to clean up KV-transfer requests + # rejected at the D-side serving layer). # To avoid stranding the prefill blocks in the prefill instance, # we must add empty block_ids to _reqs_need_recv so that our # worker side will notify and free blocks in the prefill instance. diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 6058d8ed86b7..225c398882e8 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -108,6 +108,20 @@ async def abort(self, request_id: str | Iterable[str]) -> None: """ ... + @abstractmethod + async def notify_kv_transfer_request_rejected( + self, + request_id: str, + kv_transfer_params: dict[str, Any], + *, + data_parallel_rank: int | None = None, + ) -> None: + """Notify the engine that a KV-transfer request was rejected before + engine admission, so connector-side cleanup can run (e.g. free + prefill blocks pinned on the P node). + """ + ... + @abstractmethod async def is_tracing_enabled(self) -> bool: ... diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 1026e0a1e3f7..375fe9794c1e 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -234,6 +234,15 @@ async def create_chat_completion( for the API specification. This API mimics the OpenAI Chat Completion API. """ + return await self._with_kv_transfer_rejection_cleanup( + self._create_chat_completion(request, raw_request), request, raw_request + ) + + async def _create_chat_completion( + self, + request: ChatCompletionRequest, + raw_request: Request | None = None, + ) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse: # Streaming response tokenizer = self.renderer.tokenizer assert tokenizer is not None diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index ee4ca9f3ada3..05efe86466d4 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -118,6 +118,15 @@ async def create_completion( - suffix (the language models we currently support do not support suffix) """ + return await self._with_kv_transfer_rejection_cleanup( + self._create_completion(request, raw_request), request, raw_request + ) + + async def _create_completion( + self, + request: CompletionRequest, + raw_request: Request | None = None, + ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse: if request.stream and request.use_beam_search: return self.create_error_response( "Streaming is not currently supported with beam search" diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index f0f84a82204c..c7f28f4b50c4 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -4,7 +4,7 @@ import contextlib import json import time -from collections.abc import AsyncGenerator, Mapping +from collections.abc import AsyncGenerator, Awaitable, Mapping from dataclasses import dataclass, field from http import HTTPStatus from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar @@ -118,6 +118,7 @@ def build_chat_params( ) RequestT = TypeVar("RequestT", bound=AnyRequest) +_T = TypeVar("_T") @dataclass(kw_only=True) @@ -156,6 +157,7 @@ def __init__( self.model_config = engine_client.model_config self.renderer = engine_client.renderer self.input_processor = engine_client.input_processor + self.has_kv_connector = engine_client.vllm_config.kv_transfer_config is not None # Computed once at startup (cached by ``vllm_config`` identity) and # stamped on non-streaming responses. Streaming chunks deliberately @@ -616,6 +618,40 @@ def _get_data_parallel_rank(raw_request: Request | None) -> int | None: except ValueError: return None + async def _with_kv_transfer_rejection_cleanup( + self, + awaitable: Awaitable[_T], + request: ChatCompletionRequest | CompletionRequest | ResponsesRequest, + raw_request: Request | None, + ) -> _T: + """Wrap a `create_*` coroutine so that, if it raises or returns an + ErrorResponse (i.e. the request never reached the engine), the KV + connector is notified to free any pinned remote-prefill blocks.""" + kv_transfer_params = self.has_kv_connector and request.kv_transfer_params + if not kv_transfer_params or not kv_transfer_params.get("do_remote_prefill"): + return await awaitable + + notify = True + try: + result = await awaitable + if not isinstance(result, ErrorResponse): + notify = False + return result + finally: + if notify: + try: + await self.engine_client.notify_kv_transfer_request_rejected( + request.request_id, + kv_transfer_params, + data_parallel_rank=self._get_data_parallel_rank(raw_request), + ) + except Exception: + logger.warning( + "Failed to notify KV connector about rejected request %s", + request.request_id, + exc_info=True, + ) + @staticmethod def _parse_tool_calls_from_content( request: ResponsesRequest | ChatCompletionRequest, diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 9c4dc48589ff..fb28a2256ad0 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -323,6 +323,17 @@ async def create_responses( AsyncGenerator[StreamingResponsesResponse, None] | ResponsesResponse | ErrorResponse + ): + return await self._with_kv_transfer_rejection_cleanup( + self._create_responses(request, raw_request), request, raw_request + ) + + async def _create_responses( + self, request: ResponsesRequest, raw_request: Request | None = None + ) -> ( + AsyncGenerator[StreamingResponsesResponse, None] + | ResponsesResponse + | ErrorResponse ): error_check_ret = await self._check_model(request) if error_check_ret is not None: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 8172ead08319..f6c80055b651 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -122,6 +122,12 @@ class EngineCoreRequest( reasoning_ended: bool | None = None reasoning_parser_kwargs: dict[str, Any] | None = None + # If True, the request should be added to the scheduler's waiting queue + # and immediately aborted, so connector-side cleanup runs via the standard + # request_finished hook. Used to free P-side prefill blocks when a + # KV-transfer request is rejected on the D node before engine admission. + abort_immediately: bool = False + @property def params(self) -> SamplingParams | PoolingParams: """Return the processed params (sampling or pooling).""" diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0e55a685bd94..431a0dc5e89a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -720,6 +720,33 @@ async def abort( if self.log_requests: logger.info("Aborted request(s) %s.", ",".join(request_ids)) + async def notify_kv_transfer_request_rejected( + self, + request_id: str, + kv_transfer_params: dict[str, Any], + *, + data_parallel_rank: int | None = None, + ) -> None: + """Submit a pre-aborted request so the connector's request_finished + hook runs to free any pre-admission KV-transfer resources (e.g. NIXL + prefill blocks pinned on the P node).""" + request = EngineCoreRequest( + request_id=request_id, + prompt_token_ids=[0], + mm_features=None, + sampling_params=SamplingParams( + max_tokens=1, + extra_args={"kv_transfer_params": dict(kv_transfer_params)}, + ), + pooling_params=None, + arrival_time=time.time(), + lora_request=None, + cache_salt=None, + data_parallel_rank=data_parallel_rank, + abort_immediately=True, + ) + await self.engine_core.add_request_async(request) + async def pause_generation( self, *, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 11c5ee19a664..122cdd8f965d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -344,6 +344,10 @@ def add_request(self, request: Request, request_wave: int = 0): ) self.scheduler.add_request(request) + if request.abort_immediately: + # Immediately abort so the connector's request_finished hook runs + # to free any pre-admission KV-transfer resources. + self.abort_requests([request.request_id]) def abort_requests(self, request_ids: list[str]): """Abort requests from the scheduler.""" diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 738a68c83680..0d435deb3f05 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -76,6 +76,7 @@ def __init__( resumable: bool = False, reasoning_ended: bool | None = None, reasoning_parser_kwargs: dict[str, Any] | None = None, + abort_immediately: bool = False, ) -> None: self.request_id = request_id self.client_index = client_index @@ -182,6 +183,10 @@ def __init__( # None entry in the queue means finished. self.streaming_queue: deque[StreamingUpdate | None] | None = None + # If True, request should be aborted immediately after being added to + # the scheduler so the connector's request_finished hook runs. + self.abort_immediately = abort_immediately + @classmethod def from_engine_core_request( cls, @@ -206,6 +211,7 @@ def from_engine_core_request( resumable=request.resumable, reasoning_ended=request.reasoning_ended, reasoning_parser_kwargs=request.reasoning_parser_kwargs, + abort_immediately=request.abort_immediately, ) def append_output_token_ids(