diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ab0bdb83e..843d553ef 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -948,6 +948,9 @@ def __init__( completion_batch_size: int = 32, prefill_batch_size: int = 8, prefill_step_size: int = 2048, + prompt_checkpoint_callback: Optional[ + Callable[[List[Tuple[int, int, List[Any]]]], None] + ] = None, prompt_progress_callback: Optional[ Callable[[List[Tuple[int, int, int]]], None] ] = None, @@ -963,6 +966,7 @@ def __init__( self.prefill_step_size = prefill_step_size self.prefill_batch_size = prefill_batch_size self.completion_batch_size = max(completion_batch_size, prefill_batch_size) + self.prompt_checkpoint_callback = prompt_checkpoint_callback self.prompt_progress_callback = prompt_progress_callback or (lambda *_: None) self._stats = BatchStats() self._next_count = 0 @@ -993,12 +997,16 @@ def insert( caches=None, samplers: list | None = None, logits_processors: list | None = None, + prompt_checkpoints: list | int | None = None, ): uids = [] if max_tokens is None or isinstance(max_tokens, int): max_tokens = [max_tokens or self.max_tokens] * len(prompts) + if prompt_checkpoints is None or isinstance(prompt_checkpoints, int): + prompt_checkpoints = [prompt_checkpoints or -1] * len(prompts) + if caches is None: caches = [None] * len(prompts) for i in range(len(prompts)): @@ -1008,10 +1016,10 @@ def insert( samplers = samplers or [None] * len(prompts) logits_processors = logits_processors or [self.logits_processors] * len(prompts) - for p, m, c, s, lp in zip( - prompts, max_tokens, caches, samplers, logits_processors + for p, m, c, s, lp, pc in zip( + prompts, max_tokens, caches, samplers, logits_processors, prompt_checkpoints ): - self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp)) + self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp, pc)) uids.append(self.uid_count) self.uid_count += 1 # Sort in ascending order of length @@ -1052,12 +1060,28 @@ def prompt_cache_nbytes(self): return total def _process_prompts(self, prompts): - uids, inputs, max_tokens, caches, samplers, logits_processors = zip(*prompts) + ( + uids, + inputs, + max_tokens, + caches, + samplers, + logits_processors, + prompt_checkpoints, + ) = zip(*prompts) lengths = [len(p) for p in inputs] max_length = max(lengths) padding = [max_length - l for l in lengths] + # Get the checkpoint token as an offset from the end of each prompt. + # Then select the largest one so that we perform the checkpoint at + # least `pc` before the end. + prompt_checkpoints = [ + (l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints) + ] + prompt_checkpoint = max(1, max(prompt_checkpoints)) + self._stats.prompt_tokens += sum(lengths) tokens = [mx.array(inp) for inp in inputs] @@ -1070,8 +1094,10 @@ def _process_prompts(self, prompts): inputs = _left_pad_prompts(inputs, max_length=max_length) prompt_cache = _make_cache(self.model, padding, self.max_kv_size) - while inputs.shape[1] > 1: - n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1) + while inputs.shape[1] > prompt_checkpoint: + n_to_process = min( + self.prefill_step_size, inputs.shape[1] - prompt_checkpoint + ) self.model(inputs[:, :n_to_process], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) inputs = inputs[:, n_to_process:] @@ -1090,16 +1116,22 @@ def _process_prompts(self, prompts): # 2. Process # 3. Finalize the KV caches so they are left padded again else: - last_inputs = mx.array([p[-1:] for p in inputs]) + last_inputs = mx.array([p[-prompt_checkpoint:] for p in inputs]) inputs = _right_pad_prompts(inputs, max_length=max_length) prompt_cache = _merge_caches(caches) for c in prompt_cache: - # subtract one from lengths since we don't process the last token during prefill - c.prepare(lengths=[l - 1 for l in lengths], right_padding=padding) + # subtract from lengths since we don't process the last + # `prompt_checkpoint` tokens during prefill + c.prepare( + lengths=[l - prompt_checkpoint for l in lengths], + right_padding=padding, + ) - while inputs.shape[1] > 1: - n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1) + while inputs.shape[1] > prompt_checkpoint: + n_to_process = min( + self.prefill_step_size, inputs.shape[1] - prompt_checkpoint + ) self.model(inputs[:, :n_to_process], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) inputs = inputs[:, n_to_process:] @@ -1117,6 +1149,20 @@ def _process_prompts(self, prompts): for c in prompt_cache: c.finalize() + + # We processed L - prompt_checkpoint tokens so call the checkpoint + # callback. + if self.prompt_checkpoint_callback is not None: + self.prompt_checkpoint_callback( + [ + (uid, prompt_checkpoint, (c.extract(i) for c in prompt_cache)) + for i, uid in enumerate(uids) + ] + ) + # Process the remaining prompt_checkpoint-1 tokens + if prompt_checkpoint > 1: + self.model(inputs[:, : prompt_checkpoint - 1], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) mx.clear_cache() y, logprobs = self._step( diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 36695aa88..80425a434 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -2,6 +2,7 @@ import argparse import copy +import heapq import json import logging import pickle @@ -187,9 +188,32 @@ class LRUPromptCache: @dataclass class CacheEntry: prompt_cache: List[Any] - count: int nbytes: int + class CacheOrder: + def __init__(self): + self._lru_checkpoints = deque() + self._lru = deque() + + def __len__(self): + return len(self._lru) + len(self._lru_checkpoints) + + def push(self, model, tokens, checkpoint: bool = False): + c = self._lru_checkpoints if checkpoint else self._lru + c.append((model, tokens)) + + def remove(self, model, tokens): + try: + self._lru.remove((model, tokens)) + except ValueError: + self._lru_checkpoints.remove((model, tokens)) + + def pop(self): + if len(self._lru) >= len(self._lru_checkpoints): + return self._lru.popleft() + else: + return self._lru_checkpoints.popleft() + @dataclass class SearchResult: model: Any @@ -202,7 +226,7 @@ def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): self.max_size = max_size self.max_bytes = max_bytes self._cache = {} - self._lru = deque() + self._lru = self.CacheOrder() self._n_bytes = 0 def __len__(self): @@ -239,7 +263,7 @@ def _search(self, model, tokens): # Check for caches that are longer longer = None common_prefix = index - if index > 0 and last_cache_index <= 0: + if index > 0: best = None stack = [(current, [])] while stack: @@ -272,32 +296,14 @@ def _delete(self, model, tokens): break del d_prev[t] - logging.debug(f"[LRUPromptCache] Removed {cache_bytes} bytes from the cache") - - def _extract(self, model, tokens): - cache_entry = self._get(model, tokens) - if cache_entry.count == 1: - self._delete(model, tokens) - self._lru.remove((model, tokens)) - return cache_entry - - cache_entry.count -= 1 - return self.CacheEntry( - copy.deepcopy(cache_entry.prompt_cache), 1, cache_entry.nbytes - ) - def fetch_nearest_cache(self, model, tokens): result = self._search(model, tokens) if result.exact is not None: - cache_entry = self._extract(result.model, result.exact) - return cache_entry.prompt_cache, [] - - if result.shorter is not None: - cache_entry = self._extract(result.model, result.shorter) - prefix_len = len(result.shorter) - return cache_entry.prompt_cache, tokens[prefix_len:] + cache_entry = self._get(result.model, result.exact) + return copy.deepcopy(cache_entry.prompt_cache), [] - if result.longer is not None: + short_length = len(result.shorter) if result.shorter is not None else 0 + if result.longer is not None and result.common_prefix > short_length: cache_entry = self._get(result.model, result.longer) if can_trim_prompt_cache(cache_entry.prompt_cache): cache = copy.deepcopy(cache_entry.prompt_cache) @@ -306,32 +312,40 @@ def fetch_nearest_cache(self, model, tokens): trim_prompt_cache(cache, num_to_trim) return cache, tokens[prefix:] + if short_length > 0: + cache_entry = self._get(result.model, result.shorter) + return copy.deepcopy(cache_entry.prompt_cache), tokens[short_length:] + return None, tokens - def insert_cache(self, model, tokens, prompt_cache): + def insert_cache(self, model, tokens, prompt_cache, checkpoint: bool = False): + is_trimmable = can_trim_prompt_cache(prompt_cache) + if model not in self._cache: self._cache[model] = {} current = self._cache[model] - for tok in tokens: + for i, tok in enumerate(tokens): if tok not in current: current[tok] = {} + if is_trimmable and "cache" in current: + self._n_bytes -= current["cache"].nbytes + del current["cache"] + self._lru.remove(model, tokens[:i]) current = current[tok] if "cache" in current: - current["cache"].count += 1 - self._lru.remove((model, tokens)) + self._lru.remove(model, tokens) else: cache_bytes = sum(c.nbytes for c in prompt_cache) - current["cache"] = self.CacheEntry(prompt_cache, 1, cache_bytes) + current["cache"] = self.CacheEntry(prompt_cache, cache_bytes) self._n_bytes += cache_bytes - logging.debug(f"[LRUPromptCache] Adding {cache_bytes} to the cache") - self._lru.append((model, tokens)) + self._lru.push(model, tokens, checkpoint=checkpoint) if len(self._lru) > self.max_size: - model, tokens = self._lru.popleft() + model, tokens = self._lru.pop() self._delete(model, tokens) while self._n_bytes > self.max_bytes and len(self._lru) > 1: - model, tokens = self._lru.popleft() + model, tokens = self._lru.pop() self._delete(model, tokens) def trim_to( @@ -341,12 +355,23 @@ def trim_to( n_bytes = max(0, n_bytes) if n_bytes is not None else 1 << 63 while len(self._lru) > n_sequences: - model, tokens = self._lru.popleft() + model, tokens = self._lru.pop() self._delete(model, tokens) while self._n_bytes > n_bytes: - model, tokens = self._lru.popleft() + model, tokens = self._lru.pop() self._delete(model, tokens) + def log_cache_stats(self): + ncaches, nbytes = len(self), self.nbytes + ntok = ( + len(self._lru._lru_checkpoints[-1][1]) + if len(self._lru._lru_checkpoints) > 0 + else 0 + ) + logging.info( + f"KV Caches: {ncaches} seq, {nbytes/1e9:.2f} GB, latest user cache {ntok} tokens" + ) + @dataclass class ModelDescription: @@ -714,6 +739,24 @@ def _tokenize(self, tokenizer, request, args): else: return tokenizer.encode(request.prompt) + def _compute_prompt_checkpoint(self, tokenizer, request, prompt): + if request.request_type != "chat": + return False, -1 + if request.messages[-1]["role"] != "user": + return False, -1 + + # Save the KV cache at the end of the prompt just before + # the think start token which will likely be removed in the + # next turn. + prompt_checkpoint = -1 + if tokenizer.has_thinking: + for i in range(1, min(11, len(prompt)) - 1, 1): + if prompt[-i] == tokenizer.think_start_id: + prompt_checkpoint = -i - 1 + break + + return True, prompt_checkpoint + def _is_batchable(self, args): if not self.model_provider.is_batchable: return False @@ -744,6 +787,18 @@ def progress_callback(info): if uid in batch_results: batch_results[uid]["rqueue"].put((min(processed, total), total)) + def checkpoint_callback(prompts): + for uid, prompt_end, cache in prompts: + rs = batch_results[uid] + if not rs["checkpoint"]: + continue + self.prompt_cache.insert_cache( + current_model_key, + rs["cache_key"][:-prompt_end], + list(cache), + checkpoint=True, + ) + if self._is_distributed: seed = mx.distributed.all_sum(mx.random.state[0]).view(mx.uint64).item() mx.random.seed(seed) @@ -792,6 +847,7 @@ def progress_callback(info): ) rqueue.put(ctx) + self.prompt_cache.log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( current_model_key, prompt ) @@ -799,9 +855,8 @@ def progress_callback(info): if cache is None: cache = make_prompt_cache(self.model_provider.model) - ncaches, nbytes = len(self.prompt_cache), self.prompt_cache.nbytes - logging.info( - f"We have {ncaches} kv caches that take {nbytes/1e9:.2f} GB" + do_checkpoint, checkpoint_position = ( + self._compute_prompt_checkpoint(tokenizer, request, prompt) ) (uid,) = batch_generator.insert( @@ -810,12 +865,14 @@ def progress_callback(info): caches=[cache], samplers=[_make_sampler(args, tokenizer)], logits_processors=[_make_logits_processors(args)], + prompt_checkpoints=[checkpoint_position], ) batch_results[uid] = { "ctx": ctx, "cache_key": prompt[:], "rqueue": rqueue, "detokenizer": tokenizer.detokenizer, + "checkpoint": do_checkpoint, } # just making sure we don't leave a reference around del cache @@ -852,6 +909,7 @@ def progress_callback(info): completion_batch_size=self.cli_args.decode_concurrency, prefill_batch_size=self.cli_args.prompt_concurrency, prompt_progress_callback=progress_callback, + prompt_checkpoint_callback=checkpoint_callback, ) unprocessed_requests.append((rqueue, request, args)) continue @@ -969,6 +1027,7 @@ def progress(tokens_processed, tokens_total): logits_processors = _make_logits_processors(args) # Load the KV cache + self.prompt_cache.log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( self.model_provider.model_key, prompt ) @@ -979,9 +1038,6 @@ def progress(tokens_processed, tokens_total): if self.model_provider.draft_model is not None: cache += make_prompt_cache(self.model_provider.draft_model) - ncaches, nbytes = len(self.prompt_cache), self.prompt_cache.nbytes - logging.info(f"We have {ncaches} kv caches that take {nbytes/1e9:.2f} GB") - # Process the prompt and generate tokens for gen in stream_generate( model=model, diff --git a/tests/test_server.py b/tests/test_server.py index f8b0be404..636cd6f8a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -61,8 +61,9 @@ def load(self, model, adapter=None, draft_model=None): class MockCache: - def __init__(self, value): + def __init__(self, value, is_trimmable: bool = True): self.value = value + self._is_trimmable = is_trimmable @property def nbytes(self): @@ -71,6 +72,13 @@ def nbytes(self): def __eq__(self, other): return other.value == self.value + def is_trimmable(self): + return self._is_trimmable + + def trim(self, n): + assert self._is_trimmable + return n + class TestServer(unittest.TestCase): @classmethod @@ -436,18 +444,23 @@ def get_kv(n): c[0].update_and_fetch(*get_kv(24)) cache.insert_cache(model, t, c) + # Fetching a cache that is strictly a prefix doesn't remove it from the + # lru cache tokens = tokens + [20] * 5 c, t = cache.fetch_nearest_cache(model, tokens) k, v = c[0].state self.assertTrue((k == v).all().item()) self.assertTrue((k.flatten() == mx.arange(24)).all().item()) self.assertEqual(t, [20] * 5) - self.assertEqual(len(cache._lru), 0) + self.assertEqual(len(cache), 1) + # Inserting a trimmable cache with shared prefix removes the prefixes tokens = tokens + [30] * 3 c[0].update_and_fetch(*get_kv(8)) cache.insert_cache(model, tokens, c) + self.assertEqual(len(cache), 1) + # Fetching a cache with a shared prefix doesn't remove it either tokens = tokens[:26] + [40] * 8 c, t = cache.fetch_nearest_cache(model, tokens) k, v = c[0].state @@ -456,23 +469,34 @@ def get_kv(n): (k.flatten() == mx.concatenate([mx.arange(24), mx.arange(2)])).all().item() ) self.assertEqual(t, [40] * 8) - self.assertEqual(len(cache._lru), 1) + self.assertEqual(len(cache), 1) + + # Inserting a diverged cache actually creates another entry + c[0].update_and_fetch(*get_kv(8)) + cache.insert_cache(model, tokens, c) + self.assertEqual(len(cache), 2) def test_lru(self): cache = LRUPromptCache(max_size=2) model = ("test", None, None) cache.insert_cache(model, [1, 2], [MockCache("test1")]) - cache.insert_cache(model, [1, 2], [MockCache("test1")]) + cache.insert_cache(model, [2, 3], [MockCache("test2")]) c, t = cache.fetch_nearest_cache(model, [1, 2]) self.assertEqual(c, [MockCache("test1")]) self.assertEqual(t, []) - c, t = cache.fetch_nearest_cache(model, [1, 2]) + c, t = cache.fetch_nearest_cache(model, [1]) self.assertEqual(c, [MockCache("test1")]) - self.assertEqual(t, []) - c, t = cache.fetch_nearest_cache(model, [1, 2]) - self.assertEqual(c, None) - self.assertEqual(t, [1, 2]) + self.assertEqual(t, [1]) + c, t = cache.fetch_nearest_cache(model, [1, 3, 4]) + self.assertEqual(c, [MockCache("test1")]) + self.assertEqual(t, [3, 4]) + c, t = cache.fetch_nearest_cache(model, [2, 3, 4]) + self.assertEqual(c, [MockCache("test2")]) + self.assertEqual(t, [4]) + c, t = cache.fetch_nearest_cache(model, [2, 4, 5]) + self.assertEqual(c, [MockCache("test2")]) + self.assertEqual(t, [4, 5]) cache.insert_cache(model, [1, 2], [MockCache("test1")]) cache.insert_cache(model, [2, 3], [MockCache("test2")]) @@ -488,6 +512,29 @@ def test_lru(self): self.assertEqual(c, [MockCache("test3")]) self.assertEqual(t, []) + cache.insert_cache(model, [4, 5], [MockCache("test4")], checkpoint=True) + c, t = cache.fetch_nearest_cache(model, [2, 3]) + self.assertEqual(c, None) + self.assertEqual(t, [2, 3]) + c, t = cache.fetch_nearest_cache(model, [3, 4]) + self.assertEqual(c, [MockCache("test3")]) + self.assertEqual(t, []) + c, t = cache.fetch_nearest_cache(model, [4, 5]) + self.assertEqual(c, [MockCache("test4")]) + self.assertEqual(t, []) + + cache.insert_cache(model, [5, 6], [MockCache("test5")]) + cache.insert_cache(model, [6, 7], [MockCache("test6")]) + c, t = cache.fetch_nearest_cache(model, [5, 6]) + self.assertEqual(c, None) + self.assertEqual(t, [5, 6]) + c, t = cache.fetch_nearest_cache(model, [6, 7]) + self.assertEqual(c, [MockCache("test6")]) + self.assertEqual(t, []) + c, t = cache.fetch_nearest_cache(model, [4, 5]) + self.assertEqual(c, [MockCache("test4")]) + self.assertEqual(t, []) + def test_lru_bytes(self): cache = LRUPromptCache(max_size=100, max_bytes=10) model = ("test", None, None)