-
Notifications
You must be signed in to change notification settings - Fork 5.3k
[Auto Sync] Update schedule_batch.py, schedule_policy.py, b... (20251122) #13763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -180,10 +180,19 @@ def _compute_prefix_matches( | |
| extra_key = r.extra_key | ||
|
|
||
| # NOTE: the prefix_indices must always be aligned with last_node | ||
| r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = ( | ||
| self.tree_cache.match_prefix( | ||
| rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key) | ||
| ) | ||
| match_result = self.tree_cache.match_prefix( | ||
| rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key) | ||
| ) | ||
| ( | ||
| r.prefix_indices, | ||
| r.last_node, | ||
| r.last_host_node, | ||
| r.host_hit_length, | ||
| ) = ( | ||
| match_result.device_indices, | ||
| match_result.last_device_node, | ||
| match_result.last_host_node, | ||
| match_result.host_hit_length, | ||
| ) | ||
|
Comment on lines
+186
to
196
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tuple unpacking here is a bit verbose. You can directly assign the attributes of r.prefix_indices = match_result.device_indices
r.last_node = match_result.last_device_node
r.last_host_node = match_result.last_host_node
r.host_hit_length = match_result.host_hit_length |
||
|
|
||
| # NOTE(sang): This logic is for in-batch prefix caching; | ||
|
|
@@ -194,12 +203,11 @@ def _compute_prefix_matches( | |
| # threshold means we cannot use in-batch prefix caching for short prefixes. | ||
| # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). | ||
| if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: | ||
| in_batch_matching_prefixes, _, _, _ = ( | ||
| self.waiting_queue_radix_tree.match_prefix( | ||
| rid=r.rid, | ||
| key=RadixKey(token_ids=prefix_ids, extra_key=extra_key), | ||
| ) | ||
| match_result = self.waiting_queue_radix_tree.match_prefix( | ||
| rid=r.rid, | ||
| key=RadixKey(token_ids=prefix_ids, extra_key=extra_key), | ||
| ) | ||
| in_batch_matching_prefixes = match_result.device_indices | ||
| if ( | ||
| len(in_batch_matching_prefixes) | ||
| >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -520,9 +520,13 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: | |
| self.req_to_token_pool.mamba_pool.free(mamba_value_forked) | ||
|
|
||
| # The prefix indices could be updated, reuse it | ||
| new_indices, new_last_node, _, _ = self.match_prefix( | ||
| match_result = self.match_prefix( | ||
| RadixKey(page_aligned_token_ids, req.extra_key) | ||
| ) | ||
| (new_indices, new_last_node) = ( | ||
| match_result.device_indices, | ||
| match_result.last_device_node, | ||
| ) | ||
|
Comment on lines
+526
to
+529
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| if not mamba_exist: | ||
| assert torch.equal(new_last_node.mamba_value, mamba_value_forked) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -428,7 +428,11 @@ def cache_unfinished_req(self, req: Req, chunked=False): | |
| ) | ||
|
|
||
| # The prefix indices could be updated, reuse it | ||
| new_indices, new_last_node, _, _ = self.match_prefix(radix_key) | ||
| match_result = self.match_prefix(radix_key) | ||
| (new_indices, new_last_node) = ( | ||
| match_result.device_indices, | ||
| match_result.last_device_node, | ||
| ) | ||
|
Comment on lines
+432
to
+435
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| assert len(new_indices) == len(keys), f"{len(new_indices)=}, {len(keys)=}" | ||
|
|
||
| self.req_to_token_pool.write( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -550,9 +550,14 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: | |
| ) | ||
|
|
||
| # The prefix indices could be updated, reuse it | ||
| new_indices, new_last_node, _, _ = self.match_prefix( | ||
| match_result = self.match_prefix( | ||
| RadixKey(page_aligned_token_ids, req.extra_key) | ||
| ) | ||
| (new_indices, new_last_node) = ( | ||
| match_result.device_indices, | ||
| match_result.last_device_node, | ||
| ) | ||
|
Comment on lines
+556
to
+559
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| assert old_prefix_len <= len( | ||
| new_indices | ||
| ), f"{req.prefix_indices=}, {new_indices=}" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tuple unpacking here is a bit verbose. You can directly assign the attributes of
match_resultto theselfattributes for better readability and conciseness.