Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/kv_connector/unit/test_error_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):

assert len(fail_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens

(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_async_recompute_blocks_not_cached_when_invalid(
# request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens

# get the allocated block IDs
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_async_load_failure(

assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3

Expand All @@ -103,7 +103,7 @@ def test_async_load_failure(
min_invalid_block_idx * scheduler.block_size
)
else:
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request2.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
Expand Down Expand Up @@ -305,7 +305,7 @@ def test_async_progressive_load_failure(

assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_basic_lifecycle():
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == NUM_TOKENS

# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
Expand Down
42 changes: 26 additions & 16 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def schedule(self) -> SchedulerOutput:
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
)
assert num_computed_tokens <= request.num_tokens
else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
Expand Down Expand Up @@ -773,6 +774,20 @@ def schedule(self) -> SchedulerOutput:
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
# Set num_computed_tokens even though KVs are not yet loaded.
# request.num_computed_tokens will not be used anywhere until
# the request finished the KV transfer.
#
# If a transfer error is reported by the connector,
# request.num_computed_tokens will be re-set accordingly in
# _update_requests_with_invalid_blocks.
#
# When the transfer is finished, either successfully or not,
# request.num_computed_tokens will correctly reflect the number
# of computed tokens.
# _update_waiting_for_remote_kv will then cache
# only the successfully loaded tokens.
request.num_computed_tokens = num_computed_tokens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather that piggy backing on num_computed_tokens, i think this implementation would be cleaner if we had another attribute that tracked this information.

num_computed_tokens has a very specific meaning [i.e. these tokens have their KVs ready to go]. This changes the definition of num_computed_tokens in a way that I feel is uncomfortable.

continue

self.running.append(request)
Expand Down Expand Up @@ -1994,17 +2009,17 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool:
self.failed_recving_kv_req_ids.remove(request.request_id)
else:
# Now that the blocks are ready, actually cache them.
(block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
# Handle the case where num request tokens less than one block.
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
# This will cache the blocks iff caching is enabled.
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens)

# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
# on a full prompt hit, we need to re-compute the last token
# in order to be able to sample the next token
if request.num_computed_tokens == request.num_tokens:
request.num_computed_tokens = request.num_tokens - 1

# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change needed? seems unrelated

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unrelated, but I noticed that while copying from the sync case that we don't update num_cached_tokens for async requests.
It's a very small fix so I thought I can include it here.
Can also defer to another PR.

request.num_cached_tokens = request.num_computed_tokens

# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
Expand Down Expand Up @@ -2084,13 +2099,8 @@ def _update_requests_with_invalid_blocks(
# We iterate only over blocks that may contain externally computed
# tokens
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# Async loading. If num_computed_tokens is set it implies we
# already processed some block failures for it in a prior step
req_num_computed_tokens = (
request.num_computed_tokens
if req_id in self.failed_recving_kv_req_ids
else len(req_block_ids) * self.block_size
)
# Async loading. num_computed_tokens does not include new tokens
req_num_computed_tokens = request.num_computed_tokens
else:
# Sync loading. num_computed_tokens includes new tokens
req_num_computed_tokens = request.num_cached_tokens
Expand Down