Skip to content

[KVConnector] Scheduler: Fix num_computed_tokens after async KV load#34616

Merged
NickLucche merged 1 commit intovllm-project:mainfrom
orozery:scheduler-async-load-num-computed-tokens
Mar 5, 2026
Merged

[KVConnector] Scheduler: Fix num_computed_tokens after async KV load#34616
NickLucche merged 1 commit intovllm-project:mainfrom
orozery:scheduler-async-load-num-computed-tokens

Conversation

@orozery
Copy link
Collaborator

@orozery orozery commented Feb 16, 2026

Previously, following a successful async KV loading, the scheduler
would update the request num_computed_tokens by counting the number of
allocated blocks and multiplying by the block size,
and rounding down if necessary to request.num_tokens.
This worked as long the last external token aligned to either:

  1. The end of the request
  2. A full block

However, if the last external token was in the middle of a block, and not reaching the end of the prompt,
then this logic was wrongfully assuming num_computed_tokens going over the last external token.
This would yield wrong KV data.
Furthermore, the current logic hard-codes an assumption for a single KV cache group.

This PR changes the setting of num_computed_tokens at allocation time, as it is
already the case for sync KV loading, which fixes the above issue, as well as removes the hard-coding assumption
of a single KV cache group.
Additionally, we add a missing update to num_cached_tokens.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly refactors the calculation of num_computed_tokens for asynchronous KV cache loading to be done at allocation time, which fixes issues with block alignment and HMA support. The addition of updating num_cached_tokens is also a good catch for consistency.

I've found one issue in the implementation where num_cached_tokens is set after num_computed_tokens is adjusted, which could lead to incorrect metrics and error handling. I've left a specific comment with a suggestion to fix this.

Additionally, while reviewing, I noticed that _update_requests_with_invalid_blocks might still be using the old, block-aligned logic for async requests when handling KV load failures. This could be a related area to investigate to ensure the fix is comprehensive.

@orozery orozery force-pushed the scheduler-async-load-num-computed-tokens branch from 90c55f4 to 1452787 Compare February 16, 2026 10:35
Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to move over this kind of computationsnum_computed_tokens = len(block_ids) * self.block_size on my end.

But wondering, how can I go about testing the PR?
I remember I had a discussion regarding this bit with @sdavidbd .
He was kind of enough to share 2 examples to wrap my mind around current logic:

Block size = 3
Request tokens: a b c d e
No locally cached tokens

Case 1
Externally cached tokens: a b c
Allocated blocks: 1
# num_computed_tokens = len(block_ids) * self.block_size
num_computed_tokens = 1 * 3 = 3
# num_computed_tokens = min(num_computed_tokens, request.num_tokens)
num_computed_tokens = min(3, 5) = 3

Case 2
Externally cached tokens: a b c d e f
Allocated blocks: 2
# num_computed_tokens = len(block_ids) * self.block_size
num_computed_tokens = 2 * 3 = 6
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
num_computed_tokens = min(6, 5) = 5
# num_computed_tokens == request.num_tokens → trim by 1
num_computed_tokens = 4

I suppose we have these scenarios in units somewhere..?

Anyways would be nice to get an ack from @njhill on the change

@orozery
Copy link
Collaborator Author

orozery commented Feb 17, 2026

But wondering, how can I go about testing the PR?

This change effects the basic path of connectors using load_kv_async = True, so it should be tested by most e2e tests.
Specifically for CPU offloading there's e2e test_cpu_offloading.py::test_cpu_offloading.
There is already a unit-test which should exercise this basic path in test_scheduler.py::test_kv_connector_basic.

Case 2
Externally cached tokens: a b c d e f

I actually think it is weird to allow the connector to load more tokens than Request.num_tokens.
I actually added an assert for that in this PR. @sdavidbd WDYT?

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! leave this PR to @NickLucche

@orozery orozery added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 18, 2026
@orozery orozery force-pushed the scheduler-async-load-num-computed-tokens branch 3 times, most recently from 4521d10 to 4a1af87 Compare February 24, 2026 18:49
@mergify mergify bot added the kv-connector label Feb 24, 2026
@orozery
Copy link
Collaborator Author

orozery commented Feb 24, 2026

Fixed some unit tests that needed to be adapted.

@orozery
Copy link
Collaborator Author

orozery commented Feb 24, 2026

fixed some more failing unit tests

@NickLucche
Copy link
Collaborator

@njhill could you take a look at this when you find the time

@orozery orozery force-pushed the scheduler-async-load-num-computed-tokens branch from cccb355 to 795f7d2 Compare March 1, 2026 15:44
@@ -771,6 +772,9 @@ def schedule(self) -> SchedulerOutput:
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
# set num_computed_tokens even though KVs are not yet loaded
# _update_requests_with_invalid_blocks can later adjust it
request.num_computed_tokens = num_computed_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather that piggy backing on num_computed_tokens, i think this implementation would be cleaner if we had another attribute that tracked this information.

num_computed_tokens has a very specific meaning [i.e. these tokens have their KVs ready to go]. This changes the definition of num_computed_tokens in a way that I feel is uncomfortable.

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for instance, this PR immediately introduces a bug here to fail recving. Since num_computed_tokens is prematurely set, we will cache invalid blocks.

This will lead to silent data corruption that will be very hard to track down if a subsequent request get a cache hit on these tokens 🫣

@orozery orozery force-pushed the scheduler-async-load-num-computed-tokens branch from 795f7d2 to 00914b1 Compare March 2, 2026 18:08
request.num_computed_tokens = request.num_tokens - 1

# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change needed? seems unrelated

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unrelated, but I noticed that while copying from the sync case that we don't update num_cached_tokens for async requests.
It's a very small fix so I thought I can include it here.
Can also defer to another PR.

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as discussed offline

Previously, following a successful async KV loading, the scheduler
would update the request num_computed_tokens by counting the number of
allocated blocks and multiplying by the block size,
and rounding down if necessary to request.num_tokens.
This worked as long the last external token aligned to either:
1. The end of the request
2. A full block

However, if the last external token was in the middle of a block, and not reaching the end of the prompt,
then this logic was wrongfully assuming num_computed_tokens going over the last external token.
This would yield wrong KV data.
Furthermore, the current logic hard-codes an assumption for a single KV cache group.

This commit changes the setting of num_computed_tokens at allocation time, as it is
already the case for sync KV loading, which fixes the above issue, as well as removes the hard-coding assumption
of a single KV cache group.
Additionally, we add a missing update to num_cached_tokens.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the scheduler-async-load-num-computed-tokens branch from 45384bb to 2300d5f Compare March 5, 2026 09:27
@NickLucche NickLucche enabled auto-merge (squash) March 5, 2026 10:30
@NickLucche NickLucche merged commit 612e772 into vllm-project:main Mar 5, 2026
54 checks passed
@github-project-automation github-project-automation bot moved this from Backlog to Done in Metrics & Tracing Mar 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants