-
Notifications
You must be signed in to change notification settings - Fork 5k
Evict swa kv cache during decoding #17220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
56a9e56
31cd842
1aab927
a982ada
3d35138
829c299
e1bd8ce
46915dc
08f03a5
822c525
943d6f6
db2c5ba
b05fd78
61d78ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -363,6 +363,9 @@ def __init__(self, params: CacheInitParams, sliding_window_size: int): | |
| ##### Public API ##### | ||
|
|
||
| def supports_swa(self) -> bool: | ||
| assert ( | ||
| self.sliding_window_size is not None | ||
| ), "sliding_window_size must be set for SWARadixCache" | ||
| return True | ||
|
|
||
| def reset(self) -> None: | ||
|
|
@@ -418,7 +421,13 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: | |
| last_host_node=last_node, | ||
| ) | ||
|
|
||
| def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int: | ||
| def insert( | ||
| self, | ||
| key: RadixKey, | ||
| value=None, | ||
| prev_prefix_len: int = 0, | ||
| swa_evicted_seqlen: int = 0, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. you are actually evicting them instead of unlocking
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the generated tokens are not in the tree until the sequence finished. So we cannot unlock it. |
||
| ) -> int: | ||
| if self.disable: | ||
| return 0 | ||
|
|
||
|
|
@@ -431,7 +440,9 @@ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int: | |
| # Make sure the value len equal to the EAGLE bigram key len | ||
| value = value[: len(key)] | ||
|
|
||
| return self._insert_helper(self.root_node, key, value, prev_prefix_len) | ||
| return self._insert_helper( | ||
| self.root_node, key, value, prev_prefix_len, swa_evicted_seqlen | ||
| ) | ||
|
|
||
| def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: | ||
| """Cache request when it finishes.""" | ||
|
|
@@ -481,6 +492,7 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: | |
| RadixKey(token_ids[:page_aligned_token_len], req.extra_key), | ||
| page_aligned_kv_indices, | ||
| old_prefix_len, | ||
| req.swa_evicted_seqlen, | ||
| ) | ||
| else: | ||
| self.token_to_kv_pool_allocator.free( | ||
|
|
@@ -878,6 +890,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod | |
| new_node.full_lock_ref = child.full_lock_ref | ||
| new_node.swa_lock_ref = child.swa_lock_ref | ||
| new_node.key = child.key[:split_len] | ||
| assert len(new_node.key) > 0, f"new_node.key should not be empty" | ||
| new_node.value = child.value[:split_len].clone() | ||
| # parent inherits the swa_uuid from child for swa lock ref | ||
| new_node.swa_uuid = child.swa_uuid | ||
|
|
@@ -891,6 +904,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod | |
| self.swa_lru_list.remove_node(child) | ||
| child.parent = new_node | ||
| child.key = child.key[split_len:] | ||
| assert len(child.key) > 0, f"child.key should not be empty" | ||
| child.value = child.value[split_len:].clone() | ||
| new_node.parent.children[self.get_child_key_fn(key)] = new_node | ||
|
|
||
|
|
@@ -904,7 +918,12 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod | |
| return new_node | ||
|
|
||
| def _insert_helper( | ||
| self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int | ||
| self, | ||
| node: TreeNode, | ||
| key: RadixKey, | ||
| value, | ||
| update_kv_after_len: int, | ||
| swa_evicted_seqlen: int = 0, | ||
| ) -> int: | ||
| # Update the last access time from root to leaf, so that | ||
| # swa will tombstone the node closer to root first | ||
|
|
@@ -936,23 +955,44 @@ def _insert_helper( | |
| # contains tombstone. If this is the case and we don't update the kv value, then | ||
| # the prefill prefix matching will stuck. | ||
| if update_kv_after_len < total_prefix_length + prefix_len: | ||
| first_diff_idx = max(0, update_kv_after_len - total_prefix_length) | ||
| # For page_size > 1 and chunked prefill case, update_kv_after_len may be not page-aligned due to a trailing partial page | ||
| # (kept in the request but not inserted into the radix tree) appended to prefix_indices. | ||
| if node.swa_tombstone: | ||
| assert ( | ||
| node.swa_lock_ref == 0 | ||
| ), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}" | ||
| self.token_to_kv_pool_allocator.free(node.value[first_diff_idx:]) | ||
| node.value = value[:prefix_len] | ||
| node.swa_tombstone = False | ||
|
|
||
| # insert the node into the lru lists | ||
| self.swa_lru_list.insert_mru(node) | ||
|
|
||
| self.swa_evictable_size_ += len(node.value) | ||
| assert ( | ||
| swa_evicted_seqlen % self.page_size == 0 | ||
| ), f"swa_evicted_seqlen must be page aligned, {swa_evicted_seqlen=}, {self.page_size=}" | ||
| if swa_evicted_seqlen <= total_prefix_length: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add some comments for each branch please |
||
| # Branch 1: all swa tokens of value[:prefix_len] are not evicted, so we can insert it to the tree directly. | ||
| # Free full tokens in the original tree node. | ||
| self.token_to_kv_pool_allocator.free(node.value[:prefix_len]) | ||
| # Overwrite the new value in request to the tree node. | ||
| node.value = value[:prefix_len].clone() | ||
| node.swa_tombstone = False | ||
| self.swa_lru_list.insert_mru(node) | ||
| self.swa_evictable_size_ += len(node.value) | ||
| elif swa_evicted_seqlen < total_prefix_length + prefix_len: | ||
| # Branch 2: part of swa tokens of value[:prefix_len] are evicted, so we need to split the node and insert the value to new node. | ||
| start_update_idx = swa_evicted_seqlen - total_prefix_length | ||
| self.token_to_kv_pool_allocator.free( | ||
| node.value[start_update_idx:prefix_len] | ||
| ) | ||
| self._split_node(node.key, node, start_update_idx) | ||
| # Here node is the new node after split, so we can overwrite the value to the new node. | ||
| # The old node is still swa tombstone and the full token is not freed. | ||
| node.value = value[start_update_idx:prefix_len].clone() | ||
| self.token_to_kv_pool_allocator.free(value[:start_update_idx]) | ||
| node.swa_tombstone = False | ||
| self.swa_lru_list.insert_mru(node) | ||
| self.swa_evictable_size_ += len(node.value) | ||
| else: | ||
| # Branch 3: all swa tokens of value[:prefix_len] are evicted, so we don't need to update the node. | ||
| self.token_to_kv_pool_allocator.free(value[:prefix_len]) | ||
| else: | ||
| self.token_to_kv_pool_allocator.free( | ||
| value[first_diff_idx:prefix_len] | ||
| ) | ||
| # The node is not tombstone, so we don't need to update the node. | ||
| self.token_to_kv_pool_allocator.free(value[:prefix_len]) | ||
|
|
||
| total_prefix_length += prefix_len | ||
| key = key[prefix_len:] | ||
|
|
@@ -962,16 +1002,43 @@ def _insert_helper( | |
| child_key = self.get_child_key_fn(key) | ||
|
|
||
| if len(key): | ||
| new_node = TreeNode() | ||
| new_node.parent = node | ||
| new_node.key = key | ||
| new_node.value = value | ||
| self.full_lru_list.insert_mru(new_node) | ||
| if ( | ||
| swa_evicted_seqlen > total_prefix_length | ||
| and swa_evicted_seqlen < total_prefix_length + len(key) | ||
| ): | ||
| swa_tombstone_len = swa_evicted_seqlen - total_prefix_length | ||
| node = self._add_new_node( | ||
| node, | ||
| key[:swa_tombstone_len], | ||
| value[:swa_tombstone_len], | ||
| swa_tombstone=True, | ||
| ) | ||
| key = key[swa_tombstone_len:] | ||
| value = value[swa_tombstone_len:] | ||
|
|
||
| self._add_new_node(node, key, value, swa_tombstone=False) | ||
| return total_prefix_length | ||
|
|
||
| def _add_new_node( | ||
| self, | ||
| parent: TreeNode, | ||
| key: RadixKey, | ||
| value: torch.Tensor, | ||
| swa_tombstone: bool = False, | ||
| ) -> TreeNode: | ||
| assert len(key) > 0, f"key should not be empty" | ||
| new_node = TreeNode() | ||
| new_node.parent = parent | ||
| new_node.key = key | ||
| new_node.value = value.clone() | ||
| new_node.swa_tombstone = swa_tombstone | ||
| parent.children[self.get_child_key_fn(key)] = new_node | ||
| self.full_lru_list.insert_mru(new_node) | ||
| self.full_evictable_size_ += len(value) | ||
| if not swa_tombstone: | ||
| self.swa_lru_list.insert_mru(new_node) | ||
| node.children[child_key] = new_node | ||
| self.full_evictable_size_ += len(value) | ||
| self.swa_evictable_size_ += len(value) | ||
| return total_prefix_length | ||
| return new_node | ||
|
|
||
| def _iteratively_delete_tombstone_leaf( | ||
| self, node: TreeNode | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neat