Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,12 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):

# All can be scheduled - 1st token.
output = scheduler.schedule()

# verify request-level cache hit stats are set
for request in requests:
assert request.num_cached_tokens == NUM_MATCHED_NEW_TOKENS
assert request.num_external_computed_tokens == NUM_MATCHED_NEW_TOKENS

if is_async:
assert _num_waiting_requests(scheduler) == 2
assert scheduler.running == []
Expand Down Expand Up @@ -1607,6 +1613,12 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
# Restarts the preempted request - generate 3rd token.
# This will have a local and remote cache hit.
output = scheduler.schedule()

# verify request level hit stats are NOT re-set
for request in requests:
assert request.num_cached_tokens == NUM_MATCHED_NEW_TOKENS
assert request.num_external_computed_tokens == NUM_MATCHED_NEW_TOKENS

if is_async:
waiting_req_ids = [
req.request_id
Expand Down Expand Up @@ -1649,6 +1661,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1

# final verification request-level cache hit stats are NOT re-set
for request in requests:
assert request.num_cached_tokens == NUM_MATCHED_NEW_TOKENS
assert request.num_external_computed_tokens == NUM_MATCHED_NEW_TOKENS


def make_output(scheduler: Scheduler):
return ModelRunnerOutput(
Expand Down
31 changes: 26 additions & 5 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,6 @@ def schedule(self) -> SchedulerOutput:
step_skipped_waiting.prepend_request(request)
continue

request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens

connector_prefix_cache_queries = (
Expand All @@ -632,6 +631,16 @@ def schedule(self) -> SchedulerOutput:
num_new_local_computed_tokens + num_external_computed_tokens
)
assert num_computed_tokens <= request.num_tokens

if request.num_preemptions == 0:
# For request-level stats,
# track hits only the first time a request gets scheduled.
# If allocation will later fail, we will get back here
# the next time the request re-tries scheduling
request.num_cached_tokens = num_computed_tokens
request.num_external_computed_tokens = (
num_external_computed_tokens
)
else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
Expand Down Expand Up @@ -802,9 +811,6 @@ def schedule(self) -> SchedulerOutput:
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
Expand Down Expand Up @@ -2158,6 +2164,7 @@ def _update_requests_with_invalid_blocks(
req_num_computed_tokens = request.num_computed_tokens
else:
# Sync loading. num_computed_tokens includes new tokens
# TODO(orozery): Bug below! Incorrect for preempted requests!
req_num_computed_tokens = request.num_cached_tokens

req_num_computed_blocks = (
Expand Down Expand Up @@ -2192,7 +2199,19 @@ def _update_requests_with_invalid_blocks(
req_num_computed_tokens - request.num_computed_tokens
)
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
if request.num_preemptions == 0:
# For request-level stats,
# track hits only the first time a request gets scheduled.
num_local_hit_tokens = (
request.num_cached_tokens - request.num_external_computed_tokens
)
assert num_local_hit_tokens >= 0
request.num_cached_tokens = min(
request.num_cached_tokens, request.num_computed_tokens
)
request.num_external_computed_tokens = max(
0, request.num_cached_tokens - num_local_hit_tokens
)
# collect invalid block and all downstream dependent blocks
if evict_blocks:
blocks_to_evict.update(req_block_ids[idx:])
Expand All @@ -2204,6 +2223,8 @@ def _update_requests_with_invalid_blocks(
# Revert to considering only cached tokens as computed.
# Currently this only applies to sync loading; Async
# loading does not yet support block sharing

# TODO(orozery): Bug! Incorrect computation for preempted requests!
total_affected_tokens += (
request.num_computed_tokens - request.num_cached_tokens
)
Expand Down
12 changes: 8 additions & 4 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ def __init__(
self.all_token_ids = ConstantList(self._all_token_ids)
# trace_headers
self.trace_headers = trace_headers
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1

# True if this request is scheduled as a non-final prefill chunk.
self.is_prefill_chunk = False
Expand All @@ -159,7 +156,14 @@ def __init__(
# The number of times this request has been preempted by the scheduler.
self.num_preemptions = 0

# The number of tokens that have been computed remotely.
# Fields used for request-level cache stats
# These fields are only set on the first time a request gets scheduled
# Cache hits following request preemption are currently not tracked.

# Total number of KV cache hit tokens:
# local prefix cache hits + external (connector-based) hits
self.num_cached_tokens = 0
# Number of external tokens hit (excluding local prefix cache hits)
self.num_external_computed_tokens = 0

self.block_hashes: list[BlockHash] = []
Expand Down
Loading