Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright © 2023-2024 Apple Inc.

import copy
from collections import deque
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import mlx.core as mx
Expand Down Expand Up @@ -1381,3 +1383,224 @@ def nbytes(self):
if self.keys is None:
return 0
return self.keys.nbytes + self.values.nbytes


@dataclass
class PromptTrieResult:
model: Any
exact: Optional[List[int]] # Exact match found
shorter: Optional[List[int]] # Longest prefix with a value
longer: Optional[List[int]] # Shortest value that extends beyond tokens
common_prefix: int # Length of common prefix with any path


class PromptTrie:
def __init__(self):
self._trie = {}

def add(self, model: Any, tokens: List[int], value: Any):
if model not in self._trie:
self._trie[model] = {}

current = self._trie[model]
for tok in tokens:
if tok not in current:
current[tok] = {}
current = current[tok]
prev = current.get("__value__", None)
current["__value__"] = value
return prev

def get(self, model: Any, tokens: List[int]):
current = self._trie[model]
for tok in tokens:
current = current[tok]
return current["__value__"]

def pop(self, model: Any, tokens: List[int]):
path = [self._trie[model]]
for tok in tokens:
path.append(path[-1][tok])
value = path[-1].pop("__value__")
for i in range(len(tokens), 0, -1):
node = path[i]
parent = path[i - 1]
tok = tokens[i - 1]
if len(node) > 0:
break
del parent[tok]
return value

def pop_prefixes(self, model: Any, tokens: List[int]):
values = []
current = self._trie[model]
for i in range(len(tokens) - 1):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here -1 because we want to pop prefixes but not exact match because we don't want to pop something we just inserted, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep pop_prefixes removes strict prefixes not including this key.

if "__value__" in current:
values.append((i, current.pop("__value__")))
current = current[tokens[i]]
return values

def search(self, model: Any, tokens: List[int]) -> PromptTrieResult:
if model not in self._trie:
return PromptTrieResult(model, None, None, None, 0)

# Walk the tokens as far as we can
current = self._trie[model]
last_index = -1
index = 0
while index < len(tokens) and tokens[index] in current:
current = current[tokens[index]]
if "__value__" in current:
last_index = index
index += 1

# Got an exact match
if last_index == len(tokens) - 1:
return PromptTrieResult(model, tokens, None, None, 0)

# Check if we found a prefix at any point
shorter = None
if last_index > 0:
shorter = tokens[: last_index + 1]

# Check for sequences that are longer
longer = None
common_prefix = index
if index > 0:
best = None
stack = [(current, [])]
while stack:
current, extra = stack.pop()
if "__value__" in current:
if best is None or len(extra) < len(best):
best = extra
elif best is None or len(extra) < len(best):
for tok in current:
stack.append((current[tok], extra + [tok]))
longer = tokens[:index] + best
return PromptTrieResult(model, None, shorter, longer, common_prefix)


class LRUPromptCache:
@dataclass
class CacheEntry:
prompt_cache: List[Any]
nbytes: int

class CacheOrder:
def __init__(self, ordering: List[str] = ["assistant", "user", "system"]):
self._ordering = ordering
self._lrus = {k: deque() for k in ordering}

def __len__(self):
return sum(len(lru) for lru in self._lrus.values())

def push(self, model: Any, tokens: List[Any], cache_type: str = "assistant"):
self._lrus[cache_type].append((model, tokens))

def remove(self, model: Any, tokens: List[Any]):
for cache_type in self._ordering:
try:
self._lrus[cache_type].remove((model, tokens))
break
except ValueError:
pass

