-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Fix eagle radix cache #10846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix eagle radix cache #10846
Changes from all commits
c2e27b4
6c0a58e
a44351b
b188f44
0bb18da
e680778
d6f0206
a401aef
74ce777
aaeea64
0da4297
7397449
520d404
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |
| import time | ||
| from collections import defaultdict | ||
| from functools import partial | ||
| from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union | ||
| from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1): | |
| return (key.extra_key, plain_key) | ||
|
|
||
|
|
||
| def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]: | ||
| # EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target | ||
| # [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)] | ||
| if len(tokens) < 2: | ||
| return [] | ||
| if isinstance(tokens[0], tuple): | ||
| return tokens | ||
| return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)] | ||
|
|
||
|
|
||
| class RadixCache(BasePrefixCache): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -168,13 +178,15 @@ def __init__( | |
| disable: bool = False, | ||
| enable_kv_cache_events: bool = False, | ||
| eviction_policy: str = "lru", | ||
| is_eagle: bool = False, | ||
| ): | ||
| self.req_to_token_pool = req_to_token_pool | ||
| self.token_to_kv_pool_allocator = token_to_kv_pool_allocator | ||
| self.page_size = page_size | ||
| self.disable = disable | ||
| self.enable_kv_cache_events = enable_kv_cache_events | ||
| self.kv_event_queue = [] | ||
| self.is_eagle = is_eagle | ||
|
|
||
| if self.token_to_kv_pool_allocator: | ||
| self.device = self.token_to_kv_pool_allocator.device | ||
|
|
@@ -188,6 +200,11 @@ def __init__( | |
| self.key_match_fn = partial(_key_match_paged, page_size=page_size) | ||
| self.get_child_key_fn = partial(get_child_key, page_size=page_size) | ||
|
|
||
| if is_eagle: | ||
| self.key_convert_fn = _convert_to_bigram_key | ||
| else: | ||
| self.key_convert_fn = lambda key: key | ||
|
|
||
| if eviction_policy.lower() == "lru": | ||
| self.eviction_strategy: EvictionStrategy = LRUStrategy() | ||
| elif eviction_policy.lower() == "lfu": | ||
|
|
@@ -248,6 +265,8 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: | |
| to expose a precise boundary; this structural refinement improves | ||
| subsequent match efficiency and does not duplicate data. | ||
| """ | ||
| key.token_ids = self.key_convert_fn(key.token_ids) | ||
|
|
||
| if self.disable or len(key) == 0: | ||
| return MatchResult( | ||
| device_indices=torch.empty( | ||
|
|
@@ -278,8 +297,15 @@ def insert(self, key: RadixKey, value=None, chunked=False): | |
| if self.disable: | ||
| return 0 | ||
|
|
||
| key.token_ids = self.key_convert_fn(key.token_ids) | ||
|
|
||
| if value is None: | ||
| value = torch.tensor(key.token_ids, dtype=torch.int64) | ||
|
|
||
| if self.is_eagle: | ||
| # Make sure the value len equal to the EAGLE bigram key len | ||
| value = value[: len(key)] | ||
|
|
||
| return self._insert_helper(self.root_node, key, value) | ||
|
|
||
| def cache_finished_req(self, req: Req): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. while
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The swa cache inherits from BaseRadixCache, so it seems all the changes should be implemented again on it. HiCache is from RadixCache, we just need to do some adaptation on it with less override. But for HiCache, the main thing I'm concerning is that the chunked prefill size is a little changed. If the chunked prefill size is 64, actually only 63 bigram keys are inserted to the tree. Maybe it's not efficient for cache offloading with block.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is what we are doing is primarily to resolve conflict with eagle workers since it shares the same radix tree but has its own pool, but not to have hicache support for eagle workers, i.e., eagle workers to fetch kv caches from host memory, which seems unnecessary and potentially complicated. Is it correct?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the kv cache for eagle worker is unnecessary to store into host memory since it's only one layer. If we use HiCache only for target model, can we still share the kv indices between target and draft pool?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I think it should be fine just wanted to confirm that we are aligned on this
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc: @merrymercy |
||
|
|
@@ -293,28 +319,39 @@ def cache_finished_req(self, req: Req): | |
| return | ||
|
|
||
| token_ids = (req.origin_input_ids + req.output_ids)[:-1] | ||
| all_token_len = len(token_ids) | ||
| actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len | ||
| kv_indices = self.req_to_token_pool.req_to_token[ | ||
| req.req_pool_idx, : len(token_ids) | ||
| req.req_pool_idx, :all_token_len | ||
| ] | ||
|
|
||
| if self.page_size != 1: | ||
| page_aligned_len = len(kv_indices) // self.page_size * self.page_size | ||
| page_aligned_len = actual_kv_len // self.page_size * self.page_size | ||
| page_aligned_kv_indices = kv_indices[:page_aligned_len].to( | ||
| dtype=torch.int64, copy=True | ||
| ) | ||
| self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) | ||
| else: | ||
| page_aligned_len = len(kv_indices) | ||
| page_aligned_len = actual_kv_len | ||
| page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) | ||
| if self.is_eagle: | ||
| self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) | ||
xiezhq-hermann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| page_aligned_token_len = ( | ||
| page_aligned_len + 1 if self.is_eagle else page_aligned_len | ||
| ) | ||
|
|
||
| old_prefix_len = len(req.prefix_indices) | ||
| if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: | ||
| # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) | ||
| old_prefix_len -= 1 | ||
|
|
||
| # Radix Cache takes one ref in memory pool | ||
| new_prefix_len = self.insert( | ||
| RadixKey(token_ids[:page_aligned_len], req.extra_key), | ||
| RadixKey(token_ids[:page_aligned_token_len], req.extra_key), | ||
| page_aligned_kv_indices, | ||
| ) | ||
| self.token_to_kv_pool_allocator.free( | ||
| kv_indices[len(req.prefix_indices) : new_prefix_len] | ||
| ) | ||
| self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len]) | ||
|
|
||
| # Remove req slot release the cache lock | ||
| self.req_to_token_pool.free(req.req_pool_idx) | ||
|
|
@@ -326,49 +363,73 @@ def cache_unfinished_req(self, req: Req, chunked=False): | |
| return | ||
|
|
||
| token_ids = req.fill_ids | ||
| all_token_len = len(token_ids) | ||
| # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key | ||
| actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len | ||
| kv_indices = self.req_to_token_pool.req_to_token[ | ||
| req.req_pool_idx, : len(token_ids) | ||
| req.req_pool_idx, :all_token_len | ||
| ] | ||
|
|
||
| if self.page_size != 1: | ||
| page_aligned_len = len(kv_indices) // self.page_size * self.page_size | ||
| page_aligned_len = actual_kv_len // self.page_size * self.page_size | ||
| page_aligned_kv_indices = kv_indices[:page_aligned_len].to( | ||
| dtype=torch.int64, copy=True | ||
| ) | ||
| else: | ||
| page_aligned_len = len(kv_indices) | ||
| page_aligned_len = actual_kv_len | ||
| page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) | ||
| page_aligned_token_ids = token_ids[:page_aligned_len] | ||
|
|
||
| # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1 | ||
| page_aligned_token_len = ( | ||
| page_aligned_len + 1 if self.is_eagle else page_aligned_len | ||
| ) | ||
| page_aligned_token_ids = token_ids[:page_aligned_token_len] | ||
|
|
||
| old_prefix_len = len(req.prefix_indices) | ||
| if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: | ||
| # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) | ||
| old_prefix_len -= 1 | ||
|
|
||
| # Radix Cache takes one ref in memory pool | ||
| new_prefix_len = self.insert( | ||
| RadixKey(page_aligned_token_ids, req.extra_key), | ||
| page_aligned_kv_indices, | ||
| chunked=chunked, | ||
| ) | ||
| self.token_to_kv_pool_allocator.free( | ||
| kv_indices[len(req.prefix_indices) : new_prefix_len] | ||
| ) | ||
| self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len]) | ||
|
|
||
| # The prefix indices could be updated, reuse it | ||
| new_indices, new_last_node, _, _ = self.match_prefix( | ||
| RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key) | ||
| ) | ||
| self.req_to_token_pool.write( | ||
| (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), | ||
| new_indices[len(req.prefix_indices) :], | ||
| (req.req_pool_idx, slice(old_prefix_len, len(new_indices))), | ||
| new_indices[old_prefix_len:], | ||
| ) | ||
|
|
||
| # The last_matched_prefix_len is not always equal to len(req.prefix_indices) | ||
| # since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree. | ||
| # It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak. | ||
| # So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly. | ||
| req.last_matched_prefix_len = len(new_indices) | ||
|
|
||
| self.dec_lock_ref(req.last_node) | ||
| self.inc_lock_ref(new_last_node) | ||
|
|
||
| # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later | ||
| if self.page_size != 1: | ||
| # Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req. | ||
| req.prefix_indices = torch.cat( | ||
| [new_indices, kv_indices[len(new_indices) :]] | ||
| ) | ||
| else: | ||
| req.prefix_indices = new_indices | ||
| if self.is_eagle: | ||
| # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill | ||
| req.prefix_indices = torch.cat( | ||
| [new_indices, kv_indices[actual_kv_len:]] | ||
| ) | ||
| else: | ||
| req.prefix_indices = new_indices | ||
| req.last_node = new_last_node | ||
|
|
||
| def pretty_print(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hiradix (and other trees like swa) override the
insertfunction, would that be a problem since eagle worker shared the same tree?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in current design, we need to adapt this change to other trees like swa and hiradix if they override these functions. This PR just makes the main radix tree ready. HiCache and swa need extra work and test to make them ready.