[KVConnector] Scheduler: Fix num_computed_tokens after async KV load#34616
Conversation
There was a problem hiding this comment.
Code Review
This pull request correctly refactors the calculation of num_computed_tokens for asynchronous KV cache loading to be done at allocation time, which fixes issues with block alignment and HMA support. The addition of updating num_cached_tokens is also a good catch for consistency.
I've found one issue in the implementation where num_cached_tokens is set after num_computed_tokens is adjusted, which could lead to incorrect metrics and error handling. I've left a specific comment with a suggestion to fix this.
Additionally, while reviewing, I noticed that _update_requests_with_invalid_blocks might still be using the old, block-aligned logic for async requests when handling KV load failures. This could be a related area to investigate to ensure the fix is comprehensive.
90c55f4 to
1452787
Compare
NickLucche
left a comment
There was a problem hiding this comment.
Happy to move over this kind of computationsnum_computed_tokens = len(block_ids) * self.block_size on my end.
But wondering, how can I go about testing the PR?
I remember I had a discussion regarding this bit with @sdavidbd .
He was kind of enough to share 2 examples to wrap my mind around current logic:
Block size = 3
Request tokens: a b c d e
No locally cached tokens
Case 1
Externally cached tokens: a b c
Allocated blocks: 1
# num_computed_tokens = len(block_ids) * self.block_size
num_computed_tokens = 1 * 3 = 3
# num_computed_tokens = min(num_computed_tokens, request.num_tokens)
num_computed_tokens = min(3, 5) = 3
Case 2
Externally cached tokens: a b c d e f
Allocated blocks: 2
# num_computed_tokens = len(block_ids) * self.block_size
num_computed_tokens = 2 * 3 = 6
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
num_computed_tokens = min(6, 5) = 5
# num_computed_tokens == request.num_tokens → trim by 1
num_computed_tokens = 4
I suppose we have these scenarios in units somewhere..?
Anyways would be nice to get an ack from @njhill on the change
This change effects the basic path of connectors using
I actually think it is weird to allow the connector to load more tokens than |
heheda12345
left a comment
There was a problem hiding this comment.
LGTM! leave this PR to @NickLucche
4521d10 to
4a1af87
Compare
|
Fixed some unit tests that needed to be adapted. |
4a1af87 to
cccb355
Compare
|
fixed some more failing unit tests |
|
@njhill could you take a look at this when you find the time |
cccb355 to
795f7d2
Compare
| @@ -771,6 +772,9 @@ def schedule(self) -> SchedulerOutput: | |||
| # into the WAITING_FOR_REMOTE_KV state. | |||
| skipped_waiting_requests.prepend_request(request) | |||
| request.status = RequestStatus.WAITING_FOR_REMOTE_KVS | |||
| # set num_computed_tokens even though KVs are not yet loaded | |||
| # _update_requests_with_invalid_blocks can later adjust it | |||
| request.num_computed_tokens = num_computed_tokens | |||
There was a problem hiding this comment.
rather that piggy backing on num_computed_tokens, i think this implementation would be cleaner if we had another attribute that tracked this information.
num_computed_tokens has a very specific meaning [i.e. these tokens have their KVs ready to go]. This changes the definition of num_computed_tokens in a way that I feel is uncomfortable.
There was a problem hiding this comment.
for instance, this PR immediately introduces a bug here to fail recving. Since num_computed_tokens is prematurely set, we will cache invalid blocks.
This will lead to silent data corruption that will be very hard to track down if a subsequent request get a cache hit on these tokens 🫣
795f7d2 to
00914b1
Compare
| request.num_computed_tokens = request.num_tokens - 1 | ||
|
|
||
| # Count the number of prefix cached tokens. | ||
| if request.num_cached_tokens < 0: |
There was a problem hiding this comment.
why is this change needed? seems unrelated
There was a problem hiding this comment.
It is unrelated, but I noticed that while copying from the sync case that we don't update num_cached_tokens for async requests.
It's a very small fix so I thought I can include it here.
Can also defer to another PR.
NickLucche
left a comment
There was a problem hiding this comment.
LGTM as discussed offline
Previously, following a successful async KV loading, the scheduler would update the request num_computed_tokens by counting the number of allocated blocks and multiplying by the block size, and rounding down if necessary to request.num_tokens. This worked as long the last external token aligned to either: 1. The end of the request 2. A full block However, if the last external token was in the middle of a block, and not reaching the end of the prompt, then this logic was wrongfully assuming num_computed_tokens going over the last external token. This would yield wrong KV data. Furthermore, the current logic hard-codes an assumption for a single KV cache group. This commit changes the setting of num_computed_tokens at allocation time, as it is already the case for sync KV loading, which fixes the above issue, as well as removes the hard-coding assumption of a single KV cache group. Additionally, we add a missing update to num_cached_tokens. Signed-off-by: Or Ozeri <oro@il.ibm.com>
45384bb to
2300d5f
Compare
Previously, following a successful async KV loading, the scheduler
would update the request num_computed_tokens by counting the number of
allocated blocks and multiplying by the block size,
and rounding down if necessary to request.num_tokens.
This worked as long the last external token aligned to either:
However, if the last external token was in the middle of a block, and not reaching the end of the prompt,
then this logic was wrongfully assuming num_computed_tokens going over the last external token.
This would yield wrong KV data.
Furthermore, the current logic hard-codes an assumption for a single KV cache group.
This PR changes the setting of num_computed_tokens at allocation time, as it is
already the case for sync KV loading, which fixes the above issue, as well as removes the hard-coding assumption
of a single KV cache group.
Additionally, we add a missing update to num_cached_tokens.