def pop(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of questions about the eviction logic in CacheOrder.pop():

  1. With the default ordering ["assistant", "user", "system"], we evict from the assistant queue until its size drops below the user queue, then start evicting user entries. What's the reasoning behind this?
  2. Would it be correct to say that when the cache is trimmable, insert_cache calls pop_prefixes, so in that case, the trie effectively have just one branch, making the eviction ordering irrelevant?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yep exactly this. This is done to ensure that the assistant doesn't kick out of the cache the user message. The reason we don't want the assistant to kick out the user message is that most (if not all) chat templates remove the thinking tokens before the last user message. Even more so they could insert a system message reminder and so on. The interaction looks like this (let's assume 1 slot per cache type)

1st message

<system> <--- system cache
<user> <--- user cache
<thinking>
<assistant>
<tool> <--- assistant cache
<thinking>
<assistant>
<tool> <--- assistant cache

You can see that the 2nd assistant cache will evict the user cache (actually the system cache but bear with me cause it is better for the example) if we simply treat them as a single LRU cache.

Moreover the chat template will remove all the thinking when a second user message arrives which renders the assistant caches useless (if not trimmable).

2nd message

<system> <--- system cache
<user> <--- user cache
<assistant>
<tool>
<assistant>
<tool>
<system reminder>
<user> <--- user cache
...

Since the messages changed after the user message we have to process all the messages up to the next user message which we put in the cache. You can see that we lost the two assistant cache entries, they are useless to us after the new user message.

  1. Yes. When the cache is trimmable the user cache and system cache will always be included so we only need to hold the different diverging branches. pop_prefixes ensures that.

i = 0
while i + 1 < len(self._ordering):
lru_a = self._lrus[self._ordering[i]]
lru_b = self._lrus[self._ordering[i + 1]]
if len(lru_a) >= len(lru_b):
return lru_a.popleft()
i += 1
return lru_b.popleft()

def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63):
self.max_size = max_size
self.max_bytes = max_bytes
self._trie = PromptTrie()
self._lru = LRUPromptCache.CacheOrder()
self._n_bytes = 0

def __len__(self):
return len(self._lru)

@property
def nbytes(self):
return self._n_bytes

def fetch_nearest_cache(self, model: Any, tokens: List[int]):
result = self._trie.search(model, tokens)
if result.exact is not None:
cache_entry = self._trie.get(result.model, result.exact)
return copy.deepcopy(cache_entry.prompt_cache), []

short_length = len(result.shorter) if result.shorter is not None else 0
if result.longer is not None and result.common_prefix > short_length:
cache_entry = self._trie.get(result.model, result.longer)
if can_trim_prompt_cache(cache_entry.prompt_cache):
cache = copy.deepcopy(cache_entry.prompt_cache)
prefix = min(len(tokens) - 1, result.common_prefix)
num_to_trim = len(result.longer) - prefix
trim_prompt_cache(cache, num_to_trim)
return cache, tokens[prefix:]

if short_length > 0:
cache_entry = self._trie.get(result.model, result.shorter)
return copy.deepcopy(cache_entry.prompt_cache), tokens[short_length:]

return None, tokens

def insert_cache(
self,
model: Any,
tokens: List[int],
prompt_cache: List[Any],
*,
cache_type: str = "assistant",
):
# Make the cache entry
entry = LRUPromptCache.CacheEntry(
prompt_cache, sum(c.nbytes for c in prompt_cache)
)

# Insert into the trie and update the byte counter and lru position
self._n_bytes += entry.nbytes
prev = self._trie.add(model, tokens, entry)
if prev is not None:
self._n_bytes -= prev.nbytes
self._lru.remove(model, tokens)
self._lru.push(model, tokens, cache_type)

# If it is a trimmable cache remove all prefixes cause they just take
# space
if can_trim_prompt_cache(prompt_cache):
for prefix_len, entry in self._trie.pop_prefixes(model, tokens):
self._n_bytes -= entry.nbytes
self._lru.remove(model, tokens[:prefix_len])

# Ensure we match the constraints
if len(self._lru) > self.max_size:
model, tokens = self._lru.pop()
entry = self._trie.pop(model, tokens)
self._n_bytes -= entry.nbytes
while self._n_bytes > self.max_bytes:
model, tokens = self._lru.pop()
entry = self._trie.pop(model, tokens)
self._n_bytes -= entry.nbytes

def trim_to(
self, *, n_sequences: Optional[int] = None, n_bytes: Optional[int] = None
):
n_sequences = max(0, n_sequences) if n_sequences is not None else 1 << 63
n_bytes = max(0, n_bytes) if n_bytes is not None else 1 << 63

while len(self._lru) > n_sequences:
model, tokens = self._lru.pop()
entry = self._trie.pop(model, tokens)
self._n_bytes -= entry.nbytes
while self._n_bytes > n_bytes:
model, tokens = self._lru.pop()
entry = self._trie.pop(model, tokens)
self._n_bytes -= entry.nbytes
Loading
Loading