Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from unittest.mock import Mock
from unittest.mock import Mock, patch

import pytest
import torch
Expand All @@ -20,6 +20,7 @@
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.platforms.cpu import CpuPlatform
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.utils.hashing import sha256
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
Expand Down Expand Up @@ -106,6 +107,65 @@ def test_schedule(enable_prefix_caching: bool, prompt_logprobs: int | None):
assert scheduler.running[i] == request


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("reserve_full_isl", [True, False])
def test_schedule_rejects_waiting_request_exceeding_kv_capacity(
reserve_full_isl: bool,
):
with patch("vllm.platforms.current_platform", CpuPlatform()):
scheduler = create_scheduler(
max_num_seqs=2,
max_num_batched_tokens=128,
max_model_len=128,
num_blocks=5,
block_size=16,
)
scheduler.scheduler_reserve_full_isl = reserve_full_isl
scheduler.scheduler_config.scheduler_reserve_full_isl = reserve_full_isl

request_too_large = create_requests(
num_requests=1,
num_tokens=65,
req_ids=["too_large"],
)[0]
request_small = create_requests(
num_requests=1,
num_tokens=8,
req_ids=["small"],
)[0]
scheduler.add_request(request_too_large)
scheduler.add_request(request_small)

output = scheduler.schedule()

assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == request_small.request_id
assert request_too_large.request_id in output.finished_req_ids
assert request_too_large.request_id not in scheduler.requests
assert request_too_large.status == RequestStatus.FINISHED_ERROR
assert not scheduler.waiting
assert not scheduler.skipped_waiting
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == request_small.request_id

model_output = ModelRunnerOutput(
req_ids=[request_small.request_id],
req_id_to_index={request_small.request_id: 0},
sampled_token_ids=[[]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output, model_output)

assert len(engine_core_outputs[0].outputs) == 1
error_output = engine_core_outputs[0].outputs[0]
assert error_output.request_id == request_too_large.request_id
assert error_output.new_token_ids == []
assert error_output.finish_reason == FinishReason.ERROR
assert "KV cache capacity" in str(error_output.stop_reason)


def test_schedule_multimodal_requests():
scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf")
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
Expand Down
38 changes: 37 additions & 1 deletion vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,42 @@ def can_fit_full_sequence(
new_computed_blocks: KVCacheBlocks | None = None,
num_external_computed_tokens: int = 0,
num_encoder_tokens: int = 0,
) -> bool:
return self._can_fit_full_sequence_with_block_budget(
request,
self.block_pool.get_num_free_blocks(),
num_new_computed_tokens=num_new_computed_tokens,
new_computed_blocks=new_computed_blocks,
num_external_computed_tokens=num_external_computed_tokens,
num_encoder_tokens=num_encoder_tokens,
)

def can_fit_full_sequence_in_empty_cache(
self,
request: Request,
num_new_computed_tokens: int = 0,
new_computed_blocks: KVCacheBlocks | None = None,
num_external_computed_tokens: int = 0,
num_encoder_tokens: int = 0,
) -> bool:
"""Check if the sequence fits when this request is the only resident."""
return self._can_fit_full_sequence_with_block_budget(
request,
self.block_pool.num_gpu_blocks - 1,
num_new_computed_tokens=num_new_computed_tokens,
new_computed_blocks=new_computed_blocks,
num_external_computed_tokens=num_external_computed_tokens,
num_encoder_tokens=num_encoder_tokens,
)

def _can_fit_full_sequence_with_block_budget(
self,
request: Request,
num_available_blocks: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: KVCacheBlocks | None = None,
num_external_computed_tokens: int = 0,
num_encoder_tokens: int = 0,
) -> bool:
"""Check if the KV cache has enough free blocks to hold the full
sequence, accounting for prefix cache hits and sliding window.
Expand Down Expand Up @@ -259,7 +295,7 @@ def can_fit_full_sequence(
num_tokens_main_model=full_num_tokens,
)

return num_blocks_to_allocate <= self.block_pool.get_num_free_blocks()
return num_blocks_to_allocate <= num_available_blocks

def allocate_slots(
self,
Expand Down
76 changes: 69 additions & 7 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
# requests so that they can free the cached states for those requests.
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
self.pending_outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)

# Counter for requests waiting for streaming input. Used to calculate
# number of unfinished requests
Expand Down Expand Up @@ -751,8 +752,15 @@ def schedule(self) -> SchedulerOutput:
num_encoder_tokens=num_encoder_tokens,
)
):
if request.has_encoder_inputs:
self.encoder_cache_manager.free(request)
if self._reject_or_defer_request_for_kv_capacity(
request_queue,
request,
num_new_local_computed_tokens,
new_computed_blocks,
num_external_computed_tokens,
num_encoder_tokens,
):
continue
break

new_blocks = self.kv_cache_manager.allocate_slots(
Expand All @@ -768,11 +776,15 @@ def schedule(self) -> SchedulerOutput:

if new_blocks is None:
# The request cannot be scheduled.

# NOTE: we need to untouch the request from the encode cache
# manager
if request.has_encoder_inputs:
self.encoder_cache_manager.free(request)
if self._reject_or_defer_request_for_kv_capacity(
request_queue,
request,
num_new_local_computed_tokens,
new_computed_blocks,
num_external_computed_tokens,
num_encoder_tokens,
):
continue
break

# KVTransfer: the connector uses this info to determine
Expand Down Expand Up @@ -962,6 +974,50 @@ def _build_kv_connector_meta(
) -> KVConnectorMetadata:
return connector.build_connector_meta(scheduler_output)

def _reject_or_defer_request_for_kv_capacity(
self,
request_queue: RequestQueue,
request: Request,
num_new_computed_tokens: int,
new_computed_blocks: KVCacheBlocks | None,
num_external_computed_tokens: int,
num_encoder_tokens: int,
) -> bool:
if self.kv_cache_manager.can_fit_full_sequence_in_empty_cache(
request,
num_new_computed_tokens=num_new_computed_tokens,
new_computed_blocks=new_computed_blocks,
num_external_computed_tokens=num_external_computed_tokens,
num_encoder_tokens=num_encoder_tokens,
):
if request.has_encoder_inputs:
self.encoder_cache_manager.free(request)
return False

request = request_queue.pop_request()
self._reject_request_for_kv_capacity(request)
return True

def _reject_request_for_kv_capacity(self, request: Request) -> None:
reason = (
f"Request with {request.num_tokens} tokens exceeds the available KV "
"cache capacity for this model."
)
logger.warning("%s Rejecting request %s.", reason, request.request_id)
request.stop_reason = reason
request.resumable = False
self.finish_requests(request.request_id, RequestStatus.FINISHED_ERROR)
self.pending_outputs[request.client_index].append(
EngineCoreOutput(
request_id=request.request_id,
new_token_ids=[],
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason,
events=request.take_events(),
trace_headers=request.trace_headers,
)
)

def _preempt_request(self, request: Request, timestamp: float) -> None:
"""Preempt a request and put it back to the waiting queue.

Expand Down Expand Up @@ -1507,6 +1563,12 @@ def update_from_output(
)
)

pending_outputs = getattr(self, "pending_outputs", None)
if pending_outputs:
for client_index, pending in pending_outputs.items():
outputs[client_index].extend(pending)
self.pending_outputs = defaultdict(list)

# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
Expand Down
Loading