-
-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[KVConnector] Scheduler: Fix num_computed_tokens after async KV load #34616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -638,6 +638,7 @@ def schedule(self) -> SchedulerOutput: | |
| num_computed_tokens = ( | ||
| num_new_local_computed_tokens + num_external_computed_tokens | ||
| ) | ||
| assert num_computed_tokens <= request.num_tokens | ||
| else: | ||
| # KVTransfer: WAITING reqs have num_computed_tokens > 0 | ||
| # after async KV recvs are completed. | ||
|
|
@@ -773,6 +774,20 @@ def schedule(self) -> SchedulerOutput: | |
| # into the WAITING_FOR_REMOTE_KV state. | ||
| skipped_waiting_requests.prepend_request(request) | ||
| request.status = RequestStatus.WAITING_FOR_REMOTE_KVS | ||
| # Set num_computed_tokens even though KVs are not yet loaded. | ||
| # request.num_computed_tokens will not be used anywhere until | ||
| # the request finished the KV transfer. | ||
| # | ||
| # If a transfer error is reported by the connector, | ||
| # request.num_computed_tokens will be re-set accordingly in | ||
| # _update_requests_with_invalid_blocks. | ||
| # | ||
| # When the transfer is finished, either successfully or not, | ||
| # request.num_computed_tokens will correctly reflect the number | ||
| # of computed tokens. | ||
| # _update_waiting_for_remote_kv will then cache | ||
| # only the successfully loaded tokens. | ||
| request.num_computed_tokens = num_computed_tokens | ||
| continue | ||
|
|
||
| self.running.append(request) | ||
|
|
@@ -1994,17 +2009,17 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: | |
| self.failed_recving_kv_req_ids.remove(request.request_id) | ||
| else: | ||
| # Now that the blocks are ready, actually cache them. | ||
| (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) | ||
| num_computed_tokens = len(block_ids) * self.block_size | ||
| # Handle the case where num request tokens less than one block. | ||
| num_computed_tokens = min(num_computed_tokens, request.num_tokens) | ||
| if num_computed_tokens == request.num_tokens: | ||
| num_computed_tokens -= 1 | ||
| # This will cache the blocks iff caching is enabled. | ||
| self.kv_cache_manager.cache_blocks(request, num_computed_tokens) | ||
| self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) | ||
|
|
||
| # Update the request state for scheduling. | ||
| request.num_computed_tokens = num_computed_tokens | ||
| # on a full prompt hit, we need to re-compute the last token | ||
| # in order to be able to sample the next token | ||
| if request.num_computed_tokens == request.num_tokens: | ||
| request.num_computed_tokens = request.num_tokens - 1 | ||
|
|
||
| # Count the number of prefix cached tokens. | ||
| if request.num_cached_tokens < 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this change needed? seems unrelated
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is unrelated, but I noticed that while copying from the sync case that we don't update |
||
| request.num_cached_tokens = request.num_computed_tokens | ||
orozery marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Return that we are ready. | ||
| self.finished_recving_kv_req_ids.remove(request.request_id) | ||
|
|
@@ -2084,13 +2099,8 @@ def _update_requests_with_invalid_blocks( | |
| # We iterate only over blocks that may contain externally computed | ||
| # tokens | ||
| if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: | ||
| # Async loading. If num_computed_tokens is set it implies we | ||
| # already processed some block failures for it in a prior step | ||
| req_num_computed_tokens = ( | ||
| request.num_computed_tokens | ||
| if req_id in self.failed_recving_kv_req_ids | ||
| else len(req_block_ids) * self.block_size | ||
| ) | ||
| # Async loading. num_computed_tokens does not include new tokens | ||
| req_num_computed_tokens = request.num_computed_tokens | ||
| else: | ||
| # Sync loading. num_computed_tokens includes new tokens | ||
| req_num_computed_tokens = request.num_cached_tokens | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rather that piggy backing on
num_computed_tokens, i think this implementation would be cleaner if we had another attribute that tracked this information.num_computed_tokenshas a very specific meaning [i.e. these tokens have their KVs ready to go]. This changes the definition ofnum_computed_tokensin a way that I feel is uncomfortable.