EAGLE cache fix for SWARadixCache#11231
Conversation
Summary of ChangesHello @ispobock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces several enhancements and fixes to the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for the EAGLE speculative decoding algorithm within the SWARadixCache. The changes primarily adapt the cache to handle bigram keys used by EAGLE, which involves modifying key conversion, prefix matching, and insertion logic. Supporting changes include plumbing the is_eagle flag and dtype through various components. My review identifies a potential bug in the insert method when handling None values for EAGLE, and a maintainability concern regarding in-place modification of RadixKey objects which violates its type hint. Overall, the changes are logical and address the caching issue for EAGLE.
| if value is None: | ||
| value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) |
There was a problem hiding this comment.
This logic can lead to a bug when is_eagle is true. In that case, key.token_ids becomes a list of tuples (bigrams) after self.key_convert_fn. torch.tensor on a list of tuples will create a 2D tensor, while the rest of the code expects value to be a 1D tensor of KV cache indices. Although this code path (value is None) might not be triggered in production, it's a latent bug that can affect tests or future use cases.
I suggest raising an error for this case or creating a dummy 1D tensor if it's needed for tests.
| if value is None: | |
| value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) | |
| if value is None: | |
| if self.is_eagle: | |
| # This path is not expected in production for EAGLE. | |
| # The value should be a 1D tensor of indices, but creating it from bigram keys is ambiguous. | |
| raise NotImplementedError("insert with value=None is not supported for EAGLE mode.") | |
| value = torch.tensor(key.token_ids, dtype=torch.int64) |
| The last node create a new child if the prefix is shorter | ||
| than the last node's value. | ||
| """ | ||
| key.token_ids = self.key_convert_fn(key.token_ids) |
There was a problem hiding this comment.
This line modifies the input key object in-place, which can be an unexpected side effect for callers. Additionally, when is_eagle is true, key.token_ids is converted from List[int] to List[Tuple[int, int]], which violates the type hint in the RadixKey class definition (List[int]). This makes the code harder to understand and maintain.
While creating a new RadixKey object might have performance implications, it would be safer. A less disruptive change would be to update the type hint for RadixKey.token_ids to List[Union[int, Tuple[int, int]]] and add a comment here explaining the in-place modification.
hanming-lu
left a comment
There was a problem hiding this comment.
On the high level makes sense. Finding it not straightforward to fully understand the +1 and -1 logics and their implications. Was there a doc for the radix cache changes?
|
|
||
| token_ids = (req.origin_input_ids + req.output_ids)[:-1] | ||
| all_token_len = len(token_ids) | ||
| actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len |
There was a problem hiding this comment.
can we have more comments on the reason behind this -1?
There was a problem hiding this comment.
If we convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. (len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)
So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
| page_aligned_token_len = ( | ||
| page_aligned_len + 1 if self.is_eagle else page_aligned_len | ||
| ) | ||
|
|
||
| old_prefix_len = len(req.prefix_indices) | ||
| if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: | ||
| # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) | ||
| old_prefix_len -= 1 |
There was a problem hiding this comment.
Wishing for more comments here 🙏 The +1 and -1 logics and their implications are not straightforward to fully understand :(
There was a problem hiding this comment.
In the chunked prefill case, the chunked kv should be cached in the Radix cache. But in EAGLE case, the last token will not be inserted into the tree due to the shorter length of bigram key. But we still add it to req.prefix_indices (ref), since the kv is still in the sequence. Here we do old_prefix_len - 1 to just make sure the additional kv should be freed correctly, or we will get the memory leak.
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
Motivation
follow-up of #10846
Accept Length Test
start server with this line commented:
run requests:
This PR w/ radix cache:
main w/ radix cache:
main w/o radix cache: