-
Notifications
You must be signed in to change notification settings - Fork 530
Refactor LRUPromptCache #1019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor LRUPromptCache #1019
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| # Copyright © 2023-2024 Apple Inc. | ||
|
|
||
| import copy | ||
| from collections import deque | ||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| import mlx.core as mx | ||
|
|
@@ -1381,3 +1383,224 @@ def nbytes(self): | |
| if self.keys is None: | ||
| return 0 | ||
| return self.keys.nbytes + self.values.nbytes | ||
|
|
||
|
|
||
| @dataclass | ||
| class PromptTrieResult: | ||
| model: Any | ||
| exact: Optional[List[int]] # Exact match found | ||
| shorter: Optional[List[int]] # Longest prefix with a value | ||
| longer: Optional[List[int]] # Shortest value that extends beyond tokens | ||
| common_prefix: int # Length of common prefix with any path | ||
|
|
||
|
|
||
| class PromptTrie: | ||
| def __init__(self): | ||
| self._trie = {} | ||
|
|
||
| def add(self, model: Any, tokens: List[int], value: Any): | ||
| if model not in self._trie: | ||
| self._trie[model] = {} | ||
|
|
||
| current = self._trie[model] | ||
| for tok in tokens: | ||
| if tok not in current: | ||
| current[tok] = {} | ||
| current = current[tok] | ||
| prev = current.get("__value__", None) | ||
| current["__value__"] = value | ||
| return prev | ||
|
|
||
| def get(self, model: Any, tokens: List[int]): | ||
| current = self._trie[model] | ||
| for tok in tokens: | ||
| current = current[tok] | ||
| return current["__value__"] | ||
|
|
||
| def pop(self, model: Any, tokens: List[int]): | ||
| path = [self._trie[model]] | ||
| for tok in tokens: | ||
| path.append(path[-1][tok]) | ||
| value = path[-1].pop("__value__") | ||
| for i in range(len(tokens), 0, -1): | ||
| node = path[i] | ||
| parent = path[i - 1] | ||
| tok = tokens[i - 1] | ||
| if len(node) > 0: | ||
| break | ||
| del parent[tok] | ||
| return value | ||
|
|
||
| def pop_prefixes(self, model: Any, tokens: List[int]): | ||
| values = [] | ||
| current = self._trie[model] | ||
| for i in range(len(tokens) - 1): | ||
| if "__value__" in current: | ||
| values.append((i, current.pop("__value__"))) | ||
| current = current[tokens[i]] | ||
| return values | ||
|
|
||
| def search(self, model: Any, tokens: List[int]) -> PromptTrieResult: | ||
| if model not in self._trie: | ||
| return PromptTrieResult(model, None, None, None, 0) | ||
|
|
||
| # Walk the tokens as far as we can | ||
| current = self._trie[model] | ||
| last_index = -1 | ||
| index = 0 | ||
| while index < len(tokens) and tokens[index] in current: | ||
| current = current[tokens[index]] | ||
| if "__value__" in current: | ||
| last_index = index | ||
| index += 1 | ||
|
|
||
| # Got an exact match | ||
| if last_index == len(tokens) - 1: | ||
| return PromptTrieResult(model, tokens, None, None, 0) | ||
|
|
||
| # Check if we found a prefix at any point | ||
| shorter = None | ||
| if last_index > 0: | ||
| shorter = tokens[: last_index + 1] | ||
|
|
||
| # Check for sequences that are longer | ||
| longer = None | ||
| common_prefix = index | ||
| if index > 0: | ||
| best = None | ||
| stack = [(current, [])] | ||
| while stack: | ||
| current, extra = stack.pop() | ||
| if "__value__" in current: | ||
| if best is None or len(extra) < len(best): | ||
| best = extra | ||
| elif best is None or len(extra) < len(best): | ||
| for tok in current: | ||
| stack.append((current[tok], extra + [tok])) | ||
| longer = tokens[:index] + best | ||
| return PromptTrieResult(model, None, shorter, longer, common_prefix) | ||
|
|
||
|
|
||
| class LRUPromptCache: | ||
| @dataclass | ||
| class CacheEntry: | ||
| prompt_cache: List[Any] | ||
| nbytes: int | ||
|
|
||
| class CacheOrder: | ||
| def __init__(self, ordering: List[str] = ["assistant", "user", "system"]): | ||
| self._ordering = ordering | ||
| self._lrus = {k: deque() for k in ordering} | ||
|
|
||
| def __len__(self): | ||
| return sum(len(lru) for lru in self._lrus.values()) | ||
|
|
||
| def push(self, model: Any, tokens: List[Any], cache_type: str = "assistant"): | ||
| self._lrus[cache_type].append((model, tokens)) | ||
|
|
||
| def remove(self, model: Any, tokens: List[Any]): | ||
| for cache_type in self._ordering: | ||
| try: | ||
| self._lrus[cache_type].remove((model, tokens)) | ||
| break | ||
| except ValueError: | ||
| pass | ||
|
|
||
| def pop(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A couple of questions about the eviction logic in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
1st message You can see that the 2nd assistant cache will evict the user cache (actually the system cache but bear with me cause it is better for the example) if we simply treat them as a single LRU cache. Moreover the chat template will remove all the thinking when a second user message arrives which renders the assistant caches useless (if not trimmable). 2nd message Since the messages changed after the user message we have to process all the messages up to the next user message which we put in the cache. You can see that we lost the two assistant cache entries, they are useless to us after the new user message.
|
||
| i = 0 | ||
| while i + 1 < len(self._ordering): | ||
| lru_a = self._lrus[self._ordering[i]] | ||
| lru_b = self._lrus[self._ordering[i + 1]] | ||
| if len(lru_a) >= len(lru_b): | ||
| return lru_a.popleft() | ||
| i += 1 | ||
| return lru_b.popleft() | ||
|
|
||
| def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): | ||
| self.max_size = max_size | ||
| self.max_bytes = max_bytes | ||
| self._trie = PromptTrie() | ||
| self._lru = LRUPromptCache.CacheOrder() | ||
| self._n_bytes = 0 | ||
|
|
||
| def __len__(self): | ||
| return len(self._lru) | ||
|
|
||
| @property | ||
| def nbytes(self): | ||
| return self._n_bytes | ||
|
|
||
| def fetch_nearest_cache(self, model: Any, tokens: List[int]): | ||
| result = self._trie.search(model, tokens) | ||
| if result.exact is not None: | ||
| cache_entry = self._trie.get(result.model, result.exact) | ||
| return copy.deepcopy(cache_entry.prompt_cache), [] | ||
|
|
||
| short_length = len(result.shorter) if result.shorter is not None else 0 | ||
| if result.longer is not None and result.common_prefix > short_length: | ||
| cache_entry = self._trie.get(result.model, result.longer) | ||
| if can_trim_prompt_cache(cache_entry.prompt_cache): | ||
| cache = copy.deepcopy(cache_entry.prompt_cache) | ||
| prefix = min(len(tokens) - 1, result.common_prefix) | ||
| num_to_trim = len(result.longer) - prefix | ||
| trim_prompt_cache(cache, num_to_trim) | ||
| return cache, tokens[prefix:] | ||
|
|
||
| if short_length > 0: | ||
| cache_entry = self._trie.get(result.model, result.shorter) | ||
| return copy.deepcopy(cache_entry.prompt_cache), tokens[short_length:] | ||
|
|
||
| return None, tokens | ||
|
|
||
| def insert_cache( | ||
| self, | ||
| model: Any, | ||
| tokens: List[int], | ||
| prompt_cache: List[Any], | ||
| *, | ||
| cache_type: str = "assistant", | ||
| ): | ||
| # Make the cache entry | ||
| entry = LRUPromptCache.CacheEntry( | ||
| prompt_cache, sum(c.nbytes for c in prompt_cache) | ||
| ) | ||
|
|
||
| # Insert into the trie and update the byte counter and lru position | ||
| self._n_bytes += entry.nbytes | ||
| prev = self._trie.add(model, tokens, entry) | ||
| if prev is not None: | ||
| self._n_bytes -= prev.nbytes | ||
| self._lru.remove(model, tokens) | ||
| self._lru.push(model, tokens, cache_type) | ||
|
|
||
| # If it is a trimmable cache remove all prefixes cause they just take | ||
| # space | ||
| if can_trim_prompt_cache(prompt_cache): | ||
| for prefix_len, entry in self._trie.pop_prefixes(model, tokens): | ||
| self._n_bytes -= entry.nbytes | ||
| self._lru.remove(model, tokens[:prefix_len]) | ||
|
|
||
| # Ensure we match the constraints | ||
| if len(self._lru) > self.max_size: | ||
| model, tokens = self._lru.pop() | ||
| entry = self._trie.pop(model, tokens) | ||
| self._n_bytes -= entry.nbytes | ||
| while self._n_bytes > self.max_bytes: | ||
| model, tokens = self._lru.pop() | ||
| entry = self._trie.pop(model, tokens) | ||
| self._n_bytes -= entry.nbytes | ||
|
|
||
| def trim_to( | ||
| self, *, n_sequences: Optional[int] = None, n_bytes: Optional[int] = None | ||
| ): | ||
| n_sequences = max(0, n_sequences) if n_sequences is not None else 1 << 63 | ||
| n_bytes = max(0, n_bytes) if n_bytes is not None else 1 << 63 | ||
|
|
||
| while len(self._lru) > n_sequences: | ||
| model, tokens = self._lru.pop() | ||
| entry = self._trie.pop(model, tokens) | ||
| self._n_bytes -= entry.nbytes | ||
| while self._n_bytes > n_bytes: | ||
| model, tokens = self._lru.pop() | ||
| entry = self._trie.pop(model, tokens) | ||
| self._n_bytes -= entry.nbytes | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here -1 because we want to pop prefixes but not exact match because we don't want to pop something we just inserted, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep
pop_prefixesremoves strict prefixes not including this key.