Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,34 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0


def test_abort_immediately_remote_prefill_enqueues_empty_recv():
"""A remote-prefill request added with abort_immediately=True should
be added to the scheduler's waiting queue then immediately aborted, so the
NIXL connector's request_finished hook enqueues an empty recv to notify
the prefill instance to free its blocks."""
from vllm.v1.request import RequestStatus

scheduler = create_scheduler(create_vllm_config())

request = create_request(request_id=42, num_tokens=10, do_remote_prefill=True)
assert request.kv_transfer_params is not None
assert request.kv_transfer_params["do_remote_prefill"] is True

# Mimic the EngineCore.add_request path for an abort-immediately req.
scheduler.add_request(request)
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)

scheduler_output = scheduler.schedule()
meta = scheduler_output.kv_connector_metadata
assert isinstance(meta, NixlConnectorMetadata)
assert set(meta.reqs_to_recv) == {request.request_id}
req_meta = meta.reqs_to_recv[request.request_id]
assert req_meta.local_block_ids == []
assert req_meta.remote.request_id == f"prefill-{42}"
# do_remote_prefill is consumed by request_finished to prevent re-issuing.
assert request.kv_transfer_params["do_remote_prefill"] is False


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ def request_finished(
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# must have been aborted before it was scheduled, e.g. via the
# abort_immediately path used to clean up KV-transfer requests
# rejected at the D-side serving layer).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
Expand Down
14 changes: 14 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ async def abort(self, request_id: str | Iterable[str]) -> None:
"""
...

@abstractmethod
async def notify_kv_transfer_request_rejected(
self,
request_id: str,
kv_transfer_params: dict[str, Any],
*,
data_parallel_rank: int | None = None,
) -> None:
"""Notify the engine that a KV-transfer request was rejected before
engine admission, so connector-side cleanup can run (e.g. free
prefill blocks pinned on the P node).
"""
...

@abstractmethod
async def is_tracing_enabled(self) -> bool: ...

Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/chat_completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,15 @@ async def create_chat_completion(
for the API specification. This API mimics the OpenAI
Chat Completion API.
"""
return await self._with_kv_transfer_rejection_cleanup(
self._create_chat_completion(request, raw_request), request, raw_request
)

async def _create_chat_completion(
self,
request: ChatCompletionRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse:
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ async def create_completion(
- suffix (the language models we currently support do not support
suffix)
"""
return await self._with_kv_transfer_rejection_cleanup(
self._create_completion(request, raw_request), request, raw_request
)

async def _create_completion(
self,
request: CompletionRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
if request.stream and request.use_beam_search:
return self.create_error_response(
"Streaming is not currently supported with beam search"
Expand Down
37 changes: 36 additions & 1 deletion vllm/entrypoints/openai/engine/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import json
import time
from collections.abc import AsyncGenerator, Mapping
from collections.abc import AsyncGenerator, Awaitable, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
Expand Down Expand Up @@ -118,6 +118,7 @@ def build_chat_params(
)

RequestT = TypeVar("RequestT", bound=AnyRequest)
_T = TypeVar("_T")


@dataclass(kw_only=True)
Expand Down Expand Up @@ -616,6 +617,40 @@ def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
except ValueError:
return None

async def _with_kv_transfer_rejection_cleanup(
self,
awaitable: Awaitable[_T],
request: ChatCompletionRequest | CompletionRequest | ResponsesRequest,
raw_request: Request | None,
) -> _T:
"""Wrap a `create_*` coroutine so that, if it raises or returns an
ErrorResponse (i.e. the request never reached the engine), the KV
connector is notified to free any pinned remote-prefill blocks."""
kv_transfer_params = request.kv_transfer_params
if not kv_transfer_params or not kv_transfer_params.get("do_remote_prefill"):
return await awaitable

notify = True
try:
result = await awaitable
if not isinstance(result, ErrorResponse):
notify = False
return result
finally:
if notify:
try:
await self.engine_client.notify_kv_transfer_request_rejected(
request.request_id,
kv_transfer_params,
data_parallel_rank=self._get_data_parallel_rank(raw_request),
)
except Exception:
logger.warning(
"Failed to notify KV connector about rejected request %s",
request.request_id,
exc_info=True,
)

@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
Expand Down
11 changes: 11 additions & 0 deletions vllm/entrypoints/openai/responses/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ async def create_responses(
AsyncGenerator[StreamingResponsesResponse, None]
| ResponsesResponse
| ErrorResponse
):
return await self._with_kv_transfer_rejection_cleanup(
self._create_responses(request, raw_request), request, raw_request
)

async def _create_responses(
self, request: ResponsesRequest, raw_request: Request | None = None
) -> (
AsyncGenerator[StreamingResponsesResponse, None]
| ResponsesResponse
| ErrorResponse
):
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ class EngineCoreRequest(
reasoning_ended: bool | None = None
reasoning_parser_kwargs: dict[str, Any] | None = None

# If True, the request should be added to the scheduler's waiting queue
# and immediately aborted, so connector-side cleanup runs via the standard
# request_finished hook. Used to free P-side prefill blocks when a
# KV-transfer request is rejected on the D node before engine admission.
abort_immediately: bool = False

@property
def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling)."""
Expand Down
27 changes: 27 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,33 @@ async def abort(
if self.log_requests:
logger.info("Aborted request(s) %s.", ",".join(request_ids))

async def notify_kv_transfer_request_rejected(
self,
request_id: str,
kv_transfer_params: dict[str, Any],
*,
data_parallel_rank: int | None = None,
) -> None:
"""Submit a pre-aborted request so the connector's request_finished
hook runs to free any pre-admission KV-transfer resources (e.g. NIXL
prefill blocks pinned on the P node)."""
request = EngineCoreRequest(
request_id=request_id,
prompt_token_ids=[0],
mm_features=None,
sampling_params=SamplingParams(
max_tokens=1,
extra_args={"kv_transfer_params": dict(kv_transfer_params)},
),
pooling_params=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
data_parallel_rank=data_parallel_rank,
abort_immediately=True,
)
await self.engine_core.add_request_async(request)

async def pause_generation(
self,
*,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ def add_request(self, request: Request, request_wave: int = 0):
)

self.scheduler.add_request(request)
if request.abort_immediately:
# Immediately abort so the connector's request_finished hook runs
# to free any pre-admission KV-transfer resources.
self.abort_requests([request.request_id])

def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
resumable: bool = False,
reasoning_ended: bool | None = None,
reasoning_parser_kwargs: dict[str, Any] | None = None,
abort_immediately: bool = False,
) -> None:
self.request_id = request_id
self.client_index = client_index
Expand Down Expand Up @@ -182,6 +183,10 @@ def __init__(
# None entry in the queue means finished.
self.streaming_queue: deque[StreamingUpdate | None] | None = None

# If True, request should be aborted immediately after being added to
# the scheduler so the connector's request_finished hook runs.
self.abort_immediately = abort_immediately

@classmethod
def from_engine_core_request(
cls,
Expand All @@ -206,6 +211,7 @@ def from_engine_core_request(
resumable=request.resumable,
reasoning_ended=request.reasoning_ended,
reasoning_parser_kwargs=request.reasoning_parser_kwargs,
abort_immediately=request.abort_immediately,
)

def append_output_token_ids(
Expand Down
Loading