Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,10 +1261,11 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
assert len(scheduler.waiting) == 0


@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
)
def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
Expand All @@ -1277,7 +1278,9 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
),
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results
Expand Down Expand Up @@ -1315,6 +1318,12 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):

# All can be scheduled - 1st token.
output = scheduler.schedule()
if is_async:
assert len(scheduler.waiting) == 2
assert scheduler.running == []
_step_until_kv_transfer_finished(scheduler, req_ids)
output = scheduler.schedule()

_assert_right_scheduler_output(
output,
# 2 remote kv cache hits.
Expand Down Expand Up @@ -1367,6 +1376,12 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
# Restarts the preempted request - generate 3rd token.
# This will have a local and remote cache hit.
output = scheduler.schedule()
if is_async:
waiting_req_ids = [req.request_id for req in scheduler.waiting]
assert len(waiting_req_ids) == 1
_step_until_kv_transfer_finished(scheduler, waiting_req_ids)
output = scheduler.schedule()

_assert_right_scheduler_output(
output,
# 1 remote kv_cache hit!
Expand All @@ -1377,6 +1392,8 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_new_reqs == []
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
Expand All @@ -1389,6 +1406,8 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_new_reqs == []
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,12 @@ def schedule(self) -> SchedulerOutput:
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
if request.num_preemptions:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
# indicates that the output is corrupted
self.num_nans_in_logits = 0

# The number of requests being preempted by the scheduler
# The number of times this request has been preempted by the scheduler.
self.num_preemptions = 0

# The number of tokens that have been computed remotely.
Expand Down