From 91eff7de795afe8ac2cc559cd3d84b96f29ceb4f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 17 Mar 2026 22:09:46 -0700 Subject: [PATCH 01/34] Add prompt processing batch and generation batch --- mlx_lm/examples/batch_generate_response.py | 5 +- mlx_lm/generate.py | 310 +++++++++++++++++++++ mlx_lm/models/cache.py | 10 + 3 files changed, 324 insertions(+), 1 deletion(-) diff --git a/mlx_lm/examples/batch_generate_response.py b/mlx_lm/examples/batch_generate_response.py index 6d07b4fba..908fefc09 100644 --- a/mlx_lm/examples/batch_generate_response.py +++ b/mlx_lm/examples/batch_generate_response.py @@ -27,9 +27,12 @@ # Set `verbose=True` to see generation statistics result = batch_generate( - model, tokenizer, prompts, verbose=False, return_prompt_caches=True + model, tokenizer, prompts, verbose=False, return_prompt_caches=True, max_tokens=2048 ) print(result.texts[-1]) +import pdb + +pdb.set_trace() prompts = [ "Could you summarize that?", diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 22531c644..8714ceaa8 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -6,6 +6,7 @@ import json import sys import time +from collections import deque from dataclasses import dataclass from functools import partial from typing import ( @@ -1405,6 +1406,315 @@ def batch_generate( return BatchResponse(texts, stats, caches) +class SequenceMatcher: + def __init__(self, sequences): + self._trie = {} + + # Make the trie + for seq in sequences: + current = self._trie + for x in seq: + current = current.setdefault(x, {}) + current["__seq__"] = tuple(seq) + + # Propagate failure links and matched sequences + queue = deque() + for x, child in self._trie.items(): + child["__fail__"] = self._trie + queue.append(child) + while queue: + parent = queue.popleft() + for x, child in parent.items(): + if x in ("__fail__", "__seq__"): + continue + queue.append(child) + failure_node = parent["__fail__"] + while x not in failure_node and failure_node is not self._trie: + failure_node = failure_node["__fail__"] + if x in failure_node: + child["__fail__"] = failure_node[x] + else: + child["__fail__"] = self._trie + if "__seq__" not in child and "__seq__" in child["__fail__"]: + child["__seq__"] = child["__fail__"]["__seq__"] + + def make_state(self): + return self._trie + + def match(self, state, x): + while x not in state and state is not self._trie: + state = state["__fail__"] + if x in state: + state = state[x] + return state, state.get("__seq__") + + +class PromptProcessingBatch: + def __init__( + self, + model, + uids, + caches, + tokens=None, + prefill_step_size: int = 2048, + samplers=None, + fallback_sampler=None, + logits_processors=None, + stop_matcher=None, + max_tokens=None, + ): + self.model = model + self.uids = uids + self.prompt_cache = _merge_caches(caches) + self.tokens = tokens or [[] for uid in self.uids] + + self.prefill_step_size = prefill_step_size + self.samplers = samplers or [] + self.fallback_sampler = fallback_sampler or (lambda x: mx.argmax(x, axis=-1)) + self.logits_processors = logits_processors or [] + self.stop_matcher = stop_matcher or SequenceMatcher([]) + self.max_tokens = max_tokens or 1 << 31 + + def prompt(self, tokens): + if len(self.uids) != len(tokens): + raise ValueError("The batch length doesn't match the number of inputs") + + # Add the tokens to the self.tokens so they represent the tokens + # contained in the KV Cache. + for sti, ti in zip(self.tokens, tokens): + sti += ti + + # Calculate if we need to pad + lengths = [len(p) for p in tokens] + max_length = max(lengths) + padding = [max_length - l for l in lengths] + max_padding = max(padding) + + # Prepare the caches and inputs. Right pad if needed otherwise just + # cast to array. + if max_padding > 0: + tokens = _right_pad_prompts(tokens, max_length=max_length) + for c in self.prompt_cache: + c.prepare(lengths=lengths, right_padding=padding) + else: + tokens = mx.array(tokens) + + # Actual prompt processing loop + while tokens.shape[1] > 0: + n_to_process = min(self.prefill_step_size, tokens.shape[1]) + self.model(tokens[:, :n_to_process], cache=self.prompt_cache) + mx.eval([c.state for c in self.prompt_cache]) + mx.clear_cache() + tokens = tokens[:, n_to_process:] + + # Finalize the cache if there was any padding + if max_padding > 0: + for c in self.prompt_cache: + c.finalize() + mx.eval([c.state for c in self.prompt_cache]) + mx.clear_cache() + + def generate(self, tokens): + self.prompt([t[:-1] for t in tokens]) + last_token = mx.array([t[-1] for t in tokens]) + + generation = GenerationBatch( + self.model, + self.uids, + last_token, + self.prompt_cache, + self.tokens, + self.samplers, + self.fallback_sampler, + self.logits_processors, + self.stop_matcher, + self.max_tokens, + ) + + self.uids = [] + self.prompt_cache = None + self.tokens = [] + + return generation + + +class GenerationBatch: + @dataclass + class Response: + uid: int + token: int + logprobs: mx.array + finish_reason: Optional[str] + stop_sequence: Optional[List[int]] + prompt_cache: Optional[List[Any]] + all_tokens: Optional[List[int]] + + def __init__( + self, + model, + uids: List[int], + inputs: mx.array, + prompt_cache: List[Any], + tokens: List[List[int]], + samplers, + fallback_sampler, + logits_processors, + stop_matcher, + max_tokens, + ): + self.model = model + self.uids = uids + self.prompt_cache = prompt_cache + self.tokens = tokens + + self.samplers = samplers + self.fallback_sampler = fallback_sampler + self.logits_processors = logits_processors + self.stop_matcher = stop_matcher + self.max_tokens = max_tokens + + if self.samplers and len(self.samplers) != len(self.uids): + raise ValueError("Insufficient number of samplers provided") + if self.logits_processors and len(self.logits_processors) != len(self.uids): + raise ValueError("Insufficient number of logits_processors provided") + + self._current_tokens = None + self._current_logprobs = None + self._next_tokens = inputs + self._next_logprobs = None + self._token_context = [mx.array(t[-256:]) for t in tokens] + self._num_tokens = [0] * len(self.uids) + self._matcher_states = [self.stop_matcher.make_state()] * len(self.uids) + + self._step() + + def _step(self): + # Move next to current and kick of the computation of next. Also assign + # current to a local var for convenience. + self._current_tokens = self._next_tokens + self._current_logprobs = self._next_logprobs + inputs = self._current_tokens + + # Update the token context that will be used by the logits processors + for i, ti in enumerate(self._token_context): + self._token_context[i] = mx.concatenate( + [ti[1:] if len(ti) == 256 else ti, inputs[i : i + 1]] + ) + + # Forward pass + logits = self.model(inputs[:, None], cache=self.prompt_cache) + logits = logits[:, -1, :] + + # Logit processors + if any(self.logits_processors): + processed_logits = [] + for e in range(len(self.uids)): + sample_logits = logits[e : e + 1] + for processor in logits_processors[e]: + sample_logits = processor(tokens[e], sample_logits) + processed_logits.append(sample_logits) + logits = mx.concatenate(processed_logits, axis=0) + + # Normalize the logits + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + + # Samplers + if any(self.samplers): + all_samples = [] + for e in range(len(self.uids)): + sample_sampler = self.samplers[e] or self.fallback_sampler + sampled = sample_sampler(logprobs[e : e + 1]) + all_samples.append(sampled) + sampled = mx.concatenate(all_samples, axis=0) + else: + sampled = self.fallback_sampler(logprobs) + + # Assign the next step to member variables and start computing it + # asynchronously + self._next_tokens = sampled + self._next_logprobs = list(logprobs) + mx.async_eval(self._next_tokens, self._next_logprobs, self._token_context) + + # Eval the current tokens and current logprobs. After that also add + # them to self.tokens so that it always represents the tokens contained + # in the KV Cache. + mx.eval(inputs, self._current_logprobs) + inputs = inputs.tolist() + for sti, ti in zip(self.tokens, inputs): + sti.append(ti) + return inputs, self._current_logprobs + + def _extract_cache(self, idx: int): + return [c.extract(idx) for c in self.prompt_cache] + + def _filter(self, keep: List[int]): + # Visible state + self.uids = [self.uids[idx] for idx in keep] + if not keep: + self.prompt_cache.clear() + else: + for c in self.prompt_cache: + c.filter(keep) + self.tokens = [self.tokens[idx] for idx in keep] + if any(self.samplers): + self.samplers = [self.samplers[idx] for idx in keep] + if any(self.logits_processors): + self.logits_processors = [self.logits_processors[idx] for idx in keep] + + # Internal state + self._next_tokens = self._next_tokens[keep] if keep else None + self._next_logprobs = [self._next_logprobs[idx] for idx in keep] + self._token_context = [self._token_context[idx] for idx in keep] + self._num_tokens = [self._num_tokens[idx] for idx in keep] + self._matcher_states = [self._matcher_states[idx] for idx in keep] + + def next(self): + tokens, logprobs = self._step() + + keep = [] + responses = [] + for i in range(len(self.uids)): + # Make the response object + resp = self.Response( + self.uids[i], + tokens[i], + logprobs[i], + None, + None, + None, + None, + ) + + # Check if we reached the limit of tokens to generate + self._num_tokens[i] += 1 + if self._num_tokens[i] >= self.max_tokens: + resp.finish_reason = "length" + + # Check if we produced a token sequence that means we got to stop + self._matcher_states[i], match = self.stop_matcher.match( + self._matcher_states[i], tokens[i] + ) + if match is not None: + resp.finish_reason = "stop" + resp.stop_sequence = match + + # If we are done add the cache and corresponding tokens to the + # response. + if resp.finish_reason is not None: + resp.prompt_cache = self._extract_cache(i) + resp.all_tokens = self.tokens[i] + else: + keep.append(i) + + responses.append(resp) + + # Remove all sequences that are done. + if len(keep) < len(self.uids): + self._filter(keep) + + return responses + + def main(): parser = setup_arg_parser() args = parser.parse_args() diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 756ce4ecf..cc8342916 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1024,6 +1024,11 @@ def extract(self, idx): def merge(cls, caches): lengths = [c.size() for c in caches] max_length = max(lengths) + + # No cache has content so make an empty one + if max_length == 0: + return BatchKVCache([0] * len(caches)) + padding = [max_length - l for l in lengths] B = len(caches) H = max(c.keys.shape[1] for c in caches if c.keys is not None) @@ -1351,6 +1356,11 @@ def merge(cls, caches): offsets = [c.offset for c in caches] lengths = [c.size() for c in caches] max_length = max(lengths) + + # No cache has content so make an empty one + if max_length == 0: + return cls(caches[0].max_size, [0] * len(caches)) + padding = [max_length - l for l in lengths] B = len(caches) H = max(c.keys.shape[1] for c in caches if c.keys is not None) From 9462b9267b50f07ab7a6d4fda165038432817451 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 21 Mar 2026 04:19:59 -0700 Subject: [PATCH 02/34] Fix ArraysCache merge of empty arrays --- mlx_lm/models/cache.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index cc8342916..630e65004 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -662,6 +662,11 @@ def merge(cls, caches): n_state = len(caches[0].cache) B = len(caches) cache = cls(n_state) + + # All caches are empty so return early + if all(c.empty() for c in caches): + return cache + for e in range(n_state): c_init = next(iter(c[e] for c in caches if c[e] is not None)) shape = list(c_init.shape) From 1a555b7e20bef300a364dc7dc649175e1a06bb45 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 23 Mar 2026 01:25:17 -0700 Subject: [PATCH 03/34] Start a BatchGenerator2 --- mlx_lm/generate.py | 285 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 279 insertions(+), 6 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 8714ceaa8..8595943aa 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1450,6 +1450,13 @@ def match(self, state, x): class PromptProcessingBatch: + @dataclass + class Response: + uid: int + progress: tuple + end_of_segment: bool + end_of_prompt: bool + def __init__( self, model, @@ -1473,7 +1480,60 @@ def __init__( self.fallback_sampler = fallback_sampler or (lambda x: mx.argmax(x, axis=-1)) self.logits_processors = logits_processors or [] self.stop_matcher = stop_matcher or SequenceMatcher([]) - self.max_tokens = max_tokens or 1 << 31 + self.max_tokens = max_tokens or [1 << 31] * len(self.uids) + + def __len__(self): + return len(self.uids) + + def extend(self, batch): + if not self.samplers: + self.samplers = [None] * len(self.uids) + if not self.logits_processors: + self.logits_processors = [None] * len(self.uids) + samplers = batch.samplers + if not samplers: + samplers = [None] * len(batch.uids) + logits_processors = batch.logits_processors + if not logits_processors: + logits_processors = [None] * len(batch.uids) + + self.uids.extend(batch.uids) + self.prompt_cache.extend(batch.prompt_cache) + self.tokens.extend(batch.tokens) + self.samplers.extend(batch.samplers) + self.logits_processors.extend(batch.logits_processors) + self.max_tokens.extend(batch.max_tokens) + + if not any(self.samplers): + self.samplers = [] + if not any(self.logits_processors): + self.logits_processors = [] + + def split(self, indices: List[int]): + indices = sorted(indices) + indices_left = sorted(set(range(len(self.uids))) - set(indices)) + new_batch = copy.deepcopy(self) + self.filter(indices_left) + new_batch.filter(indices) + + return new_batch + + def filter(self, keep: List[int]): + self.uids = [self.uids[idx] for idx in keep] + if not keep: + self.prompt_cache.clear() + else: + for c in self.prompt_cache: + c.filter(keep) + self.tokens = [self.tokens[idx] for idx in keep] + if any(self.samplers): + self.samplers = [self.samplers[idx] for idx in keep] + if any(self.logits_processors): + self.logits_processors = [self.logits_processors[idx] for idx in keep] + self.max_tokens = [self.max_tokens[idx] for idx in keep] + + def extract_cache(self, idx: int): + return [c.extract(idx) for c in self.prompt_cache] def prompt(self, tokens): if len(self.uids) != len(tokens): @@ -1515,7 +1575,8 @@ def prompt(self, tokens): mx.clear_cache() def generate(self, tokens): - self.prompt([t[:-1] for t in tokens]) + if any(len(t) > 1 for t in tokens): + self.prompt([t[:-1] for t in tokens]) last_token = mx.array([t[-1] for t in tokens]) generation = GenerationBatch( @@ -1588,6 +1649,12 @@ def __init__( self._step() + def __len__(self): + return len(self.uids) + + def extend(self, batch): + pass + def _step(self): # Move next to current and kick of the computation of next. Also assign # current to a local var for convenience. @@ -1644,10 +1711,10 @@ def _step(self): sti.append(ti) return inputs, self._current_logprobs - def _extract_cache(self, idx: int): + def extract_cache(self, idx: int): return [c.extract(idx) for c in self.prompt_cache] - def _filter(self, keep: List[int]): + def filter(self, keep: List[int]): # Visible state self.uids = [self.uids[idx] for idx in keep] if not keep: @@ -1701,7 +1768,7 @@ def next(self): # If we are done add the cache and corresponding tokens to the # response. if resp.finish_reason is not None: - resp.prompt_cache = self._extract_cache(i) + resp.prompt_cache = self.extract_cache(i) resp.all_tokens = self.tokens[i] else: keep.append(i) @@ -1710,11 +1777,217 @@ def next(self): # Remove all sequences that are done. if len(keep) < len(self.uids): - self._filter(keep) + self.filter(keep) return responses +class BatchGenerator2: + def __init__( + self, + model, + max_tokens: int = 128, + stop_tokens: Optional[List[List[int]]] = None, + sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[ + List[Callable[[mx.array, mx.array], mx.array]] + ] = None, + completion_batch_size: int = 32, + prefill_batch_size: int = 8, + prefill_step_size: int = 2048, + ): + self.model = model + self.max_tokens = max_tokens + self.stop_matcher = SequenceMatcher(stop_tokens) + self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + self.logits_processors = logits_processors or [] + self.uid_count = 0 + self.prefill_step_size = prefill_step_size + self.prefill_batch_size = prefill_batch_size + self.completion_batch_size = max(completion_batch_size, prefill_batch_size) + + self._uid_count = 0 + self._prompt_batch = None + self._generation_batch = None + self._unprocessed_sequences = deque() + self._currently_processing = [] + + if mx.metal.is_available(): + self._old_wired_limit = mx.set_wired_limit( + mx.device_info()["max_recommended_working_set_size"] + ) + else: + self._old_wired_limit = None + + def close(self): + if self._old_wired_limit is not None: + mx.synchronize(generation_stream) + mx.set_wired_limit(self._old_wired_limit) + self._old_wired_limit = None + + def __del__(self): + self.close() + + def insert( + self, + prompts, + max_tokens=None, + caches=None, + all_tokens=None, + samplers=None, + logits_processors=None, + ): + self.insert_segments( + [[p] for p in prompts], + max_tokens, + caches, + all_tokens, + samplers, + logits_processors, + ) + + def insert_segments( + self, + segments, + max_tokens=None, + caches=None, + all_tokens=None, + samplers=None, + logits_processors=None, + ): + uids = [] + + if max_tokens is None or isinstance(max_tokens, int): + max_tokens = [max_tokens or self.max_tokens] * len(prompts) + + if caches is None: + caches = [None] * len(segments) + for i in range(len(segments)): + if caches[i] is None: + caches[i] = cache.make_prompt_cache(self.model) + + all_tokens = all_tokens or [[] for _ in range(len(segments))] + + samplers = samplers or [None] * len(segments) + logits_processors = logits_processors or [self.logits_processors] * len( + segments + ) + + for seq, m, c, at, s, lp in zip( + segments, max_tokens, caches, all_tokens, samplers, logits_processors + ): + seq = list(seq) + if len(seq[-1]) != 1: + seq.append(seq[-1][-1:]) + seq[-2] = seq[-2][:-1] + self._unprocessed_sequences.append((self._uid_count, seq, m, c, at, s, lp)) + uids.append(self._uid_count) + self._uid_count += 1 + + return uids + + def _make_batch(self, n): + uids = [] + caches = [] + tokens = [] + samplers = [] + logits_processors = [] + max_tokens = [] + while n > 0: + n -= 1 + sequence = self._unprocessed_sequences.popleft() + uids.append(sequence[0]) + caches.append(sequence[3]) + tokens.append(sequence[4]) + samplers.append(sequence[5]) + logits_processors.append(sequence[6]) + max_tokens.append(sequence[2]) + self._currently_processing.append( + [sequence[1], 0, sum(len(s) for s in sequence[1])] + ) + return PromptProcessingBatch( + self.model, + uids, + caches, + tokens, + self.prefill_step_size, + samplers, + logits_processors, + self.stop_matcher, + max_tokens, + ) + + def _next(self): + generation_responses = [] + prompt_responses = [] + + # Generate tokens first + if len(self._generation_batch) > 0: + generation_responses = self._generation_batch.next() + + # Exit early because we already have our hands full with decoding + if len(self._generation_batch) >= self.completion_batch_size: + return prompt_responses, generation_responses + + # Check if we have sequences and add them to the prompt batch + n = min( + self.prefill_batch_size - len(self._prompt_batch), + self.completion_batch_size - len(self._generation_batch), + len(self._unprocessed_sequences), + ) + if n > 0: + self._prompt_batch.extend(self._make_batch(n)) + + # Split the prompt sequences to the ones moving to generation and the rest + keep = [] + split = [] + for i, seq in enumerate(self._currently_processing): + if len(segments) == 1 and len(segments[0]) == 1: + split.append(i) + else: + keep.append(i) + + # Actually split off part of the prompt batch and start generation + if split: + last_inputs = [self._currently_processing[i][0] for i in split] + progress = [(self._currently_processing[i][2],) * 2 for i in split] + self._currently_processing = [self._currently_processing[i] for i in keep] + gen_batch = self._prompt_batch.split(split).generate(last_inputs) + for i, p in enumerate(progress): + prompt_responses.append( + self.Response( + gen_batch.uids[i], + p, + True, + True, + ) + ) + if self._generation_batch is not None: + self._generation_batch.extend(gen_batch) + else: + self._generation_batch = gen_batch + + # Extract the next prompts input + prompts = [] + for i, seq in enumerate(self._currently_processing): + response = self.Response(self._prompt_batch.uids[i], 0, False, False) + segments = seq[0] + n = min(len(segments[0]), self.prefill_step_size) + prompts.append(segments[0][:n]) + segments[0] = segments[0][n:] + if len(segments[0]) == 0: + segments.pop(0) + response.end_of_segment = True + seq[1] += len(prompts[-1]) + response.progress = (seq[1], seq[2]) + prompt_responses.append(response) + + # Process the prompts + self._prompt_batch.prompt(prompts) + + return prompt_responses, generation_responses + + def main(): parser = setup_arg_parser() args = parser.parse_args() From 2154d166b55cea16de94881f459f2629598caeb0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 23 Mar 2026 03:11:52 -0700 Subject: [PATCH 04/34] Fix various bugs and add generation batch extend --- mlx_lm/generate.py | 147 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 117 insertions(+), 30 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 8595943aa..bd01e5222 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -2,6 +2,7 @@ import argparse import contextlib +import copy import functools import json import sys @@ -918,6 +919,10 @@ def to_batch_cache(c): def _merge_caches(caches): batch_cache = [] + + if not caches: + return batch_cache + for i in range(len(caches[0])): if hasattr(caches[0][i], "merge"): batch_cache.append(caches[0][i].merge([c[i] for c in caches])) @@ -928,6 +933,16 @@ def _merge_caches(caches): return batch_cache +def _extend_cache(cache_a, cache_b): + if not cache_a: + return cache_b + if not cache_b: + return cache_b + for ca, cb in zip(cache_a, cache_b): + ca.extend(cb) + return cache_a + + def _lazy_extract_cache(cache, i): # Generators like lambdas are late bound so we can't just use it in the loop return (c.extract(i) for c in cache) @@ -1439,14 +1454,16 @@ def __init__(self, sequences): child["__seq__"] = child["__fail__"]["__seq__"] def make_state(self): - return self._trie + return (self._trie, self._trie) - def match(self, state, x): - while x not in state and state is not self._trie: + @staticmethod + def match(state, x): + state, trie = state + while x not in state and state is not trie: state = state["__fail__"] if x in state: state = state[x] - return state, state.get("__seq__") + return (state, trie), state.get("__seq__") class PromptProcessingBatch: @@ -1498,10 +1515,10 @@ def extend(self, batch): logits_processors = [None] * len(batch.uids) self.uids.extend(batch.uids) - self.prompt_cache.extend(batch.prompt_cache) + self.prompt_cache = _extend_cache(self.prompt_cache, batch.prompt_cache) self.tokens.extend(batch.tokens) - self.samplers.extend(batch.samplers) - self.logits_processors.extend(batch.logits_processors) + self.samplers.extend(samplers) + self.logits_processors.extend(logits_processors) self.max_tokens.extend(batch.max_tokens) if not any(self.samplers): @@ -1539,6 +1556,9 @@ def prompt(self, tokens): if len(self.uids) != len(tokens): raise ValueError("The batch length doesn't match the number of inputs") + if not tokens: + return + # Add the tokens to the self.tokens so they represent the tokens # contained in the KV Cache. for sti, ti in zip(self.tokens, tokens): @@ -1598,6 +1618,23 @@ def generate(self, tokens): return generation + @classmethod + def empty( + cls, model, fallback_sampler, stop_matcher, prefill_step_size: int = 2048 + ): + return cls( + model=model, + fallback_sampler=fallback_sampler, + stop_matcher=stop_matcher, + prefill_step_size=prefill_step_size, + uids=[], + caches=[], + tokens=[], + samplers=[], + logits_processors=[], + max_tokens=[], + ) + class GenerationBatch: @dataclass @@ -1647,13 +1684,36 @@ def __init__( self._num_tokens = [0] * len(self.uids) self._matcher_states = [self.stop_matcher.make_state()] * len(self.uids) - self._step() + if self.uids: + self._step() def __len__(self): return len(self.uids) def extend(self, batch): - pass + self.uids.extend(batch.uids) + self.prompt_cache = _extend_cache(self.prompt_cache, batch.prompt_cache) + self.tokens.extend(batch.tokens) + self.samplers.extend(batch.samplers) + self.logits_processors.extend(batch.logits_processors) + self.max_tokens.extend(batch.max_tokens) + if self._current_tokens is None: + self._current_tokens = batch._current_tokens + self._current_logprobs = batch._current_logprobs + elif batch._current_tokens is not None: + self._current_tokens = mx.concatenate( + [self._current_tokens, batch._current_tokens] + ) + self._current_logprobs.extend(batch._current_logprobs) + if self._next_tokens is None: + self._next_tokens = batch._next_tokens + self._next_logprobs = batch._next_logprobs + elif batch._next_tokens is not None: + self._next_tokens = mx.concatenate([self._next_tokens, batch._next_tokens]) + self._next_logprobs.extend(batch._next_logprobs) + self._token_context.extend(batch._token_context) + self._num_tokens.extend(batch._num_tokens) + self._matcher_states.extend(batch._matcher_states) def _step(self): # Move next to current and kick of the computation of next. Also assign @@ -1754,7 +1814,7 @@ def next(self): # Check if we reached the limit of tokens to generate self._num_tokens[i] += 1 - if self._num_tokens[i] >= self.max_tokens: + if self._num_tokens[i] >= self.max_tokens[i]: resp.finish_reason = "length" # Check if we produced a token sequence that means we got to stop @@ -1781,6 +1841,21 @@ def next(self): return responses + @classmethod + def empty(cls, model, fallback_sampler, stop_matcher): + return cls( + model=model, + fallback_sampler=fallback_sampler, + stop_matcher=stop_matcher, + uids=[], + inputs=None, + prompt_cache=[], + tokens=[], + samplers=[], + logits_processors=[], + max_tokens=[], + ) + class BatchGenerator2: def __init__( @@ -1807,8 +1882,15 @@ def __init__( self.completion_batch_size = max(completion_batch_size, prefill_batch_size) self._uid_count = 0 - self._prompt_batch = None - self._generation_batch = None + self._prompt_batch = PromptProcessingBatch.empty( + self.model, + self.sampler, + self.stop_matcher, + prefill_step_size=prefill_step_size, + ) + self._generation_batch = GenerationBatch.empty( + self.model, self.sampler, self.stop_matcher + ) self._unprocessed_sequences = deque() self._currently_processing = [] @@ -1837,7 +1919,7 @@ def insert( samplers=None, logits_processors=None, ): - self.insert_segments( + return self.insert_segments( [[p] for p in prompts], max_tokens, caches, @@ -1858,7 +1940,7 @@ def insert_segments( uids = [] if max_tokens is None or isinstance(max_tokens, int): - max_tokens = [max_tokens or self.max_tokens] * len(prompts) + max_tokens = [max_tokens or self.max_tokens] * len(segments) if caches is None: caches = [None] * len(segments) @@ -1906,15 +1988,16 @@ def _make_batch(self, n): [sequence[1], 0, sum(len(s) for s in sequence[1])] ) return PromptProcessingBatch( - self.model, - uids, - caches, - tokens, - self.prefill_step_size, - samplers, - logits_processors, - self.stop_matcher, - max_tokens, + model=self.model, + uids=uids, + caches=caches, + tokens=tokens, + prefill_step_size=self.prefill_step_size, + samplers=samplers, + fallback_sampler=self.sampler, + logits_processors=logits_processors, + stop_matcher=self.stop_matcher, + max_tokens=max_tokens, ) def _next(self): @@ -1942,6 +2025,7 @@ def _next(self): keep = [] split = [] for i, seq in enumerate(self._currently_processing): + segments = seq[0] if len(segments) == 1 and len(segments[0]) == 1: split.append(i) else: @@ -1949,28 +2033,27 @@ def _next(self): # Actually split off part of the prompt batch and start generation if split: - last_inputs = [self._currently_processing[i][0] for i in split] + last_inputs = [self._currently_processing[i][0][0] for i in split] progress = [(self._currently_processing[i][2],) * 2 for i in split] self._currently_processing = [self._currently_processing[i] for i in keep] gen_batch = self._prompt_batch.split(split).generate(last_inputs) for i, p in enumerate(progress): prompt_responses.append( - self.Response( + PromptProcessingBatch.Response( gen_batch.uids[i], p, True, True, ) ) - if self._generation_batch is not None: - self._generation_batch.extend(gen_batch) - else: - self._generation_batch = gen_batch + self._generation_batch.extend(gen_batch) # Extract the next prompts input prompts = [] for i, seq in enumerate(self._currently_processing): - response = self.Response(self._prompt_batch.uids[i], 0, False, False) + response = PromptProcessingBatch.Response( + self._prompt_batch.uids[i], 0, False, False + ) segments = seq[0] n = min(len(segments[0]), self.prefill_step_size) prompts.append(segments[0][:n]) @@ -1987,6 +2070,10 @@ def _next(self): return prompt_responses, generation_responses + def next(self): + with mx.stream(generation_stream): + return self._next() + def main(): parser = setup_arg_parser() From db4455092cc4f827c353a794fcbfb8dddab9fd03 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 23 Mar 2026 03:42:53 -0700 Subject: [PATCH 05/34] Fix a couple bugs add types and docstrings --- mlx_lm/generate.py | 274 ++++++++++++++++++++++++++--------------- mlx_lm/models/cache.py | 6 + 2 files changed, 182 insertions(+), 98 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index bd01e5222..5e44d3d64 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1467,6 +1467,13 @@ def match(state, x): class PromptProcessingBatch: + """ + A batch processor for prompt tokens with support for incremental processing. + + This class handles batched prompt processing, managing KV caches and preparing + tokens for generation. It supports extending, filtering, and splitting batches. + """ + @dataclass class Response: uid: int @@ -1476,43 +1483,51 @@ class Response: def __init__( self, - model, - uids, - caches, - tokens=None, + model: nn.Module, + uids: List[int], + caches: List[List[Any]], + tokens: Optional[List[List[int]]] = None, prefill_step_size: int = 2048, - samplers=None, - fallback_sampler=None, - logits_processors=None, - stop_matcher=None, - max_tokens=None, + samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, + fallback_sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ] = None, + stop_matcher: Optional[SequenceMatcher] = None, + max_tokens: Optional[List[int]] = None, ): self.model = model self.uids = uids self.prompt_cache = _merge_caches(caches) - self.tokens = tokens or [[] for uid in self.uids] + self.tokens = tokens if tokens is not None else [[] for _ in uids] self.prefill_step_size = prefill_step_size - self.samplers = samplers or [] + self.samplers = samplers if samplers is not None else [] self.fallback_sampler = fallback_sampler or (lambda x: mx.argmax(x, axis=-1)) - self.logits_processors = logits_processors or [] + self.logits_processors = ( + logits_processors if logits_processors is not None else [] + ) self.stop_matcher = stop_matcher or SequenceMatcher([]) - self.max_tokens = max_tokens or [1 << 31] * len(self.uids) + self.max_tokens = ( + max_tokens + if max_tokens is not None + else [DEFAULT_MAX_TOKENS] * len(self.uids) + ) def __len__(self): return len(self.uids) def extend(self, batch): - if not self.samplers: + if not any(self.samplers): self.samplers = [None] * len(self.uids) - if not self.logits_processors: + if not any(self.logits_processors): self.logits_processors = [None] * len(self.uids) - samplers = batch.samplers - if not samplers: - samplers = [None] * len(batch.uids) - logits_processors = batch.logits_processors - if not logits_processors: - logits_processors = [None] * len(batch.uids) + samplers = batch.samplers if any(batch.samplers) else [None] * len(batch.uids) + logits_processors = ( + batch.logits_processors + if any(batch.logits_processors) + else [None] * len(batch.uids) + ) self.uids.extend(batch.uids) self.prompt_cache = _extend_cache(self.prompt_cache, batch.prompt_cache) @@ -1521,11 +1536,6 @@ def extend(self, batch): self.logits_processors.extend(logits_processors) self.max_tokens.extend(batch.max_tokens) - if not any(self.samplers): - self.samplers = [] - if not any(self.logits_processors): - self.logits_processors = [] - def split(self, indices: List[int]): indices = sorted(indices) indices_left = sorted(set(range(len(self.uids))) - set(indices)) @@ -1549,10 +1559,13 @@ def filter(self, keep: List[int]): self.logits_processors = [self.logits_processors[idx] for idx in keep] self.max_tokens = [self.max_tokens[idx] for idx in keep] - def extract_cache(self, idx: int): - return [c.extract(idx) for c in self.prompt_cache] + def prompt(self, tokens: List[List[int]]): + """ + Process prompt tokens through the model. - def prompt(self, tokens): + Args: + tokens: List of token sequences to process. + """ if len(self.uids) != len(tokens): raise ValueError("The batch length doesn't match the number of inputs") @@ -1594,7 +1607,16 @@ def prompt(self, tokens): mx.eval([c.state for c in self.prompt_cache]) mx.clear_cache() - def generate(self, tokens): + def generate(self, tokens: List[List[int]]): + """ + Transition from prompt processing to generation. + + Args: + tokens: Final tokens for each sequence to start generation. + + Returns: + A GenerationBatch ready for token generation. + """ if any(len(t) > 1 for t in tokens): self.prompt([t[:-1] for t in tokens]) last_token = mx.array([t[-1] for t in tokens]) @@ -1613,14 +1635,21 @@ def generate(self, tokens): ) self.uids = [] - self.prompt_cache = None + self.prompt_cache = [] self.tokens = [] + self.samplers = [] + self.logits_processors = [] + self.max_tokens = [] return generation @classmethod def empty( - cls, model, fallback_sampler, stop_matcher, prefill_step_size: int = 2048 + cls, + model: nn.Module, + fallback_sampler: Callable[[mx.array], mx.array], + stop_matcher: SequenceMatcher, + prefill_step_size: int = 2048, ): return cls( model=model, @@ -1637,6 +1666,13 @@ def empty( class GenerationBatch: + """ + A batched token generator that manages multiple sequences in parallel. + + This class handles the generation phase after prompt processing, managing + KV caches, sampling, and stop sequence detection for multiple sequences. + """ + @dataclass class Response: uid: int @@ -1649,16 +1685,18 @@ class Response: def __init__( self, - model, + model: nn.Module, uids: List[int], inputs: mx.array, prompt_cache: List[Any], tokens: List[List[int]], - samplers, - fallback_sampler, - logits_processors, - stop_matcher, - max_tokens, + samplers: Optional[List[Callable[[mx.array], mx.array]]], + fallback_sampler: Callable[[mx.array], mx.array], + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ], + stop_matcher: SequenceMatcher, + max_tokens: List[int], ): self.model = model self.uids = uids @@ -1677,9 +1715,9 @@ def __init__( raise ValueError("Insufficient number of logits_processors provided") self._current_tokens = None - self._current_logprobs = None + self._current_logprobs = [] self._next_tokens = inputs - self._next_logprobs = None + self._next_logprobs = [] self._token_context = [mx.array(t[-256:]) for t in tokens] self._num_tokens = [0] * len(self.uids) self._matcher_states = [self.stop_matcher.make_state()] * len(self.uids) @@ -1691,6 +1729,7 @@ def __len__(self): return len(self.uids) def extend(self, batch): + """Extend this batch with another generation batch.""" self.uids.extend(batch.uids) self.prompt_cache = _extend_cache(self.prompt_cache, batch.prompt_cache) self.tokens.extend(batch.tokens) @@ -1715,9 +1754,13 @@ def extend(self, batch): self._num_tokens.extend(batch._num_tokens) self._matcher_states.extend(batch._matcher_states) - def _step(self): - # Move next to current and kick of the computation of next. Also assign - # current to a local var for convenience. + def _step(self) -> Tuple[List[int], List[mx.array]]: + """ + Perform a single generation step. + + Returns: + Tuple of token list and logprobs list. + """ self._current_tokens = self._next_tokens self._current_logprobs = self._next_logprobs inputs = self._current_tokens @@ -1732,20 +1775,20 @@ def _step(self): logits = self.model(inputs[:, None], cache=self.prompt_cache) logits = logits[:, -1, :] - # Logit processors + # Logits processors if any(self.logits_processors): processed_logits = [] for e in range(len(self.uids)): sample_logits = logits[e : e + 1] - for processor in logits_processors[e]: - sample_logits = processor(tokens[e], sample_logits) + for processor in self.logits_processors[e]: + sample_logits = processor(self.tokens[e], sample_logits) processed_logits.append(sample_logits) logits = mx.concatenate(processed_logits, axis=0) # Normalize the logits logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - # Samplers + # Sample if any(self.samplers): all_samples = [] for e in range(len(self.uids)): @@ -1771,11 +1814,11 @@ def _step(self): sti.append(ti) return inputs, self._current_logprobs - def extract_cache(self, idx: int): + def extract_cache(self, idx: int) -> List[Any]: return [c.extract(idx) for c in self.prompt_cache] def filter(self, keep: List[int]): - # Visible state + """Filter the batch to keep only the specified indices.""" self.uids = [self.uids[idx] for idx in keep] if not keep: self.prompt_cache.clear() @@ -1788,67 +1831,86 @@ def filter(self, keep: List[int]): if any(self.logits_processors): self.logits_processors = [self.logits_processors[idx] for idx in keep] - # Internal state self._next_tokens = self._next_tokens[keep] if keep else None self._next_logprobs = [self._next_logprobs[idx] for idx in keep] self._token_context = [self._token_context[idx] for idx in keep] self._num_tokens = [self._num_tokens[idx] for idx in keep] self._matcher_states = [self._matcher_states[idx] for idx in keep] - def next(self): + def next(self) -> List[Response]: + """ + Generate the next batch of tokens. + + Returns: + List of Response objects for each sequence in the batch. + """ + if not self.uids: + return [] + tokens, logprobs = self._step() keep = [] responses = [] for i in range(len(self.uids)): - # Make the response object - resp = self.Response( - self.uids[i], - tokens[i], - logprobs[i], - None, - None, - None, - None, - ) + finish_reason = None + stop_sequence = None - # Check if we reached the limit of tokens to generate self._num_tokens[i] += 1 if self._num_tokens[i] >= self.max_tokens[i]: - resp.finish_reason = "length" + finish_reason = "length" - # Check if we produced a token sequence that means we got to stop self._matcher_states[i], match = self.stop_matcher.match( self._matcher_states[i], tokens[i] ) if match is not None: - resp.finish_reason = "stop" - resp.stop_sequence = match - - # If we are done add the cache and corresponding tokens to the - # response. - if resp.finish_reason is not None: - resp.prompt_cache = self.extract_cache(i) - resp.all_tokens = self.tokens[i] + finish_reason = "stop" + stop_sequence = match + + if finish_reason is not None: + responses.append( + self.Response( + uid=self.uids[i], + token=tokens[i], + logprobs=logprobs[i], + finish_reason=finish_reason, + stop_sequence=stop_sequence, + prompt_cache=self.extract_cache(i), + all_tokens=self.tokens[i], + ) + ) else: keep.append(i) + responses.append( + self.Response( + uid=self.uids[i], + token=tokens[i], + logprobs=logprobs[i], + finish_reason=None, + stop_sequence=None, + prompt_cache=None, + all_tokens=None, + ) + ) - responses.append(resp) - - # Remove all sequences that are done. if len(keep) < len(self.uids): self.filter(keep) return responses @classmethod - def empty(cls, model, fallback_sampler, stop_matcher): + def empty( + cls, + model: nn.Module, + fallback_sampler: Callable[[mx.array], mx.array], + stop_matcher: SequenceMatcher, + ): + """Create an empty GenerationBatch.""" return cls( model=model, fallback_sampler=fallback_sampler, stop_matcher=stop_matcher, uids=[], - inputs=None, + inputs=mx.array([], dtype=mx.uint32), prompt_cache=[], tokens=[], samplers=[], @@ -1858,9 +1920,16 @@ def empty(cls, model, fallback_sampler, stop_matcher): class BatchGenerator2: + """ + A batch generator that manages both prompt processing and generation phases. + + This class provides automatic management of prompt processing and generation + batches, handling the transition between phases seamlessly. + """ + def __init__( self, - model, + model: nn.Module, max_tokens: int = 128, stop_tokens: Optional[List[List[int]]] = None, sampler: Optional[Callable[[mx.array], mx.array]] = None, @@ -1873,7 +1942,7 @@ def __init__( ): self.model = model self.max_tokens = max_tokens - self.stop_matcher = SequenceMatcher(stop_tokens) + self.stop_matcher = SequenceMatcher(stop_tokens or []) self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) self.logits_processors = logits_processors or [] self.uid_count = 0 @@ -1912,12 +1981,14 @@ def __del__(self): def insert( self, - prompts, - max_tokens=None, - caches=None, - all_tokens=None, - samplers=None, - logits_processors=None, + prompts: List[List[int]], + max_tokens: Optional[Union[List[int], int]] = None, + caches: Optional[List[List[Any]]] = None, + all_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ] = None, ): return self.insert_segments( [[p] for p in prompts], @@ -1930,12 +2001,14 @@ def insert( def insert_segments( self, - segments, - max_tokens=None, - caches=None, - all_tokens=None, - samplers=None, - logits_processors=None, + segments: List[List[List[int]]], + max_tokens: Optional[Union[List[int], int]] = None, + caches: Optional[List[List[Any]]] = None, + all_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ] = None, ): uids = [] @@ -1948,11 +2021,11 @@ def insert_segments( if caches[i] is None: caches[i] = cache.make_prompt_cache(self.model) - all_tokens = all_tokens or [[] for _ in range(len(segments))] + all_tokens = all_tokens or [[] for _ in segments] - samplers = samplers or [None] * len(segments) - logits_processors = logits_processors or [self.logits_processors] * len( - segments + samplers = samplers or ([None] * len(segments)) + logits_processors = logits_processors or ( + [self.logits_processors] * len(segments) ) for seq, m, c, at, s, lp in zip( @@ -1968,15 +2041,14 @@ def insert_segments( return uids - def _make_batch(self, n): + def _make_batch(self, n: int): uids = [] caches = [] tokens = [] samplers = [] logits_processors = [] max_tokens = [] - while n > 0: - n -= 1 + for _ in range(n): sequence = self._unprocessed_sequences.popleft() uids.append(sequence[0]) caches.append(sequence[3]) @@ -2071,6 +2143,12 @@ def _next(self): return prompt_responses, generation_responses def next(self): + """ + Get the next batch of responses. + + Returns: + Tuple of prompt processing responses and generation responses. + """ with mx.stream(generation_stream): return self._next() diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 630e65004..537009642 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -975,6 +975,9 @@ def filter(self, batch_indices): """ In-place filter to keep just the given indices in the cache. """ + if self.keys is None: + return + self.keys = self.keys[batch_indices] self.values = self.values[batch_indices] self.offset = self.offset[batch_indices] @@ -1297,6 +1300,9 @@ def filter(self, batch_indices): """ In-place filter to keep just the given indices in the cache. """ + if self.keys is None: + return + self.keys = self.keys[batch_indices] self.values = self.values[batch_indices] self.offset = self.offset[batch_indices] From 38c31ae0f555b5138bb78a137d55e7a33525e69a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 25 Mar 2026 01:39:12 -0700 Subject: [PATCH 06/34] Add stats and use it in batch_generate --- mlx_lm/generate.py | 75 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 16 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 5e44d3d64..55dd8190c 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1375,9 +1375,9 @@ def batch_generate( See :obj:`BatchGenerator` for more details. """ - gen = BatchGenerator( + gen = BatchGenerator2( model, - stop_tokens=tokenizer.eos_token_ids, + stop_tokens=[[t] for t in tokenizer.eos_token_ids], **kwargs, ) num_samples = len(prompts) @@ -1388,26 +1388,29 @@ def batch_generate( uids = gen.insert(prompts, max_tokens, caches=prompt_caches) results = {uid: [] for uid in uids} prompt_caches = {} - while responses := gen.next(): - for r in responses: - if r.finish_reason is not None: - if return_prompt_caches: - prompt_caches[r.uid] = r.prompt_cache - if verbose: - fin += 1 - print( - f"[batch_generate] Finished processing {fin}/{num_samples} ...", - end="\r", - ) - if r.finish_reason != "stop": - results[r.uid].append(r.token) + with gen.stats() as stats: + while True: + prompt_responses, token_responses = gen.next() + if not prompt_responses and not token_responses: + break + for r in token_responses: + if r.finish_reason is not None: + if return_prompt_caches: + prompt_caches[r.uid] = r.prompt_cache + if verbose: + fin += 1 + print( + f"[batch_generate] Finished processing {fin}/{num_samples} ...", + end="\r", + ) + if r.finish_reason != "stop": + results[r.uid].append(r.token) gen.close() if verbose: print(f"[batch_generate] Finished processing {fin}/{num_samples}") # Return results in correct order texts = [tokenizer.decode(results[uid]) for uid in uids] - stats = gen.stats() caches = [prompt_caches[uid] for uid in uids] if return_prompt_caches else None if verbose: print( @@ -1963,6 +1966,11 @@ def __init__( self._unprocessed_sequences = deque() self._currently_processing = [] + self._prompt_tokens_counter = 0 + self._prompt_time_counter = 0 + self._gen_tokens_counter = 0 + self._steps_counter = 0 + if mx.metal.is_available(): self._old_wired_limit = mx.set_wired_limit( mx.device_info()["max_recommended_working_set_size"] @@ -1976,6 +1984,27 @@ def close(self): mx.set_wired_limit(self._old_wired_limit) self._old_wired_limit = None + @contextlib.contextmanager + def stats(self, stats=None): + stats = stats or BatchStats() + self._prompt_tokens_counter = 0 + self._prompt_time_counter = 0 + self._gen_tokens_counter = 0 + tic = time.perf_counter() + try: + yield stats + finally: + toc = time.perf_counter() + total_time = toc - tic + gen_time = total_time - self._prompt_time_counter + stats.prompt_tokens += self._prompt_tokens_counter + stats.prompt_time += self._prompt_time_counter + stats.prompt_tps = stats.prompt_tokens / stats.prompt_time + stats.generation_tokens += self._gen_tokens_counter + stats.generation_time += gen_time + stats.generation_tps = stats.generation_tokens / stats.generation_time + stats.peak_memory = max(stats.peak_memory, mx.get_peak_memory() / 1e9) + def __del__(self): self.close() @@ -2041,6 +2070,15 @@ def insert_segments( return uids + def _find_uid(self, uid): + for i, uid_i in enumerate(self._prompt_batch.uids): + if uid_i == uid: + return (0, i) + for i, uid_i in enumerate(self._generation_batch.uids): + if uid_i == uid: + return (1, i) + return False + def _make_batch(self, n: int): uids = [] caches = [] @@ -2079,6 +2117,7 @@ def _next(self): # Generate tokens first if len(self._generation_batch) > 0: generation_responses = self._generation_batch.next() + self._gen_tokens_counter += len(generation_responses) # Exit early because we already have our hands full with decoding if len(self._generation_batch) >= self.completion_batch_size: @@ -2138,7 +2177,11 @@ def _next(self): prompt_responses.append(response) # Process the prompts + self._prompt_tokens_counter += sum(len(p) for p in prompts) + tic = time.perf_counter() self._prompt_batch.prompt(prompts) + toc = time.perf_counter() + self._prompt_time_counter += toc - tic return prompt_responses, generation_responses From baa8c4e36885ffb42c18196d47a7911bb39be6bf Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 25 Mar 2026 13:52:41 -0700 Subject: [PATCH 07/34] Add cache extraction and time reporting for the benchmark --- mlx_lm/benchmark.py | 3 +++ mlx_lm/generate.py | 36 +++++++++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/mlx_lm/benchmark.py b/mlx_lm/benchmark.py index 3b2cb66a4..8bfafe898 100644 --- a/mlx_lm/benchmark.py +++ b/mlx_lm/benchmark.py @@ -148,10 +148,13 @@ def batch_bench(): for i in range(args.num_trials): if args.delay > 0: time.sleep(args.delay) + tic = time.perf_counter() response = _bench() + toc = time.perf_counter() responses.append(response) results = [(k, getattr(response, k)) for k in report_keys] results = [f"{k}={v:.3f}" for k, v in results] + results.append(f"total_time={toc - tic:.3f}") rprint(f"Trial {i+1}: " + ", ".join(results)) def avg(k): diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 55dd8190c..ca4d8bea8 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -2070,14 +2070,36 @@ def insert_segments( return uids - def _find_uid(self, uid): - for i, uid_i in enumerate(self._prompt_batch.uids): - if uid_i == uid: - return (0, i) + def _find_uids(self, uids): + uids = set(uids) + results = {} for i, uid_i in enumerate(self._generation_batch.uids): - if uid_i == uid: - return (1, i) - return False + if uid_i in uid: + result[uid_i] = (2, i) + for i, uid_i in enumerate(self._prompt_batch.uids): + if uid_i in uids: + result[uid_i] = (1, i) + for i, seq in enumerate(self._unprocessed_sequences): + if seq[0] in uid: + result[seq[0]] = (0, i) + return results + + def extract_cache(self, uids): + results = {} + for uid, (stage, idx) in self._find_uids(uids).items(): + if stage == 0: + results[uid] = self._unprocessed_sequences[idx][3:5] + elif stage == 1: + results[uid] = ( + self._prompt_batch.extract_cache(idx), + self._prompt_batch.tokens[i], + ) + else: + results[uid] = ( + self._generation_batch.extract_cache(idx), + self._generation_batch.tokens[i], + ) + return results def _make_batch(self, n: int): uids = [] From a36975a8c6fdcbcb77cbfce4e5655581e6ca485f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 25 Mar 2026 14:02:43 -0700 Subject: [PATCH 08/34] Remove the original batch generator --- mlx_lm/generate.py | 576 +++++++-------------------------------------- 1 file changed, 91 insertions(+), 485 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ca4d8bea8..4f98bfd10 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -943,487 +943,6 @@ def _extend_cache(cache_a, cache_b): return cache_a -def _lazy_extract_cache(cache, i): - # Generators like lambdas are late bound so we can't just use it in the loop - return (c.extract(i) for c in cache) - - -class BatchGenerator: - @dataclass - class Response: - uid: int - token: int - logprobs: mx.array - finish_reason: Optional[str] - prompt_cache: Callable[[], List[Any]] - - def __init__( - self, - model, - max_tokens: int = 128, - stop_tokens: Optional[set] = None, - sampler: Optional[Callable[[mx.array], mx.array]] = None, - logits_processors: Optional[ - List[Callable[[mx.array, mx.array], mx.array]] - ] = None, - completion_batch_size: int = 32, - prefill_batch_size: int = 8, - prefill_step_size: int = 2048, - prompt_checkpoint_callback: Optional[ - Callable[[List[Tuple[int, int, List[Any]]]], None] - ] = None, - prompt_progress_callback: Optional[ - Callable[[List[Tuple[int, int, int]]], None] - ] = None, - max_kv_size: Optional[int] = None, - ): - self.model = model - self.unprocessed_prompts = [] - self.max_tokens = max_tokens - self.stop_tokens = stop_tokens or set() - self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) - self.logits_processors = logits_processors or [] - self.uid_count = 0 - self.prefill_step_size = prefill_step_size - self.prefill_batch_size = prefill_batch_size - self.completion_batch_size = max(completion_batch_size, prefill_batch_size) - self.prompt_checkpoint_callback = prompt_checkpoint_callback - self.prompt_progress_callback = prompt_progress_callback or (lambda *_: None) - self._stats = BatchStats() - self._next_count = 0 - self.max_kv_size = max_kv_size - - self.active_batch = None - - if mx.metal.is_available(): - self._old_wired_limit = mx.set_wired_limit( - mx.device_info()["max_recommended_working_set_size"] - ) - else: - self._old_wired_limit = None - - def close(self): - if self._old_wired_limit is not None: - mx.synchronize(generation_stream) - mx.set_wired_limit(self._old_wired_limit) - self._old_wired_limit = None - - def __del__(self): - self.close() - - def insert( - self, - prompts, - max_tokens: Union[List[int], int, None] = None, - caches=None, - samplers: list | None = None, - logits_processors: list | None = None, - prompt_checkpoints: list | int | None = None, - ): - uids = [] - - if max_tokens is None or isinstance(max_tokens, int): - max_tokens = [max_tokens or self.max_tokens] * len(prompts) - - if prompt_checkpoints is None or isinstance(prompt_checkpoints, int): - prompt_checkpoints = [prompt_checkpoints or -1] * len(prompts) - - if caches is None: - caches = [None] * len(prompts) - for i in range(len(prompts)): - if caches[i] is None: - caches[i] = cache.make_prompt_cache(self.model) - - samplers = samplers or [None] * len(prompts) - logits_processors = logits_processors or [self.logits_processors] * len(prompts) - - for p, m, c, s, lp, pc in zip( - prompts, max_tokens, caches, samplers, logits_processors, prompt_checkpoints - ): - self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp, pc)) - uids.append(self.uid_count) - self.uid_count += 1 - # Sort in ascending order of length - self.unprocessed_prompts = sorted( - self.unprocessed_prompts, - key=lambda x: len(x[1]) + max(c.size() for c in x[3]), - ) - return uids - - def remove(self, uids: List[int], return_prompt_caches: bool = False): - caches = {} - uids = set(uids) - if self.active_batch is not None: - batch = self.active_batch - if return_prompt_caches: - for e, uid in enumerate(batch.uids): - if uid not in uids: - continue - caches[uid] = batch.extract_cache(e) - keep_idx = [e for e, uid in enumerate(batch.uids) if uid not in uids] - if len(keep_idx) > 0: - batch.filter(keep_idx) - else: - self.active_batch = None - - for i in reversed(range(len(self.unprocessed_prompts))): - if self.unprocessed_prompts[i][0] in uids: - self.unprocessed_prompts.pop(i) - - if return_prompt_caches: - return caches - - @property - def prompt_cache_nbytes(self): - total = sum(c.nbytes for p in self.unprocessed_prompts for c in p[3]) - if self.active_batch is not None: - total += sum(c.nbytes for c in self.active_batch.cache) - return total - - def _process_prompts(self, prompts): - ( - uids, - inputs, - max_tokens, - caches, - samplers, - logits_processors, - prompt_checkpoints, - ) = zip(*prompts) - - lengths = [len(p) for p in inputs] - max_length = max(lengths) - padding = [max_length - l for l in lengths] - - # Get the checkpoint token as an offset from the end of each prompt. - # Then select the largest one so that we perform the checkpoint at - # least `pc` before the end. - prompt_checkpoints = [ - (l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints) - ] - prompt_checkpoint = max(1, max(prompt_checkpoints)) - - self._stats.prompt_tokens += sum(lengths) - - tokens = [mx.array(inp) for inp in inputs] - processed_tokens = 0 - - # New prompts so - # 1. Left-pad the inputs - # 2. Process - if all(c[0].empty() for c in caches): - inputs = _left_pad_prompts(inputs, max_length=max_length) - prompt_cache = _make_cache(self.model, padding, self.max_kv_size) - - while inputs.shape[1] > prompt_checkpoint: - n_to_process = min( - self.prefill_step_size, inputs.shape[1] - prompt_checkpoint - ) - self.model(inputs[:, :n_to_process], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - inputs = inputs[:, n_to_process:] - processed_tokens += n_to_process - self.prompt_progress_callback( - [ - (uid, processed_tokens, length) - for uid, length in zip(uids, lengths) - ] - ) - mx.clear_cache() - - # Further prompt processing so we need to - # 1. Merge the KV caches and prepare for right padded prompts - # 2. Right pad the inputs - # 2. Process - # 3. Finalize the KV caches so they are left padded again - else: - last_inputs = mx.array([p[-prompt_checkpoint:] for p in inputs]) - inputs = _right_pad_prompts(inputs, max_length=max_length) - prompt_cache = _merge_caches(caches) - - for c in prompt_cache: - # subtract from lengths since we don't process the last - # `prompt_checkpoint` tokens during prefill - c.prepare( - lengths=[l - prompt_checkpoint for l in lengths], - right_padding=padding, - ) - - while inputs.shape[1] > prompt_checkpoint: - n_to_process = min( - self.prefill_step_size, inputs.shape[1] - prompt_checkpoint - ) - self.model(inputs[:, :n_to_process], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - inputs = inputs[:, n_to_process:] - processed_tokens += n_to_process - self.prompt_progress_callback( - [ - (uid, processed_tokens, length) - for uid, length in zip(uids, lengths) - ] - ) - mx.clear_cache() - - mx.eval([c.state for c in prompt_cache]) - inputs = last_inputs - - for c in prompt_cache: - c.finalize() - - # We processed L - prompt_checkpoint tokens so call the checkpoint - # callback. - if self.prompt_checkpoint_callback is not None: - self.prompt_checkpoint_callback( - [ - (uid, prompt_checkpoint, _lazy_extract_cache(prompt_cache, i)) - for i, uid in enumerate(uids) - ] - ) - # Process the remaining prompt_checkpoint-1 tokens - if prompt_checkpoint > 1: - self.model(inputs[:, : prompt_checkpoint - 1], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - mx.clear_cache() - - y, logprobs = self._step( - inputs, prompt_cache, samplers, logits_processors, tokens - ) - - mx.async_eval(y, logprobs) - - return Batch( - list(uids), - y, - logprobs, - list(max_tokens), - [0] * len(uids), - prompt_cache, - list(samplers), - list(logits_processors), - tokens, - ) - - def _step( - self, - input_tokens: mx.array, - prompt_cache: List[Any], - samplers: list | None, - logits_processors: list | None, - tokens: List[mx.array], - ): - batch_size = input_tokens.shape[0] - - logits = self.model(input_tokens, cache=prompt_cache) - logits = logits[:, -1, :] - - if any(logits_processors): - processed_logits = [] - for e in range(batch_size): - sample_logits = logits[e : e + 1] - for processor in logits_processors[e]: - sample_logits = processor(tokens[e], sample_logits) - processed_logits.append(sample_logits) - logits = mx.concatenate(processed_logits, axis=0) - - logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - if any(samplers): - all_samples = [] - for e in range(batch_size): - sample_sampler = samplers[e] or self.sampler - sampled = sample_sampler(logprobs[e : e + 1]) - all_samples.append(sampled) - sampled = mx.concatenate(all_samples, axis=0) - else: - sampled = self.sampler(logprobs) - - return sampled, list(logprobs) - - def stats(self): - self._stats.prompt_tps = self._stats.prompt_tokens / self._stats.prompt_time - self._stats.generation_tps = ( - self._stats.generation_tokens / self._stats.generation_time - ) - self._stats.peak_memory = mx.get_peak_memory() / 1e9 - return self._stats - - def _next(self): - tic = time.perf_counter() - - prompt_processing = False - batch = self.active_batch - num_active = len(batch) if batch else 0 - num_to_add = self.completion_batch_size - num_active - while num_to_add >= self.prefill_batch_size: - prompts = self.unprocessed_prompts[: self.prefill_batch_size] - # Finish processing the last examples of the last batch - if len(prompts) == 0 and num_active > 0: - break - # No more prompts and no more completions, all done - elif len(prompts) == 0: - self.active_batch = None - return [] - # Process prompts - if batch is not None and not prompt_processing: - # Finish any active completion tokens - mx.eval(batch.y, batch.logprobs) - self._stats.generation_time += time.perf_counter() - tic - tic = time.perf_counter() - - batch = self._process_prompts(prompts) - self.unprocessed_prompts = self.unprocessed_prompts[ - self.prefill_batch_size : - ] - prompt_processing = True - # If there was no active batch, set it - if self.active_batch is None: - self.active_batch = batch - else: - self.active_batch.extend(batch) - - num_active = len(self.active_batch) - num_to_add -= len(batch) - - batch = self.active_batch - y, logprobs = batch.y, batch.logprobs - for i, toks in enumerate(batch.tokens): - batch.tokens[i] = mx.concatenate((toks, y[i : i + 1])) - batch.y, batch.logprobs = self._step( - y[:, None], - batch.cache, - batch.samplers, - batch.logits_processors, - batch.tokens, - ) - - mx.async_eval(batch.y, batch.logprobs, batch.tokens) - - y = y.tolist() - toc = time.perf_counter() - if prompt_processing: - self._stats.prompt_time += toc - tic - else: - self._stats.generation_time += toc - tic - keep_idx = [] - end_idx = [] - responses = [] - - for e, (t, uid, num_tok, max_tok) in enumerate( - zip(y, batch.uids, batch.num_tokens, batch.max_tokens) - ): - cache = None - num_tok += 1 - batch.num_tokens[e] = num_tok - if t in self.stop_tokens: - finish_reason = "stop" - end_idx.append(e) - elif num_tok >= max_tok: - finish_reason = "length" - end_idx.append(e) - else: - finish_reason = None - keep_idx.append(e) - if finish_reason is not None: - cache = batch.extract_cache(e) - responses.append(self.Response(uid, t, logprobs[e], finish_reason, cache)) - - # Remove any finished completions - if len(end_idx): - if len(keep_idx) > 0: - batch.filter(keep_idx) - else: - self.active_batch = None - - self._next_count += 1 - if self._next_count % 512 == 0: - mx.clear_cache() - self._stats.generation_tokens += len(responses) - return responses - - def next(self): - with mx.stream(generation_stream): - return self._next() - - -def batch_generate( - model, - tokenizer, - prompts: List[List[int]], - prompt_caches: Optional[List[List[Any]]] = None, - max_tokens: Union[int, List[int]] = 128, - verbose: bool = False, - return_prompt_caches: bool = False, - **kwargs, -) -> BatchResponse: - """ - Generate responses for the given batch of prompts. - - Args: - model (nn.Module): The language model. - tokenizer (PreTrainedTokenizer): The tokenizer. - prompts (List[List[int]]): The input prompts. - prompt_caches (List[List[Any]], optional): Pre-computed prompt-caches - for each input prompt. Note, unlike ``generate_step``, the caches - won't be updated in-place. - verbose (bool): If ``True``, print tokens and timing information. - Default: ``False``. - max_tokens (Union[int, List[int]): Maximum number of output tokens. This - can be per prompt if a list is provided. - return_prompt_caches (bool): Return the prompt caches in the batch - responses. Default: ``False``. - kwargs: The remaining options get passed to :obj:`BatchGenerator`. - See :obj:`BatchGenerator` for more details. - """ - - gen = BatchGenerator2( - model, - stop_tokens=[[t] for t in tokenizer.eos_token_ids], - **kwargs, - ) - num_samples = len(prompts) - fin = 0 - if verbose: - print(f"[batch_generate] Finished processing 0/{num_samples} ...", end="\r") - - uids = gen.insert(prompts, max_tokens, caches=prompt_caches) - results = {uid: [] for uid in uids} - prompt_caches = {} - with gen.stats() as stats: - while True: - prompt_responses, token_responses = gen.next() - if not prompt_responses and not token_responses: - break - for r in token_responses: - if r.finish_reason is not None: - if return_prompt_caches: - prompt_caches[r.uid] = r.prompt_cache - if verbose: - fin += 1 - print( - f"[batch_generate] Finished processing {fin}/{num_samples} ...", - end="\r", - ) - if r.finish_reason != "stop": - results[r.uid].append(r.token) - gen.close() - if verbose: - print(f"[batch_generate] Finished processing {fin}/{num_samples}") - - # Return results in correct order - texts = [tokenizer.decode(results[uid]) for uid in uids] - caches = [prompt_caches[uid] for uid in uids] if return_prompt_caches else None - if verbose: - print( - f"[batch_generate] Prompt: {stats.prompt_tokens} tokens, {stats.prompt_tps:.3f} tokens-per-sec" - ) - print( - f"[batch_generate] Generation: {stats.generation_tokens} tokens, " - f"{stats.generation_tps:.3f} tokens-per-sec" - ) - print(f"[batch_generate] Peak memory: {stats.peak_memory:.3f} GB") - return BatchResponse(texts, stats, caches) - - class SequenceMatcher: def __init__(self, sequences): self._trie = {} @@ -1907,7 +1426,6 @@ def empty( fallback_sampler: Callable[[mx.array], mx.array], stop_matcher: SequenceMatcher, ): - """Create an empty GenerationBatch.""" return cls( model=model, fallback_sampler=fallback_sampler, @@ -1922,12 +1440,15 @@ def empty( ) -class BatchGenerator2: +class BatchGenerator: """ - A batch generator that manages both prompt processing and generation phases. + A batch generator implements continuous batching. This class provides automatic management of prompt processing and generation - batches, handling the transition between phases seamlessly. + batches, handling the transition between the two. + + It also allows for segmented prompt processing which guarantees that the + generator will stop at these boundaries when processing an input. """ def __init__( @@ -2140,6 +1661,9 @@ def _next(self): if len(self._generation_batch) > 0: generation_responses = self._generation_batch.next() self._gen_tokens_counter += len(generation_responses) + self._steps_counter += 1 + if self._steps_counter % 512 == 0: + mx.clear_cache() # Exit early because we already have our hands full with decoding if len(self._generation_batch) >= self.completion_batch_size: @@ -2218,6 +1742,88 @@ def next(self): return self._next() +def batch_generate( + model, + tokenizer, + prompts: List[List[int]], + prompt_caches: Optional[List[List[Any]]] = None, + max_tokens: Union[int, List[int]] = 128, + verbose: bool = False, + return_prompt_caches: bool = False, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + **kwargs, +) -> BatchResponse: + """ + Generate responses for the given batch of prompts. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompts (List[List[int]]): The input prompts. + prompt_caches (List[List[Any]], optional): Pre-computed prompt-caches + for each input prompt. Note, unlike ``generate_step``, the caches + won't be updated in-place. + verbose (bool): If ``True``, print tokens and timing information. + Default: ``False``. + max_tokens (Union[int, List[int]): Maximum number of output tokens. This + can be per prompt if a list is provided. + return_prompt_caches (bool): Return the prompt caches in the batch + responses. Default: ``False``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed logits. Default: ``None``. + kwargs: The remaining options get passed to :obj:`BatchGenerator`. + See :obj:`BatchGenerator` for more details. + """ + + gen = BatchGenerator( + model, + stop_tokens=[[t] for t in tokenizer.eos_token_ids], + **kwargs, + ) + num_samples = len(prompts) + fin = 0 + if verbose: + print(f"[batch_generate] Finished processing 0/{num_samples} ...", end="\r") + + uids = gen.insert(prompts, max_tokens, caches=prompt_caches) + results = {uid: [] for uid in uids} + prompt_caches = {} + with gen.stats() as stats: + while True: + prompt_responses, token_responses = gen.next() + if not prompt_responses and not token_responses: + break + for r in token_responses: + if r.finish_reason is not None: + if return_prompt_caches: + prompt_caches[r.uid] = r.prompt_cache + if verbose: + fin += 1 + print( + f"[batch_generate] Finished processing {fin}/{num_samples} ...", + end="\r", + ) + if r.finish_reason != "stop": + results[r.uid].append(r.token) + gen.close() + if verbose: + print(f"[batch_generate] Finished processing {fin}/{num_samples}") + + # Return results in correct order + texts = [tokenizer.decode(results[uid]) for uid in uids] + caches = [prompt_caches[uid] for uid in uids] if return_prompt_caches else None + if verbose: + print( + f"[batch_generate] Prompt: {stats.prompt_tokens} tokens, {stats.prompt_tps:.3f} tokens-per-sec" + ) + print( + f"[batch_generate] Generation: {stats.generation_tokens} tokens, " + f"{stats.generation_tps:.3f} tokens-per-sec" + ) + print(f"[batch_generate] Peak memory: {stats.peak_memory:.3f} GB") + return BatchResponse(texts, stats, caches) + + def main(): parser = setup_arg_parser() args = parser.parse_args() From 366cbdec4a47db8bb96687153ca58b28e340376d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 25 Mar 2026 16:50:22 -0700 Subject: [PATCH 09/34] Fix the generate tests --- mlx_lm/generate.py | 75 +++++++++++++++++++++++++++--------------- mlx_lm/models/cache.py | 56 +++++++++++++++++++++---------- tests/test_generate.py | 24 +++++++------- 3 files changed, 99 insertions(+), 56 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 4f98bfd10..71b2730b0 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -16,6 +16,7 @@ Generator, List, Optional, + Sequence, Tuple, Union, ) @@ -826,21 +827,6 @@ class BatchStats: peak_memory: float = 0 -@dataclass -class BatchResponse: - """ - An data object to hold a batch generation response. - - Args: - texts: (List[str]): The generated text for each prompt. - stats (BatchStats): Statistics about the generation. - """ - - texts: List[str] - stats: BatchStats - caches: Optional[List[List[Any]]] - - @dataclass class Batch: uids: List[int] @@ -950,9 +936,13 @@ def __init__(self, sequences): # Make the trie for seq in sequences: current = self._trie - for x in seq: - current = current.setdefault(x, {}) - current["__seq__"] = tuple(seq) + try: + for x in seq: + current = current.setdefault(x, {}) + current["__seq__"] = tuple(seq) + except TypeError: + current = current.setdefault(seq, {}) + current["__seq__"] = (seq,) # Propagate failure links and matched sequences queue = deque() @@ -1077,8 +1067,12 @@ def filter(self, keep: List[int]): self.tokens = [self.tokens[idx] for idx in keep] if any(self.samplers): self.samplers = [self.samplers[idx] for idx in keep] + else: + self.samplers = [None] * len(keep) if any(self.logits_processors): self.logits_processors = [self.logits_processors[idx] for idx in keep] + else: + self.logits_processors = [[]] * len(keep) self.max_tokens = [self.max_tokens[idx] for idx in keep] def prompt(self, tokens: List[List[int]]): @@ -1352,6 +1346,7 @@ def filter(self, keep: List[int]): self.samplers = [self.samplers[idx] for idx in keep] if any(self.logits_processors): self.logits_processors = [self.logits_processors[idx] for idx in keep] + self.max_tokens = [self.max_tokens[idx] for idx in keep] self._next_tokens = self._next_tokens[keep] if keep else None self._next_logprobs = [self._next_logprobs[idx] for idx in keep] @@ -1455,7 +1450,7 @@ def __init__( self, model: nn.Module, max_tokens: int = 128, - stop_tokens: Optional[List[List[int]]] = None, + stop_tokens: Optional[Sequence[Sequence[int]]] = None, sampler: Optional[Callable[[mx.array], mx.array]] = None, logits_processors: Optional[ List[Callable[[mx.array, mx.array], mx.array]] @@ -1505,6 +1500,9 @@ def close(self): mx.set_wired_limit(self._old_wired_limit) self._old_wired_limit = None + def __del__(self): + self.close() + @contextlib.contextmanager def stats(self, stats=None): stats = stats or BatchStats() @@ -1526,9 +1524,6 @@ def stats(self, stats=None): stats.generation_tps = stats.generation_tokens / stats.generation_time stats.peak_memory = max(stats.peak_memory, mx.get_peak_memory() / 1e9) - def __del__(self): - self.close() - def insert( self, prompts: List[List[int]], @@ -1741,6 +1736,35 @@ def next(self): with mx.stream(generation_stream): return self._next() + def next_generated(self): + """ + Return only generated tokens ignoring batch generation responses. + + Returns: + List of GenerationBatch.Response objects + """ + with mx.stream(generation_stream): + while True: + prompt_responses, generation_responses = self._next() + if not generation_responses and prompt_responses: + continue + return generation_responses + + +@dataclass +class BatchResponse: + """ + A data object to hold a batch generation response. + + Args: + texts: (List[str]): The generated text for each prompt. + stats (BatchStats): Statistics about the generation. + """ + + texts: List[str] + stats: BatchStats + caches: Optional[List[List[Any]]] + def batch_generate( model, @@ -1789,11 +1813,8 @@ def batch_generate( results = {uid: [] for uid in uids} prompt_caches = {} with gen.stats() as stats: - while True: - prompt_responses, token_responses = gen.next() - if not prompt_responses and not token_responses: - break - for r in token_responses: + while responses := gen.next_generated(): + for r in responses: if r.finish_reason is not None: if return_prompt_caches: prompt_caches[r.uid] = r.prompt_cache diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 537009642..d87ff1340 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -621,13 +621,23 @@ def filter(self, batch_indices): """ In-place filter to keep just the given indices in the cache. """ - self.cache = [c[batch_indices] for c in self.cache] + self.cache = [c[batch_indices] if c is not None else None for c in self.cache] + if self.lengths is not None: + self.lengths = self.lengths[batch_indices] def extend(self, other): """ In-place extend this cache with the other cache. """ - self.cache = [mx.concatenate([c, o]) for c, o in zip(self.cache, other.cache)] + + def cat(a, b): + if a is None: + return b + if b is None: + return a + return mx.concatenate([a, b]) + + self.cache = [cat(c, o) for c, o in zip(self.cache, other.cache)] def extract(self, idx): cache = ArraysCache(len(self.cache)) @@ -975,19 +985,18 @@ def filter(self, batch_indices): """ In-place filter to keep just the given indices in the cache. """ - if self.keys is None: - return - - self.keys = self.keys[batch_indices] - self.values = self.values[batch_indices] + if self.keys is not None: + self.keys = self.keys[batch_indices] + self.values = self.values[batch_indices] self.offset = self.offset[batch_indices] self.left_padding = self.left_padding[batch_indices] # Shift left to reduce padding min_left_pad = self.left_padding.min().item() if min_left_pad > 0: - self.keys = self.keys[..., min_left_pad:, :] - self.values = self.values[..., min_left_pad:, :] + if self.keys is not None: + self.keys = self.keys[..., min_left_pad:, :] + self.values = self.values[..., min_left_pad:, :] self._idx -= min_left_pad self.left_padding -= min_left_pad @@ -995,15 +1004,30 @@ def extend(self, other): """ In-place extend this cache with the other cache. """ + if self.keys is None and other.keys is None: + self.left_padding = mx.concatenate([self.left_padding, other.left_padding]) + self.offset = mx.concatenate([self.offset, other.offset]) + return + max_idx = max(self._idx, other._idx) - max_size = max(self.keys.shape[2], other.keys.shape[2]) + L1 = L2 = 0 + if self.keys is not None: + B, H, L1, D = self.keys.shape + M = self.values.shape[3] + if other.keys is not None: + B, H, L2, D = other.keys.shape + M = other.values.shape[3] + max_size = max(L1, L2) # Pad the keys and values so they are right-justified # with the index and the same size def pad(c): - left = max_idx - c._idx - right = max_size - c.keys.shape[2] - left k, v = c.keys, c.values + if k is None: + k = mx.array([]).reshape(B, H, 0, D) + v = mx.array([]).reshape(B, H, 0, M) + left = max_idx - c._idx + right = max_size - k.shape[2] - left if right < 0: k = k[..., :right, :] v = v[..., :right, :] @@ -1300,11 +1324,9 @@ def filter(self, batch_indices): """ In-place filter to keep just the given indices in the cache. """ - if self.keys is None: - return - - self.keys = self.keys[batch_indices] - self.values = self.values[batch_indices] + if self.keys is not None: + self.keys = self.keys[batch_indices] + self.values = self.values[batch_indices] self.offset = self.offset[batch_indices] self.left_padding = self.left_padding[batch_indices] diff --git a/tests/test_generate.py b/tests/test_generate.py index fee5801a6..fc23c7bf9 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -199,7 +199,7 @@ def test_batch_matches_single(self): self.model, stop_tokens=self.tokenizer.eos_token_ids, max_tokens=1 ) uids = gen.insert(prompts) - batch_responses = {r.uid: r for r in gen.next()} + batch_responses = {r.uid: r for r in gen.next_generated()} # Do a test for each prompt the logits are close for e, prompt in enumerate(prompts): @@ -241,7 +241,7 @@ def test_many_batches(self): batch_responses = {} not_in = True iters = 0 - while responses := gen.next(): + while responses := gen.next_generated(): for r in responses: not_in &= r.uid not in batch_responses batch_responses[r.uid] = r @@ -289,7 +289,7 @@ def test_batch_unique_max_toks(self): num_toks = [2, 3, 4, 5] uids = gen.insert(prompts, max_tokens=num_toks) batch_responses = {uid: [] for uid in uids} - while responses := gen.next(): + while responses := gen.next_generated(): for r in responses: batch_responses[r.uid].append(r.token) @@ -337,7 +337,7 @@ def test_batch_sliding_window(self): ) uids = batch_gen.insert(prompts) batch_responses = {uid: [] for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: batch_responses[r.uid].append(r.logprobs) @@ -370,7 +370,7 @@ def test_batch_generate_with_logits_processors(self): ) prompt = self.tokenizer.encode("hello") uids = batch_gen.insert([prompt]) - response = batch_gen.next()[0] + response = batch_gen.next_generated()[0] logprobs = response.logprobs self.assertEqual(logprobs[0].item(), 0.0) self.assertEqual(logprobs.argmin().item(), 1) @@ -395,7 +395,7 @@ def test_batch_generate_with_logits_processors(self): processors = make_logits_processors(logit_bias) (uid2,) = batch_gen.insert([prompt], logits_processors=[processors]) - responses = batch_gen.next() + responses = batch_gen.next_generated() responses = {response.uid: response for response in responses} self.assertEqual(responses[uid0].logprobs[0].item(), 0.0) self.assertEqual(responses[uid1].logprobs[1].item(), 0.0) @@ -410,7 +410,7 @@ def test_batch_generate_with_samplers(self): ) prompt = self.tokenizer.encode("hello") uids = batch_gen.insert([prompt]) - response = batch_gen.next()[0] + response = batch_gen.next_generated()[0] self.assertEqual(response.token, 1) del batch_gen @@ -427,7 +427,7 @@ def test_batch_generate_with_samplers(self): samplers=[lambda _: mx.array([2]), lambda _: mx.array([3])], ) - responses = batch_gen.next() + responses = batch_gen.next_generated() responses = {response.uid: response for response in responses} self.assertEqual(responses[uid0].token, 1) self.assertEqual(responses[uid1].token, 2) @@ -481,7 +481,7 @@ def test_batch_continued_generation(self): ) uids = batch_gen.insert(prompts_a) caches = {uid: None for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: if r.finish_reason is not None: caches[r.uid] = r.prompt_cache @@ -490,7 +490,7 @@ def test_batch_continued_generation(self): # Generate the 2nd time uids = batch_gen.insert(prompts_b, caches=caches) batch_responses = {uid: [] for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: batch_responses[r.uid].append(r.logprobs) @@ -543,7 +543,7 @@ def rand_prompt(n): uids = batch_gen.insert(prompts_a) caches = {uid: None for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: if r.finish_reason is not None: caches[r.uid] = r.prompt_cache @@ -553,7 +553,7 @@ def rand_prompt(n): # Generate the 2nd time uids = batch_gen.insert(prompts_b, caches=caches) batch_responses = {uid: [] for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: batch_responses[r.uid].append(r.logprobs) From f5e5745ddd9b1b744adcebe2e29a684ae68bd7fb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 25 Mar 2026 23:54:58 -0700 Subject: [PATCH 10/34] Fix rotating cache merge with empty and merge with full --- mlx_lm/models/cache.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index d87ff1340..88731c009 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1084,6 +1084,9 @@ def merge(cls, caches): return cache + def size(self): + return self._idx + def empty(self): return self.keys is None @@ -1334,17 +1337,32 @@ def extend(self, other): """ In-place extend this cache with the other cache. """ + if self.keys is None and other.keys is None: + self.left_padding = mx.concatenate([self.left_padding, other.left_padding]) + self.offset = mx.concatenate([self.offset, other.offset]) + return + if (self.rotated != other.rotated) or self._idx != other._idx: self._temporal_order() other._temporal_order() max_idx = max(self._idx, other._idx) - max_size = max(self.keys.shape[2], other.keys.shape[2]) + L1 = L2 = 0 + if self.keys is not None: + B, H, L1, D = self.keys.shape + M = self.values.shape[3] + if other.keys is not None: + B, H, L2, D = other.keys.shape + M = other.values.shape[3] + max_size = max(L1, L2) def pad(c): left = max_idx - c._idx - right = max_size - c.keys.shape[2] - left k, v = c.keys, c.values + if k is None: + k = mx.array([]).reshape(B, H, 0, D) + v = mx.array([]).reshape(B, H, 0, M) + right = max_size - k.shape[2] - left if right < 0: k = k[..., :right, :] v = v[..., :right, :] @@ -1403,11 +1421,11 @@ def merge(cls, caches): keys = mx.zeros((B, H, max_length, Dk), dtype=dt) values = mx.zeros((B, H, max_length, Dv), dtype=dt) - for i, (p, c) in enumerate(zip(padding, caches)): + for i, (p, l, c) in enumerate(zip(padding, lengths, caches)): if c.keys is None: continue - keys[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.keys) - values[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.values) + keys[i : i + 1, :, p : p + l] = c._temporal_order(c.keys) + values[i : i + 1, :, p : p + l] = c._temporal_order(c.values) cache = cls(caches[0].max_size, padding) cache.keys = keys @@ -1418,6 +1436,9 @@ def merge(cls, caches): return cache + def size(self): + return min(self._offset, self.max_size) + def empty(self): return self.keys is None From b8a53347376720dd417422d186163abd8ed0492a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Mar 2026 01:05:42 -0700 Subject: [PATCH 11/34] Add per sequence stop matcher --- mlx_lm/generate.py | 60 +++++++++++++++++++++++++++++------------- tests/test_generate.py | 36 +++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 19 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 71b2730b0..f57e9c2fb 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1005,7 +1005,7 @@ def __init__( logits_processors: Optional[ List[List[Callable[[mx.array, mx.array], mx.array]]] ] = None, - stop_matcher: Optional[SequenceMatcher] = None, + stop_matchers: Optional[List[SequenceMatcher]] = None, max_tokens: Optional[List[int]] = None, ): self.model = model @@ -1019,7 +1019,11 @@ def __init__( self.logits_processors = ( logits_processors if logits_processors is not None else [] ) - self.stop_matcher = stop_matcher or SequenceMatcher([]) + self.stop_matchers = ( + stop_matchers + if stop_matchers is not None + else [SequenceMatcher([])] * len(uids) + ) self.max_tokens = ( max_tokens if max_tokens is not None @@ -1047,6 +1051,7 @@ def extend(self, batch): self.samplers.extend(samplers) self.logits_processors.extend(logits_processors) self.max_tokens.extend(batch.max_tokens) + self.stop_matchers.extend(batch.stop_matchers) def split(self, indices: List[int]): indices = sorted(indices) @@ -1074,6 +1079,7 @@ def filter(self, keep: List[int]): else: self.logits_processors = [[]] * len(keep) self.max_tokens = [self.max_tokens[idx] for idx in keep] + self.stop_matchers = [self.stop_matchers[idx] for idx in keep] def prompt(self, tokens: List[List[int]]): """ @@ -1146,7 +1152,7 @@ def generate(self, tokens: List[List[int]]): self.samplers, self.fallback_sampler, self.logits_processors, - self.stop_matcher, + self.stop_matchers, self.max_tokens, ) @@ -1170,7 +1176,6 @@ def empty( return cls( model=model, fallback_sampler=fallback_sampler, - stop_matcher=stop_matcher, prefill_step_size=prefill_step_size, uids=[], caches=[], @@ -1178,6 +1183,7 @@ def empty( samplers=[], logits_processors=[], max_tokens=[], + stop_matchers=[], ) @@ -1211,7 +1217,7 @@ def __init__( logits_processors: Optional[ List[List[Callable[[mx.array, mx.array], mx.array]]] ], - stop_matcher: SequenceMatcher, + stop_matchers: List[SequenceMatcher], max_tokens: List[int], ): self.model = model @@ -1222,7 +1228,7 @@ def __init__( self.samplers = samplers self.fallback_sampler = fallback_sampler self.logits_processors = logits_processors - self.stop_matcher = stop_matcher + self.stop_matchers = stop_matchers self.max_tokens = max_tokens if self.samplers and len(self.samplers) != len(self.uids): @@ -1236,7 +1242,7 @@ def __init__( self._next_logprobs = [] self._token_context = [mx.array(t[-256:]) for t in tokens] self._num_tokens = [0] * len(self.uids) - self._matcher_states = [self.stop_matcher.make_state()] * len(self.uids) + self._matcher_states = [m.make_state() for m in stop_matchers] if self.uids: self._step() @@ -1252,6 +1258,7 @@ def extend(self, batch): self.samplers.extend(batch.samplers) self.logits_processors.extend(batch.logits_processors) self.max_tokens.extend(batch.max_tokens) + self.stop_matchers.extend(batch.stop_matchers) if self._current_tokens is None: self._current_tokens = batch._current_tokens self._current_logprobs = batch._current_logprobs @@ -1347,6 +1354,7 @@ def filter(self, keep: List[int]): if any(self.logits_processors): self.logits_processors = [self.logits_processors[idx] for idx in keep] self.max_tokens = [self.max_tokens[idx] for idx in keep] + self.stop_matchers = [self.stop_matchers[idx] for idx in keep] self._next_tokens = self._next_tokens[keep] if keep else None self._next_logprobs = [self._next_logprobs[idx] for idx in keep] @@ -1376,7 +1384,7 @@ def next(self) -> List[Response]: if self._num_tokens[i] >= self.max_tokens[i]: finish_reason = "length" - self._matcher_states[i], match = self.stop_matcher.match( + self._matcher_states[i], match = self.stop_matchers[i].match( self._matcher_states[i], tokens[i] ) if match is not None: @@ -1419,12 +1427,10 @@ def empty( cls, model: nn.Module, fallback_sampler: Callable[[mx.array], mx.array], - stop_matcher: SequenceMatcher, ): return cls( model=model, fallback_sampler=fallback_sampler, - stop_matcher=stop_matcher, uids=[], inputs=mx.array([], dtype=mx.uint32), prompt_cache=[], @@ -1432,6 +1438,7 @@ def empty( samplers=[], logits_processors=[], max_tokens=[], + stop_matchers=[], ) @@ -1461,7 +1468,6 @@ def __init__( ): self.model = model self.max_tokens = max_tokens - self.stop_matcher = SequenceMatcher(stop_tokens or []) self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) self.logits_processors = logits_processors or [] self.uid_count = 0 @@ -1469,16 +1475,15 @@ def __init__( self.prefill_batch_size = prefill_batch_size self.completion_batch_size = max(completion_batch_size, prefill_batch_size) + self._default_stop_matcher = SequenceMatcher(stop_tokens or []) self._uid_count = 0 self._prompt_batch = PromptProcessingBatch.empty( self.model, self.sampler, - self.stop_matcher, + self._default_stop_matcher, prefill_step_size=prefill_step_size, ) - self._generation_batch = GenerationBatch.empty( - self.model, self.sampler, self.stop_matcher - ) + self._generation_batch = GenerationBatch.empty(self.model, self.sampler) self._unprocessed_sequences = deque() self._currently_processing = [] @@ -1534,6 +1539,7 @@ def insert( logits_processors: Optional[ List[List[Callable[[mx.array, mx.array], mx.array]]] ] = None, + stop_matchers: Optional[List[SequenceMatcher]] = None, ): return self.insert_segments( [[p] for p in prompts], @@ -1542,6 +1548,7 @@ def insert( all_tokens, samplers, logits_processors, + stop_matchers, ) def insert_segments( @@ -1554,6 +1561,7 @@ def insert_segments( logits_processors: Optional[ List[List[Callable[[mx.array, mx.array], mx.array]]] ] = None, + stop_matchers: Optional[List[SequenceMatcher]] = None, ): uids = [] @@ -1573,14 +1581,25 @@ def insert_segments( [self.logits_processors] * len(segments) ) - for seq, m, c, at, s, lp in zip( - segments, max_tokens, caches, all_tokens, samplers, logits_processors + if stop_matchers is None: + stop_matchers = [self._default_stop_matcher] * len(segments) + + for seq, m, c, at, s, lp, sm in zip( + segments, + max_tokens, + caches, + all_tokens, + samplers, + logits_processors, + stop_matchers, ): seq = list(seq) if len(seq[-1]) != 1: seq.append(seq[-1][-1:]) seq[-2] = seq[-2][:-1] - self._unprocessed_sequences.append((self._uid_count, seq, m, c, at, s, lp)) + self._unprocessed_sequences.append( + (self._uid_count, seq, m, c, at, s, lp, sm) + ) uids.append(self._uid_count) self._uid_count += 1 @@ -1624,6 +1643,7 @@ def _make_batch(self, n: int): samplers = [] logits_processors = [] max_tokens = [] + stop_matchers = [] for _ in range(n): sequence = self._unprocessed_sequences.popleft() uids.append(sequence[0]) @@ -1632,9 +1652,11 @@ def _make_batch(self, n: int): samplers.append(sequence[5]) logits_processors.append(sequence[6]) max_tokens.append(sequence[2]) + stop_matchers.append(sequence[7]) self._currently_processing.append( [sequence[1], 0, sum(len(s) for s in sequence[1])] ) + return PromptProcessingBatch( model=self.model, uids=uids, @@ -1644,7 +1666,7 @@ def _make_batch(self, n: int): samplers=samplers, fallback_sampler=self.sampler, logits_processors=logits_processors, - stop_matcher=self.stop_matcher, + stop_matchers=stop_matchers, max_tokens=max_tokens, ) diff --git a/tests/test_generate.py b/tests/test_generate.py index fc23c7bf9..40be315ad 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -9,6 +9,7 @@ from mlx_lm.generate import ( BatchGenerator, GenerationResponse, + SequenceMatcher, batch_generate, generate, generate_step, @@ -433,6 +434,41 @@ def test_batch_generate_with_samplers(self): self.assertEqual(responses[uid1].token, 2) self.assertEqual(responses[uid2].token, 3) + def test_batch_generate_with_stop_matchers(self): + """Test that batch_generate with per-sequence stop_matchers stops on different tokens.""" + batch_gen = BatchGenerator( + self.model, + max_tokens=10, + ) + prompt = self.tokenizer.encode("hello") + + stop_matcher_0 = SequenceMatcher([[0]]) + stop_matcher_1 = SequenceMatcher([[1]]) + stop_matcher_2 = SequenceMatcher([[2]]) + + processor_0 = make_logits_processors({0: 2000.0}) + processor_1 = make_logits_processors({1: 2000.0}) + processor_2 = make_logits_processors({2: 2000.0}) + + uid0, uid1, uid2 = batch_gen.insert( + [prompt, prompt, prompt], + logits_processors=[processor_0, processor_1, processor_2], + stop_matchers=[stop_matcher_0, stop_matcher_1, stop_matcher_2], + ) + + responses = batch_gen.next_generated() + responses = {response.uid: response for response in responses} + + self.assertEqual(responses[uid0].token, 0) + self.assertEqual(responses[uid1].token, 1) + self.assertEqual(responses[uid2].token, 2) + self.assertEqual(responses[uid0].finish_reason, "stop") + self.assertEqual(responses[uid1].finish_reason, "stop") + self.assertEqual(responses[uid2].finish_reason, "stop") + self.assertEqual(responses[uid0].stop_sequence, (0,)) + self.assertEqual(responses[uid1].stop_sequence, (1,)) + self.assertEqual(responses[uid2].stop_sequence, (2,)) + def test_batch_continued_generation(self): for rotating in [False, True]: if rotating: From 6e2604039c019c84c633e66709875208e084a00d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Mar 2026 01:09:48 -0700 Subject: [PATCH 12/34] Remove forgotten pdb --- mlx_lm/examples/batch_generate_response.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlx_lm/examples/batch_generate_response.py b/mlx_lm/examples/batch_generate_response.py index 908fefc09..0925113d9 100644 --- a/mlx_lm/examples/batch_generate_response.py +++ b/mlx_lm/examples/batch_generate_response.py @@ -30,9 +30,6 @@ model, tokenizer, prompts, verbose=False, return_prompt_caches=True, max_tokens=2048 ) print(result.texts[-1]) -import pdb - -pdb.set_trace() prompts = [ "Could you summarize that?", From dca3e4c23338435f3016d39efa27305132cc5d8a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Mar 2026 02:15:20 -0700 Subject: [PATCH 13/34] Add remove and prompt_cache_nbytes --- mlx_lm/generate.py | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index f57e9c2fb..d8ea3904f 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1609,14 +1609,14 @@ def _find_uids(self, uids): uids = set(uids) results = {} for i, uid_i in enumerate(self._generation_batch.uids): - if uid_i in uid: - result[uid_i] = (2, i) + if uid_i in uids: + results[uid_i] = (2, i) for i, uid_i in enumerate(self._prompt_batch.uids): if uid_i in uids: - result[uid_i] = (1, i) + results[uid_i] = (1, i) for i, seq in enumerate(self._unprocessed_sequences): - if seq[0] in uid: - result[seq[0]] = (0, i) + if seq[0] in uids: + results[seq[0]] = (0, i) return results def extract_cache(self, uids): @@ -1627,15 +1627,46 @@ def extract_cache(self, uids): elif stage == 1: results[uid] = ( self._prompt_batch.extract_cache(idx), - self._prompt_batch.tokens[i], + self._prompt_batch.tokens[idx], ) else: results[uid] = ( self._generation_batch.extract_cache(idx), - self._generation_batch.tokens[i], + self._generation_batch.tokens[idx], ) return results + def remove(self, uids, return_prompt_caches=False): + caches = {} + if return_prompt_caches: + caches = self.extract_cache(uids) + + keep = ( + set(range(len(self._unprocessed_sequences))), + set(range(len(self._prompt_batch))), + set(range(len(self._generation_batch))), + ) + for stage, idx in self._find_uids(uids).values(): + keep[stage].remove(idx) + + if len(keep[0]) < len(self._unprocessed_sequences): + self._unprocessed_sequences = deque( + x for i, x in enumerate(self._unprocessed_sequences) if i in keep[0] + ) + if len(keep[1]) < len(self._prompt_batch): + self._prompt_batch.filter(sorted(keep[1])) + if len(keep[2]) < len(self._generation_batch): + self._generation_batch.filter(sorted(keep[2])) + + return caches + + @property + def prompt_cache_nbytes(self): + total = sum(c.nbytes for c in self._unprocessed_sequences for c in p[3]) + total += sum(c.nbytes for c in self._prompt_batch.prompt_cache) + total += sum(c.nbytes for c in self._generation_batch.prompt_cache) + return total + def _make_batch(self, n: int): uids = [] caches = [] From 7a2b273c6ef4dd7a1da1a6d452ca9ee678a827a1 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Mar 2026 15:06:11 -0700 Subject: [PATCH 14/34] Change the SequenceMatcher to a full on state machine --- mlx_lm/generate.py | 238 ++++++++++++++++++++++++++++----------------- 1 file changed, 149 insertions(+), 89 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index d8ea3904f..73ee521aa 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -929,53 +929,112 @@ def _extend_cache(cache_a, cache_b): return cache_a -class SequenceMatcher: - def __init__(self, sequences): - self._trie = {} - - # Make the trie - for seq in sequences: - current = self._trie - try: - for x in seq: - current = current.setdefault(x, {}) - current["__seq__"] = tuple(seq) - except TypeError: - current = current.setdefault(seq, {}) - current["__seq__"] = (seq,) - - # Propagate failure links and matched sequences - queue = deque() - for x, child in self._trie.items(): - child["__fail__"] = self._trie +def _build_trie(sequences): + """Build an Aho-Corasick trie from the provided sequences + + See https://en.wikipedia.org/wiki/Aho–Corasick_algorithm . + """ + trie = {} + for idx, seq in enumerate(sequences): + node = trie + try: + for tok in seq: + node = node.setdefault(tok, {}) + node["__match__"] = (tuple(seq), idx) + except TypeError: + node = node.setdefault(seq, {}) + node["__match__"] = ((seq,), idx) + + # BFS to set failure links and propagate matches. + queue = deque() + for key, child in trie.items(): + if key == "__match__": + continue + child["__fail__"] = trie + queue.append(child) + while queue: + parent = queue.popleft() + for key, child in parent.items(): + if key in ("__fail__", "__match__"): + continue queue.append(child) - while queue: - parent = queue.popleft() - for x, child in parent.items(): - if x in ("__fail__", "__seq__"): - continue - queue.append(child) - failure_node = parent["__fail__"] - while x not in failure_node and failure_node is not self._trie: - failure_node = failure_node["__fail__"] - if x in failure_node: - child["__fail__"] = failure_node[x] - else: - child["__fail__"] = self._trie - if "__seq__" not in child and "__seq__" in child["__fail__"]: - child["__seq__"] = child["__fail__"]["__seq__"] + fail = parent["__fail__"] + while key not in fail and fail is not trie: + fail = fail["__fail__"] + child["__fail__"] = fail[key] if key in fail else trie + if "__match__" not in child and "__match__" in child["__fail__"]: + child["__match__"] = child["__fail__"]["__match__"] + return trie + + +def _step_trie(node, trie, x): + """One step in the Aho-Corasick trie.""" + while x not in node and node is not trie: + node = node["__fail__"] + if x in node: + node = node[x] + return node + + +class SequenceStateMachine: + """A state machine that uses one Aho-Corasick trie per state to efficiently + track state across a generated sequence. + + The transitions are provided as state -> [(sequence, new_state)]. + + Example: + + sm = SequenceStateMachine( + transitions={ + "normal": [ + (think_start_tokens, "reasoning"), + (tool_start_tokens, "tool"), + (eos, None), + ], + "reasoning": [ + (think_end_tokens, "normal"), + (eos, None), + ], + "tool": [ + (tool_end_tokens, None), + (eos, None) + ], + }, + initial="normal" + ) + """ + + def __init__(self, transitions={}, initial="normal"): + self._initial = initial + self._states = {} + for src, edges in transitions.items(): + sequences, dst = zip(*edges) + self._states[src] = (_build_trie(sequences), dst) + if not self._states: + self._states[initial] = (_build_trie([]), []) + + def __deepcopy__(self, memo): + new = object.__new__(SequenceStateMachine) + new._initial = self._initial + new._states = self._states + return new def make_state(self): - return (self._trie, self._trie) + return (self._initial, self._states[self._initial][0], self._states) @staticmethod def match(state, x): - state, trie = state - while x not in state and state is not trie: - state = state["__fail__"] - if x in state: - state = state[x] - return (state, trie), state.get("__seq__") + s, n, states = state + n = _step_trie(n, states[s][0], x) + + seq = None + match = n.get("__match__") + if match is not None: + seq = match[0] + s = states[s][1][match[1]] + n = states[s][0] if s is not None else None + + return (s, n, states), seq, s class PromptProcessingBatch: @@ -1005,7 +1064,7 @@ def __init__( logits_processors: Optional[ List[List[Callable[[mx.array, mx.array], mx.array]]] ] = None, - stop_matchers: Optional[List[SequenceMatcher]] = None, + state_machines: Optional[List[SequenceStateMachine]] = None, max_tokens: Optional[List[int]] = None, ): self.model = model @@ -1019,10 +1078,10 @@ def __init__( self.logits_processors = ( logits_processors if logits_processors is not None else [] ) - self.stop_matchers = ( - stop_matchers - if stop_matchers is not None - else [SequenceMatcher([])] * len(uids) + self.state_machines = ( + state_machines + if state_machines is not None + else [SequenceStateMachine()] * len(uids) ) self.max_tokens = ( max_tokens @@ -1051,7 +1110,7 @@ def extend(self, batch): self.samplers.extend(samplers) self.logits_processors.extend(logits_processors) self.max_tokens.extend(batch.max_tokens) - self.stop_matchers.extend(batch.stop_matchers) + self.state_machines.extend(batch.state_machines) def split(self, indices: List[int]): indices = sorted(indices) @@ -1079,7 +1138,7 @@ def filter(self, keep: List[int]): else: self.logits_processors = [[]] * len(keep) self.max_tokens = [self.max_tokens[idx] for idx in keep] - self.stop_matchers = [self.stop_matchers[idx] for idx in keep] + self.state_machines = [self.state_machines[idx] for idx in keep] def prompt(self, tokens: List[List[int]]): """ @@ -1152,7 +1211,7 @@ def generate(self, tokens: List[List[int]]): self.samplers, self.fallback_sampler, self.logits_processors, - self.stop_matchers, + self.state_machines, self.max_tokens, ) @@ -1170,7 +1229,7 @@ def empty( cls, model: nn.Module, fallback_sampler: Callable[[mx.array], mx.array], - stop_matcher: SequenceMatcher, + state_machine: SequenceStateMachine, prefill_step_size: int = 2048, ): return cls( @@ -1183,7 +1242,7 @@ def empty( samplers=[], logits_processors=[], max_tokens=[], - stop_matchers=[], + state_machines=[], ) @@ -1201,7 +1260,8 @@ class Response: token: int logprobs: mx.array finish_reason: Optional[str] - stop_sequence: Optional[List[int]] + current_state: Optional[str] + match_sequence: Optional[List[int]] prompt_cache: Optional[List[Any]] all_tokens: Optional[List[int]] @@ -1217,7 +1277,7 @@ def __init__( logits_processors: Optional[ List[List[Callable[[mx.array, mx.array], mx.array]]] ], - stop_matchers: List[SequenceMatcher], + state_machines: List[SequenceStateMachine], max_tokens: List[int], ): self.model = model @@ -1228,7 +1288,7 @@ def __init__( self.samplers = samplers self.fallback_sampler = fallback_sampler self.logits_processors = logits_processors - self.stop_matchers = stop_matchers + self.state_machines = state_machines self.max_tokens = max_tokens if self.samplers and len(self.samplers) != len(self.uids): @@ -1242,7 +1302,7 @@ def __init__( self._next_logprobs = [] self._token_context = [mx.array(t[-256:]) for t in tokens] self._num_tokens = [0] * len(self.uids) - self._matcher_states = [m.make_state() for m in stop_matchers] + self._matcher_states = [m.make_state() for m in state_machines] if self.uids: self._step() @@ -1258,7 +1318,7 @@ def extend(self, batch): self.samplers.extend(batch.samplers) self.logits_processors.extend(batch.logits_processors) self.max_tokens.extend(batch.max_tokens) - self.stop_matchers.extend(batch.stop_matchers) + self.state_machines.extend(batch.state_machines) if self._current_tokens is None: self._current_tokens = batch._current_tokens self._current_logprobs = batch._current_logprobs @@ -1354,7 +1414,7 @@ def filter(self, keep: List[int]): if any(self.logits_processors): self.logits_processors = [self.logits_processors[idx] for idx in keep] self.max_tokens = [self.max_tokens[idx] for idx in keep] - self.stop_matchers = [self.stop_matchers[idx] for idx in keep] + self.state_machines = [self.state_machines[idx] for idx in keep] self._next_tokens = self._next_tokens[keep] if keep else None self._next_logprobs = [self._next_logprobs[idx] for idx in keep] @@ -1378,18 +1438,17 @@ def next(self) -> List[Response]: responses = [] for i in range(len(self.uids)): finish_reason = None - stop_sequence = None + match_sequence = None self._num_tokens[i] += 1 if self._num_tokens[i] >= self.max_tokens[i]: finish_reason = "length" - self._matcher_states[i], match = self.stop_matchers[i].match( - self._matcher_states[i], tokens[i] + self._matcher_states[i], match_sequence, current_state = ( + self.state_machines[i].match(self._matcher_states[i], tokens[i]) ) - if match is not None: + if match_sequence is not None and current_state is None: finish_reason = "stop" - stop_sequence = match if finish_reason is not None: responses.append( @@ -1398,7 +1457,8 @@ def next(self) -> List[Response]: token=tokens[i], logprobs=logprobs[i], finish_reason=finish_reason, - stop_sequence=stop_sequence, + current_state=current_state, + match_sequence=match_sequence, prompt_cache=self.extract_cache(i), all_tokens=self.tokens[i], ) @@ -1411,7 +1471,8 @@ def next(self) -> List[Response]: token=tokens[i], logprobs=logprobs[i], finish_reason=None, - stop_sequence=None, + match_sequence=match_sequence, + current_state=current_state, prompt_cache=None, all_tokens=None, ) @@ -1438,7 +1499,7 @@ def empty( samplers=[], logits_processors=[], max_tokens=[], - stop_matchers=[], + state_machines=[], ) @@ -1475,12 +1536,15 @@ def __init__( self.prefill_batch_size = prefill_batch_size self.completion_batch_size = max(completion_batch_size, prefill_batch_size) - self._default_stop_matcher = SequenceMatcher(stop_tokens or []) + self._default_state_machine = SequenceStateMachine( + {"default": [(seq, None) for seq in (stop_tokens or [])]}, + initial="default", + ) self._uid_count = 0 self._prompt_batch = PromptProcessingBatch.empty( self.model, self.sampler, - self._default_stop_matcher, + self._default_state_machine, prefill_step_size=prefill_step_size, ) self._generation_batch = GenerationBatch.empty(self.model, self.sampler) @@ -1532,14 +1596,14 @@ def stats(self, stats=None): def insert( self, prompts: List[List[int]], - max_tokens: Optional[Union[List[int], int]] = None, + max_tokens: Optional[List[int]] = None, caches: Optional[List[List[Any]]] = None, all_tokens: Optional[List[List[int]]] = None, samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, logits_processors: Optional[ List[List[Callable[[mx.array, mx.array], mx.array]]] ] = None, - stop_matchers: Optional[List[SequenceMatcher]] = None, + state_machines: Optional[List[SequenceStateMachine]] = None, ): return self.insert_segments( [[p] for p in prompts], @@ -1548,41 +1612,37 @@ def insert( all_tokens, samplers, logits_processors, - stop_matchers, + state_machines, ) def insert_segments( self, segments: List[List[List[int]]], - max_tokens: Optional[Union[List[int], int]] = None, + max_tokens: Optional[List[int]] = None, caches: Optional[List[List[Any]]] = None, all_tokens: Optional[List[List[int]]] = None, samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, logits_processors: Optional[ List[List[Callable[[mx.array, mx.array], mx.array]]] ] = None, - stop_matchers: Optional[List[SequenceMatcher]] = None, + state_machines: Optional[List[SequenceStateMachine]] = None, ): uids = [] - if max_tokens is None or isinstance(max_tokens, int): - max_tokens = [max_tokens or self.max_tokens] * len(segments) - - if caches is None: - caches = [None] * len(segments) - for i in range(len(segments)): - if caches[i] is None: - caches[i] = cache.make_prompt_cache(self.model) - + max_tokens = max_tokens or [self.max_tokens] * len(segments) all_tokens = all_tokens or [[] for _ in segments] - - samplers = samplers or ([None] * len(segments)) + samplers = samplers or [None] * len(segments) logits_processors = logits_processors or ( [self.logits_processors] * len(segments) ) + state_machines = state_machines or ( + [self._default_state_machine] * len(segments) + ) - if stop_matchers is None: - stop_matchers = [self._default_stop_matcher] * len(segments) + caches = caches or [None] * len(segments) + for i in range(len(segments)): + if caches[i] is None: + caches[i] = cache.make_prompt_cache(self.model) for seq, m, c, at, s, lp, sm in zip( segments, @@ -1591,7 +1651,7 @@ def insert_segments( all_tokens, samplers, logits_processors, - stop_matchers, + state_machines, ): seq = list(seq) if len(seq[-1]) != 1: @@ -1674,7 +1734,7 @@ def _make_batch(self, n: int): samplers = [] logits_processors = [] max_tokens = [] - stop_matchers = [] + state_machines = [] for _ in range(n): sequence = self._unprocessed_sequences.popleft() uids.append(sequence[0]) @@ -1683,7 +1743,7 @@ def _make_batch(self, n: int): samplers.append(sequence[5]) logits_processors.append(sequence[6]) max_tokens.append(sequence[2]) - stop_matchers.append(sequence[7]) + state_machines.append(sequence[7]) self._currently_processing.append( [sequence[1], 0, sum(len(s) for s in sequence[1])] ) @@ -1697,7 +1757,7 @@ def _make_batch(self, n: int): samplers=samplers, fallback_sampler=self.sampler, logits_processors=logits_processors, - stop_matchers=stop_matchers, + state_machines=state_machines, max_tokens=max_tokens, ) From 6225ddeed2785c67a68ae9087bb648212addb560 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Mar 2026 15:11:27 -0700 Subject: [PATCH 15/34] Fix empty stop tokens --- mlx_lm/generate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 73ee521aa..ba4b16ac5 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1537,8 +1537,8 @@ def __init__( self.completion_batch_size = max(completion_batch_size, prefill_batch_size) self._default_state_machine = SequenceStateMachine( - {"default": [(seq, None) for seq in (stop_tokens or [])]}, - initial="default", + {"normal": [(seq, None) for seq in stop_tokens]} if stop_tokens else {}, + initial="normal", ) self._uid_count = 0 self._prompt_batch = PromptProcessingBatch.empty( From d02eb3c35e7d8980645a1050ca8b859e0daa6471 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Mar 2026 15:13:50 -0700 Subject: [PATCH 16/34] Fix max tokens handling in batch_generate --- mlx_lm/generate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ba4b16ac5..22a8de5c2 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1922,6 +1922,9 @@ def batch_generate( if verbose: print(f"[batch_generate] Finished processing 0/{num_samples} ...", end="\r") + if isinstance(max_tokens, int): + max_tokens = [max_tokens] * len(prompts) + uids = gen.insert(prompts, max_tokens, caches=prompt_caches) results = {uid: [] for uid in uids} prompt_caches = {} From e85cc9de696047f1ee73d8986d7b2ed0c2a20f43 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Mar 2026 18:05:23 -0700 Subject: [PATCH 17/34] Start transitioning the server to the new APIs --- mlx_lm/generate.py | 2 - mlx_lm/server.py | 265 +++++++++++++++++++++++++++++------------ tests/test_generate.py | 20 ++-- 3 files changed, 201 insertions(+), 86 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 22a8de5c2..5cf394a11 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1229,7 +1229,6 @@ def empty( cls, model: nn.Module, fallback_sampler: Callable[[mx.array], mx.array], - state_machine: SequenceStateMachine, prefill_step_size: int = 2048, ): return cls( @@ -1544,7 +1543,6 @@ def __init__( self._prompt_batch = PromptProcessingBatch.empty( self.model, self.sampler, - self._default_state_machine, prefill_step_size=prefill_step_size, ) self._generation_batch = GenerationBatch.empty(self.model, self.sampler) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index c5d1f95c3..dbf4b809d 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -32,7 +32,12 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ -from .generate import BatchGenerator, generation_stream, stream_generate +from .generate import ( + BatchGenerator, + SequenceStateMachine, + generation_stream, + stream_generate, +) from .models.cache import ( LRUPromptCache, can_trim_prompt_cache, @@ -249,6 +254,8 @@ def stop(self): class Response: text: str token: int + state: str + match: Tuple[int] logprob: float finish_reason: Optional[str] top_tokens: Tuple[Dict[str, Any]] @@ -449,6 +456,7 @@ def __init__(self, model_provider: ModelProvider, prompt_cache: LRUPromptCache): self.model_provider = model_provider self.prompt_cache = prompt_cache self.requests = Queue() + self._state_machine_cache = {} self._time_budget = TimeBudget() self._is_distributed = mx.distributed.init().size() > 1 @@ -518,6 +526,17 @@ def _share_request(self, request): return rq, *shareable def _tokenize(self, tokenizer, request, args): + """Tokenize a request and split the prompt into segments. + + Returns a tuple + + * prompt - full list of tokens + * segments - a list of lists of tokens. Up to 3 segments that + correspond to system prompt, context, thinking tail. + * initial state - a string that contains the initial state of the + state machine (normal or thinking depending on whether we have tail + or not) + """ if request.request_type == "chat": messages = request.messages tools = request.tools @@ -536,35 +555,118 @@ def _tokenize(self, tokenizer, request, args): if args.chat_template_kwargs: chat_template_args = chat_template_args.copy() chat_template_args.update(args.chat_template_kwargs) - return tokenizer.apply_chat_template( - messages, + template_kwargs = dict( tools=tools, - add_generation_prompt=True, tokenize=True, **chat_template_args, ) + prompt = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + **template_kwargs, + ) else: - return tokenizer.encode(convert_chat(messages, role_mapping)) + prompt = tokenizer.encode(convert_chat(messages, role_mapping)) + return prompt, [prompt], "normal" else: - return tokenizer.encode(request.prompt) - - def _compute_prompt_checkpoint(self, tokenizer, request, prompt): - if request.request_type != "chat": - return False, -1 - if request.messages[-1]["role"] != "user": - return False, -1 - - # Save the KV cache at the end of the prompt just before - # the think start token which will likely be removed in the - # next turn. - prompt_checkpoint = -1 + prompt = tokenizer.encode(request.prompt) + return prompt, [prompt], "normal" + + # If we are here it means we have a chat request so we need to search + # for segments for better cache management. + + # It is not a user message so no segmentation needed. + if messages[-1]["role"] != "user": + return prompt, [prompt], "normal" + + segments = [] + initial_state = "normal" + + # Find where the system prompt ends and add it as a segment. + num_system = 0 + sys_end = 0 + for m in messages: + if m["role"] == "system": + num_system += 1 + else: + break + if num_system > 0: + sys_tokens = tokenizer.apply_chat_template( + messages[:num_system], + add_generation_prompt=False, + **template_kwargs, + ) + sys_end = len(sys_tokens) + if sys_end > 0 and sys_end < len(prompt): + segments.append(prompt[:sys_end]) + + # The following code does 2 things: + # 1. Find a tail segment that contains thinking tokens + # 2. Find whether we have a think start token without corresponding + # think end token and set the state to reasoning. + tail_start = len(prompt) + state = "reasoning" if tokenizer.has_thinking: - for i in range(1, min(11, len(prompt)) - 1, 1): + for i in range(1, min(11, len(prompt) - sys_end), 1): + if prompt[-i] == tokenizer.think_end_id: + state = "normal" + continue if prompt[-i] == tokenizer.think_start_id: - prompt_checkpoint = -i - 1 + tail_start = len(prompt) - i + initial_state = state break - return True, prompt_checkpoint + # Finalize the segments and return + if sys_end < tail_start: + segments.append(prompt[sys_end:tail_start]) + if tail_start < len(prompt): + segments.append(prompt[tail_start:]) + if not segments: + segments = [prompt] + + return prompt, segments, initial_state + + def _make_state_machine( + self, model_key, tokenizer, stop_words, initial_state="normal" + ): + """Make a new SequenceStateMachine or fetch it if we 've made it before.""" + cache_key = (model_key, tuple(stop_words)) + sm = self._state_machine_cache.get(cache_key) + if sm is not None: + return sm + + eos = [([t], None) for t in tokenizer.eos_token_ids] + stop = [ + (tokenizer.encode(w, add_special_tokens=False), None) for w in stop_words + ] + common_stops = eos + stop + + transitions = {"normal": list(common_stops)} + + if tokenizer.has_thinking: + transitions["normal"].append(([tokenizer.think_start_id], "reasoning")) + transitions["reasoning"] = [ + ([tokenizer.think_end_id], "normal"), + ] + common_stops + + if tokenizer.has_tool_calling: + tool_start = tokenizer.encode( + tokenizer.tool_call_start, add_special_tokens=False + ) + tool_end = tokenizer.encode( + tokenizer.tool_call_end, add_special_tokens=False + ) + transitions["normal"].append((tool_start, "tool")) + transitions["tool"] = [ + (tool_end, "normal"), + ] + common_stops + + sm = SequenceStateMachine(transitions, initial=initial_state) + if len(self._state_machine_cache) > 100: + self._state_machine_cache.clear() + self._state_machine_cache[cache_key] = sm + + return sm def _is_batchable(self, args): if not self.model_provider.is_batchable: @@ -591,23 +693,6 @@ def get_next_request(timeout=None): else: return self._next_request(timeout) - def progress_callback(info): - for uid, processed, total in info: - if uid in batch_results: - batch_results[uid]["rqueue"].put((min(processed, total), total)) - - def checkpoint_callback(prompts): - for uid, prompt_end, cache in prompts: - rs = batch_results[uid] - if not rs["checkpoint"]: - continue - self.prompt_cache.insert_cache( - current_model_key, - rs["cache_key"][:-prompt_end], - list(cache), - cache_type="user", - ) - if self._is_distributed: seed = mx.distributed.all_sum(mx.random.state[0]).view(mx.uint64).item() mx.random.seed(seed) @@ -633,7 +718,9 @@ def checkpoint_callback(prompts): and self._is_batchable(args) ): try: - prompt = self._tokenize(current_tokenizer, request, args) + prompt, segments, initial_state = self._tokenize( + current_tokenizer, request, args + ) except Exception as e: rqueue.put(e) continue @@ -660,28 +747,38 @@ def checkpoint_callback(prompts): cache, rest = self.prompt_cache.fetch_nearest_cache( current_model_key, prompt ) - ctx.prompt_cache_count = len(prompt) - len(rest) - if cache is None: - cache = make_prompt_cache(self.model_provider.model) - - do_checkpoint, checkpoint_position = ( - self._compute_prompt_checkpoint(tokenizer, request, prompt) - ) - - (uid,) = batch_generator.insert( - [rest], - args.max_tokens, + N = len(prompt) - len(rest) + ctx.prompt_cache_count = N + prompt = prompt[:N] + while N >= 0: + if N >= len(segments[0]): + N -= len(segments.pop(0)) + else: + segments[0] = segments[0][N:] + break + + (uid,) = batch_generator.insert_segments( + segments=[segments], + max_tokens=[args.max_tokens], caches=[cache], + all_tokens=[prompt], samplers=[_make_sampler(args, tokenizer)], logits_processors=[_make_logits_processors(args)], - prompt_checkpoints=[checkpoint_position], + state_machines=[ + self._make_state_machine( + self.model_provider.model_key, + tokenizer, + args.stop_words, + initial_state=initial_state, + ) + ], ) batch_results[uid] = { "ctx": ctx, - "cache_key": prompt[:], "rqueue": rqueue, "detokenizer": tokenizer.detokenizer, - "checkpoint": do_checkpoint, + "num_segments": len(segments), + "top_logprobs": args.top_logprobs, } # just making sure we don't leave a reference around del cache @@ -714,12 +811,9 @@ def checkpoint_callback(prompts): batch_results = {} batch_generator = BatchGenerator( model, - stop_tokens=tokenizer.eos_token_ids, completion_batch_size=self.cli_args.decode_concurrency, prefill_batch_size=self.cli_args.prompt_concurrency, prefill_step_size=self.cli_args.prefill_step_size, - prompt_progress_callback=progress_callback, - prompt_checkpoint_callback=checkpoint_callback, ) unprocessed_requests.append((rqueue, request, args)) continue @@ -746,24 +840,52 @@ def checkpoint_callback(prompts): uids_to_remove = [] for _ in self._time_budget: - responses = batch_generator.next() - if not responses: + prompt_responses, gen_responses = batch_generator.next() + if not prompt_responses and not gen_responses: break - for r in responses: + # Progress report for prompt processing + for r in prompt_responses: result = batch_results[r.uid] - result["cache_key"].append(r.token) - if r.finish_reason != "stop": - result["detokenizer"].add_token(r.token) + result["rqueue"].put(r.progress) + + # Save the caches at end of segments + eos_ids = [ + r.uid + for r in prompt_responses + if r.end_of_segment and not r.end_of_prompt + ] + caches = batch_generator.extract_cache(eos_ids) + for uid, (cache, cache_key) in caches.items(): + batch_results[uid]["num_segments"] -= 1 + cache_type = ( + "system" + if batch_results[uid]["num_segments"] > 1 + else "user" + ) + self.prompt_cache.insert_cache( + self.model_provider.model_key, + cache_key, + cache, + cache_type=cache_type, + ) + del caches + for r in gen_responses: + result = batch_results[r.uid] + result["detokenizer"].add_token(r.token) result["rqueue"].put( Response( result["detokenizer"].last_segment, r.token, + r.current_state, + r.match_sequence, r.logprobs[r.token].item(), r.finish_reason, _format_top_logprobs( - r.logprobs, args.top_logprobs, current_tokenizer + r.logprobs, + result["top_logprobs"], + current_tokenizer, ), ) ) @@ -771,7 +893,10 @@ def checkpoint_callback(prompts): if r.finish_reason is not None: result["rqueue"].put(None) self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], r.prompt_cache + current_model_key, + r.all_tokens, + r.prompt_cache, + cache_type="assistant", ) del batch_results[r.uid] @@ -781,16 +906,8 @@ def checkpoint_callback(prompts): uids_to_remove = self._share_object(uids_to_remove) if uids_to_remove: with mx.stream(generation_stream): - caches = batch_generator.remove( - uids_to_remove, return_prompt_caches=True - ) - for uid, prompt_cache in caches.items(): - if uid not in batch_results: - continue - result = batch_results[uid] - self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], prompt_cache - ) + batch_generator.remove(uids_to_remove) + for uid in uids_to_remove: del batch_results[uid] def _serve_single(self, request): @@ -807,7 +924,7 @@ def progress(tokens_processed, tokens_total): draft_model = self.model_provider.draft_model # Prepare the prompt - prompt = self._tokenize(tokenizer, request, args) + prompt, _, initial_state = self._tokenize(tokenizer, request, args) # Start the generation context ctx = GenerationContext( diff --git a/tests/test_generate.py b/tests/test_generate.py index 40be315ad..95fc548e2 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -9,7 +9,7 @@ from mlx_lm.generate import ( BatchGenerator, GenerationResponse, - SequenceMatcher, + SequenceStateMachine, batch_generate, generate, generate_step, @@ -434,17 +434,17 @@ def test_batch_generate_with_samplers(self): self.assertEqual(responses[uid1].token, 2) self.assertEqual(responses[uid2].token, 3) - def test_batch_generate_with_stop_matchers(self): - """Test that batch_generate with per-sequence stop_matchers stops on different tokens.""" + def test_batch_generate_with_state_machines(self): + """Test that batch_generate with per-sequence state_machines stops on different tokens.""" batch_gen = BatchGenerator( self.model, max_tokens=10, ) prompt = self.tokenizer.encode("hello") - stop_matcher_0 = SequenceMatcher([[0]]) - stop_matcher_1 = SequenceMatcher([[1]]) - stop_matcher_2 = SequenceMatcher([[2]]) + sm_0 = SequenceStateMachine({"normal": [([0], None)]}, initial="normal") + sm_1 = SequenceStateMachine({"normal": [([1], None)]}, initial="normal") + sm_2 = SequenceStateMachine({"normal": [([2], None)]}, initial="normal") processor_0 = make_logits_processors({0: 2000.0}) processor_1 = make_logits_processors({1: 2000.0}) @@ -453,7 +453,7 @@ def test_batch_generate_with_stop_matchers(self): uid0, uid1, uid2 = batch_gen.insert( [prompt, prompt, prompt], logits_processors=[processor_0, processor_1, processor_2], - stop_matchers=[stop_matcher_0, stop_matcher_1, stop_matcher_2], + state_machines=[sm_0, sm_1, sm_2], ) responses = batch_gen.next_generated() @@ -465,9 +465,9 @@ def test_batch_generate_with_stop_matchers(self): self.assertEqual(responses[uid0].finish_reason, "stop") self.assertEqual(responses[uid1].finish_reason, "stop") self.assertEqual(responses[uid2].finish_reason, "stop") - self.assertEqual(responses[uid0].stop_sequence, (0,)) - self.assertEqual(responses[uid1].stop_sequence, (1,)) - self.assertEqual(responses[uid2].stop_sequence, (2,)) + self.assertEqual(responses[uid0].match_sequence, (0,)) + self.assertEqual(responses[uid1].match_sequence, (1,)) + self.assertEqual(responses[uid2].match_sequence, (2,)) def test_batch_continued_generation(self): for rotating in [False, True]: From 972027a9115df326dc236ac323be92527568c095 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Mar 2026 10:09:55 -0700 Subject: [PATCH 18/34] Move _serve_single to the state machine --- mlx_lm/server.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index dbf4b809d..204ce8853 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -923,8 +923,15 @@ def progress(tokens_processed, tokens_total): tokenizer = self.model_provider.tokenizer draft_model = self.model_provider.draft_model - # Prepare the prompt + # Prepare the prompt and state machine prompt, _, initial_state = self._tokenize(tokenizer, request, args) + sm = self._make_state_machine( + self.model_provider.model_key, + tokenizer, + args.stop_words, + initial_state=initial_state, + ) + sm_state = sm.make_state() # Start the generation context ctx = GenerationContext( @@ -979,12 +986,18 @@ def progress(tokens_processed, tokens_total): prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, ): + finish_reason = gen.finish_reason + sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) + if match_sequence is not None and current_state is None: + finish_reason = "stop" rqueue.put( Response( gen.text, gen.token, + current_state, + match_sequence, gen.logprobs[gen.token].item(), - gen.finish_reason, + finish_reason, _format_top_logprobs( gen.logprobs, args.top_logprobs, tokenizer ), @@ -997,6 +1010,9 @@ def progress(tokens_processed, tokens_total): raise NotImplementedError() break + if finish_reason is not None: + break + rqueue.put(None) # Save the KV cache again From d19333d165775dbad5dc8c7bc65a5b0ef289be04 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Mar 2026 17:15:22 -0700 Subject: [PATCH 19/34] Running server --- mlx_lm/generate.py | 5 +- mlx_lm/server.py | 405 +++++++++++++++++++++------------------------ 2 files changed, 197 insertions(+), 213 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 5cf394a11..2231108c4 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1092,6 +1092,9 @@ def __init__( def __len__(self): return len(self.uids) + def extract_cache(self, idx: int) -> List[Any]: + return [c.extract(idx) for c in self.prompt_cache] + def extend(self, batch): if not any(self.samplers): self.samplers = [None] * len(self.uids) @@ -1720,7 +1723,7 @@ def remove(self, uids, return_prompt_caches=False): @property def prompt_cache_nbytes(self): - total = sum(c.nbytes for c in self._unprocessed_sequences for c in p[3]) + total = sum(c.nbytes for p in self._unprocessed_sequences for c in p[3]) total += sum(c.nbytes for c in self._prompt_batch.prompt_cache) total += sum(c.nbytes for c in self._generation_batch.prompt_cache) return total diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 204ce8853..eb82fbde7 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -9,6 +9,7 @@ import time import uuid import warnings +from collections import deque from dataclasses import dataclass, field from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path @@ -53,6 +54,41 @@ def get_system_fingerprint(): return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" +class ToolCallFormatter: + def __init__(self, tool_parser, tools, streaming=False): + self._idx = 0 + self._tool_parser = tool_parser + self._tools = tools + self._streaming = streaming + + def _format(self, tc): + tc_id = tc.pop("id", None) or str(uuid.uuid4()) + tc["arguments"] = json.dumps(tc["arguments"], ensure_ascii=False) + out = { + "function": tc, + "type": "function", + "id": tc_id, + } + if self._streaming: + out["index"] = self._idx + self._idx += 1 + + return out + + def __call__(self, tool_calls): + if not tool_calls: + return [] + + result = [] + for tool_text in tool_calls: + parsed = self._tool_parser(tool_text, self._tools) + if not isinstance(parsed, list): + parsed = [parsed] + result.extend(self._format(tc) for tc in parsed) + + return result + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -232,15 +268,11 @@ class CompletionRequest: @dataclass class GenerationContext: has_tool_calling: bool - tool_call_start: str - tool_call_end: str - tool_parser: Callable[[str, Any], Dict] has_thinking: bool - think_start_id: int - think_end_id: int - think_end: str - eos_token_ids: set - stop_token_sequences: List[List[int]] + tool_parser: Callable[[str, Any], Dict] + + sequences: Dict[Tuple[int], str] + prompt: List[int] prompt_cache_count: int = -1 @@ -629,44 +661,64 @@ def _tokenize(self, tokenizer, request, args): def _make_state_machine( self, model_key, tokenizer, stop_words, initial_state="normal" ): - """Make a new SequenceStateMachine or fetch it if we 've made it before.""" - cache_key = (model_key, tuple(stop_words)) - sm = self._state_machine_cache.get(cache_key) - if sm is not None: - return sm - - eos = [([t], None) for t in tokenizer.eos_token_ids] - stop = [ - (tokenizer.encode(w, add_special_tokens=False), None) for w in stop_words - ] - common_stops = eos + stop - - transitions = {"normal": list(common_stops)} + """Make a new SequenceStateMachine or fetch it if we 've made it before. + Return also a dictionary that maps the token sequences in the state + machine to their strings. + """ + cache_key = (model_key, tuple(stop_words), initial_state) + rs = self._state_machine_cache.get(cache_key) + if rs is not None: + return rs + + # Will hold the state machine transitions and the sequences map to + # strings. + transitions = {} + sequences = {} + + # Add all the stop sequences + common_stops = [] + for t in tokenizer.eos_token_ids: + sequences[(t,)] = tokenizer.convert_ids_to_tokens(t) + common_stops.append(((t,), None)) + for w in stop_words: + t = tuple(tokenizer.encode(w, add_special_tokens=False)) + sequences[t] = w + common_stops.append((t, None)) + + # From normal to stop + transitions["normal"] = list(common_stops) + + # Reasoning related transitions if tokenizer.has_thinking: - transitions["normal"].append(([tokenizer.think_start_id], "reasoning")) - transitions["reasoning"] = [ - ([tokenizer.think_end_id], "normal"), - ] + common_stops - + ts = tokenizer.think_start_id + te = tokenizer.think_end_id + transitions["normal"].append(((ts,), "reasoning")) + transitions["reasoning"] = [((te,), "normal")] + transitions["reasoning"].extend(common_stops) + sequences[(ts,)] = tokenizer.convert_ids_to_tokens(ts) + sequences[(te,)] = tokenizer.convert_ids_to_tokens(te) + + # Tool calling relating transitions if tokenizer.has_tool_calling: - tool_start = tokenizer.encode( - tokenizer.tool_call_start, add_special_tokens=False + ts = tuple( + tokenizer.encode(tokenizer.tool_call_start, add_special_tokens=False) ) - tool_end = tokenizer.encode( - tokenizer.tool_call_end, add_special_tokens=False + te = tuple( + tokenizer.encode(tokenizer.tool_call_end, add_special_tokens=False) ) - transitions["normal"].append((tool_start, "tool")) - transitions["tool"] = [ - (tool_end, "normal"), - ] + common_stops + transitions["normal"].append((ts, "tool")) + transitions["tool"] = [(te, "normal")] + transitions["tool"].extend(common_stops) + sequences[ts] = tokenizer.tool_call_start + sequences[te] = tokenizer.tool_call_end sm = SequenceStateMachine(transitions, initial=initial_state) if len(self._state_machine_cache) > 100: self._state_machine_cache.clear() - self._state_machine_cache[cache_key] = sm + self._state_machine_cache[cache_key] = (sm, sequences) - return sm + return sm, sequences def _is_batchable(self, args): if not self.model_provider.is_batchable: @@ -725,31 +777,19 @@ def get_next_request(timeout=None): rqueue.put(e) continue - ctx = GenerationContext( - has_tool_calling=tokenizer.has_tool_calling, - tool_call_start=tokenizer.tool_call_start, - tool_call_end=tokenizer.tool_call_end, - tool_parser=tokenizer.tool_parser, - has_thinking=tokenizer.has_thinking, - think_start_id=tokenizer.think_start_id, - think_end=tokenizer.think_end, - think_end_id=tokenizer.think_end_id, - eos_token_ids=tokenizer.eos_token_ids, - stop_token_sequences=[ - tokenizer.encode(stop_word, add_special_tokens=False) - for stop_word in args.stop_words - ], - prompt=prompt, + sm, sequences = self._make_state_machine( + self.model_provider.model_key, + tokenizer, + args.stop_words, + initial_state, ) - rqueue.put(ctx) self._log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( current_model_key, prompt ) - N = len(prompt) - len(rest) - ctx.prompt_cache_count = N - prompt = prompt[:N] + prompt_cache_count = len(prompt) - len(rest) + N = prompt_cache_count while N >= 0: if N >= len(segments[0]): N -= len(segments.pop(0)) @@ -757,21 +797,24 @@ def get_next_request(timeout=None): segments[0] = segments[0][N:] break + ctx = GenerationContext( + has_tool_calling=tokenizer.has_tool_calling, + has_thinking=tokenizer.has_thinking, + tool_parser=tokenizer.tool_parser, + sequences=sequences, + prompt=prompt, + prompt_cache_count=prompt_cache_count, + ) + rqueue.put(ctx) + (uid,) = batch_generator.insert_segments( segments=[segments], max_tokens=[args.max_tokens], caches=[cache], - all_tokens=[prompt], + all_tokens=[prompt[:prompt_cache_count]], samplers=[_make_sampler(args, tokenizer)], logits_processors=[_make_logits_processors(args)], - state_machines=[ - self._make_state_machine( - self.model_provider.model_key, - tokenizer, - args.stop_words, - initial_state=initial_state, - ) - ], + state_machines=[sm], ) batch_results[uid] = { "ctx": ctx, @@ -925,7 +968,7 @@ def progress(tokens_processed, tokens_total): # Prepare the prompt and state machine prompt, _, initial_state = self._tokenize(tokenizer, request, args) - sm = self._make_state_machine( + sm, sequences = self._make_state_machine( self.model_provider.model_key, tokenizer, args.stop_words, @@ -935,19 +978,10 @@ def progress(tokens_processed, tokens_total): # Start the generation context ctx = GenerationContext( + has_thinking=tokenizer.has_thinking, has_tool_calling=tokenizer.has_tool_calling, - tool_call_start=tokenizer.tool_call_start, - tool_call_end=tokenizer.tool_call_end, tool_parser=tokenizer.tool_parser, - has_thinking=tokenizer.has_thinking, - think_start_id=tokenizer.think_start_id, - think_end=tokenizer.think_end, - think_end_id=tokenizer.think_end_id, - eos_token_ids=tokenizer.eos_token_ids, - stop_token_sequences=[ - tokenizer.encode(stop_word, add_special_tokens=False) - for stop_word in args.stop_words - ], + sequences=sequences, prompt=prompt, ) rqueue.put(ctx) @@ -1045,11 +1079,25 @@ def _inner(): continue yield response + def _process_control_tokens(ctx, token_stream): + buffer_size = max(len(s) for s in ctx.sequences) + buffered_stream = deque() + + for tok in token_stream: + buffered_stream.append(tok) + if tok.match is not None: + for _ in tok.match: + buffered_stream.pop() + if len(buffered_stream) >= buffer_size: + yield buffered_stream.popleft() + while len(buffered_stream) > 0: + yield buffered_stream.popleft() + ctx = response_queue.get() if isinstance(ctx, Exception): raise ctx - return ctx, _inner() + return ctx, _process_control_tokens(ctx, _inner()) @property def cli_args(self): @@ -1144,11 +1192,18 @@ def do_POST(self): ) return - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}") - assert isinstance( - self.body, dict - ), f"Request should be dict, but got {type(self.body)}" + if logging.getLogger().isEnabledFor(logging.DEBUG): + debug_body = json.dumps(self.body, indent="\t") + logging.debug(f"Incoming Request Body: {debug_body}") + if not isinstance(self.body, dict): + debug_body = json.dumps(self.body, indent="\t") + logging.error(f"Invalid Request Body: {debug_body}") + self._set_completion_headers(400) + self.end_headers() + self.wfile.write( + json.dumps({"error": "Request should be a JSON dictionary"}).encode() + ) + return # Extract request parameters from the body self.stream = self.body.get("stream", False) @@ -1425,20 +1480,16 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): chat_template_kwargs=self.chat_template_kwargs, ) - # Create keepalive callback to send SSE comments during long prompt processing - def keepalive_callback(processed_tokens, total_tokens): - logging.info( - f"Prompt processing progress: {processed_tokens}/{total_tokens}" - ) + # Keep connection allive during long prompt processing (and also log + # the progress) + def keepalive_callback(processed, total): + logging.info(f"Prompt processing progress: {processed}/{total}") if self.stream: try: - # Send SSE comment for keepalive - invisible to clients but keeps connection alive - self.wfile.write( - f": keepalive {processed_tokens}/{total_tokens}\n\n".encode() - ) + msg = f": keepalive {processed}/{total}\n\n".encode() + self.wfile.write(msg) self.wfile.flush() except (BrokenPipeError, ConnectionResetError, OSError): - # Client disconnected, ignore pass # Create the token generator @@ -1451,7 +1502,7 @@ def keepalive_callback(processed_tokens, total_tokens): except Exception as e: self._set_completion_headers(404) self.end_headers() - self.wfile.write(json.dumps({"error": f"{e}"}).encode()) + self.wfile.write(json.dumps({"error": str(e)}).encode()) return # Prepare the headers @@ -1463,148 +1514,78 @@ def keepalive_callback(processed_tokens, total_tokens): self._set_completion_headers(200) logging.debug("Starting completion:") - # Variables to save the tool calls in as they are being generated by - # the model. - in_tool_call = False - made_tool_call = False - tool_calls = [] - tool_text = "" - tool_idx = 0 + # Tool call formatter + tool_formatter = ToolCallFormatter(ctx.tool_parser, request.tools, self.stream) - def format_tool_call(tool_call): - nonlocal tool_idx - tool_call_id = tool_call.pop("id", None) or str(uuid.uuid4()) - tool_call["arguments"] = json.dumps( - tool_call["arguments"], ensure_ascii=False - ) - out = { - "function": tool_call, - "type": "function", - "id": tool_call_id, - } - if self.stream: - out["index"] = tool_idx - tool_idx += 1 - return out - - def parse_tools(tool_calls): - if not tool_calls: - return [] - result = [] - for tool_text in tool_calls: - parsed = ctx.tool_parser(tool_text, request.tools) - if isinstance(parsed, list): - result.extend(format_tool_call(tc) for tc in parsed) - else: - result.append(format_tool_call(parsed)) - return result - - # Start out in reasoning if the model is a reasoning model and the - # prompt has an open think token but no closing think token - in_reasoning = False - if ctx.has_thinking: - for i in range(len(ctx.prompt) - 1, -1, -1): - if ctx.prompt[i] == ctx.think_end_id: - break - elif ctx.prompt[i] == ctx.think_start_id: - in_reasoning = True - break + # Variables to save the generated text, tokens, logprobs, tools etc + prev_state = None + finish_reason = "stop" reasoning_text = "" - - # Variables to save the generated tokens and the corresponding probs + made_tool_call = False + tool_text = "" + tool_calls = [] + text = "" tokens = [] token_logprobs = [] top_tokens = [] - # Variables to save the generated text - text = "" - segment = "" - - # Well finally save the reason for stopping - finish_reason = "length" - # Process the generated tokens for gen in response: logging.debug(gen.text) - # Gather the text in tool calling or text variables - if in_reasoning: - if gen.text == ctx.think_end: - in_reasoning = False - else: - reasoning_text += gen.text - elif ctx.has_tool_calling and gen.text == ctx.tool_call_start: - made_tool_call = True - in_tool_call = True - elif in_tool_call: - if gen.text == ctx.tool_call_end: + # Collect the text according to our current state and state + # transitions. Reasoning or tool or normal text. + if gen.state == "reasoning": + reasoning_text += gen.text + elif gen.state == "tool": + tool_text += gen.text + elif gen.state == "normal": + if prev_state == "tool": tool_calls.append(tool_text) tool_text = "" - in_tool_call = False - else: - tool_text += gen.text - else: + made_tool_call = True text += gen.text - segment += gen.text - # Save the token and its logprob + # Add the tokens and logprobs to the vars. tokens.append(gen.token) if args.logprobs: token_logprobs.append(gen.logprob) - - # If requested save the k top logprobs if args.top_logprobs > 0: top_tokens.append(gen.top_tokens) - # Check if we should stop early - stop_condition = stopping_criteria( - tokens, - ctx.eos_token_ids, - ctx.stop_token_sequences, - stop_words, - ) - if stop_condition.stop_met: - finish_reason = "tool_calls" if made_tool_call else "stop" - ctx.stop() - tokens = tokens[: len(tokens) - stop_condition.trim_length] - text = text[: len(text) - stop_condition.trim_text_length] - segment = "" - break - - if self.stream and not in_tool_call: - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - ( - sequence_overlap(tokens, sequence) - for sequence in ctx.stop_token_sequences - ) - ): - continue - elif segment or tool_calls or reasoning_text: - response = self.generate_response( - segment, - None, - tool_calls=parse_tools(tool_calls), - reasoning_text=reasoning_text, - ) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - reasoning_text = "" - segment = "" - tool_calls = [] + if ( + self.stream + and gen.state != "tool" + and (text or tool_calls or reasoning_text) + ): + response = self.generate_response( + text, + None, + tool_calls=tool_formatter(tool_calls), + reasoning_text=reasoning_text, + ) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + reasoning_text = "" + text = "" + tool_calls = [] if gen.finish_reason is not None: finish_reason = gen.finish_reason - # Flush any remaining tool text (e.g. when tool_call_end is empty) - if in_tool_call and tool_text: + prev_state = gen.state + + if prev_state == "tool" and tool_text: tool_calls.append(tool_text) + made_tool_call = True + + if finish_reason == "stop" and made_tool_call: + finish_reason = "tool_calls" if self.stream: response = self.generate_response( - segment, + text, finish_reason, - tool_calls=parse_tools(tool_calls), + tool_calls=tool_formatter(tool_calls), reasoning_text=reasoning_text, ) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) @@ -1630,13 +1611,13 @@ def parse_tools(tool_calls): top_tokens=top_tokens, tokens=tokens, reasoning_text=reasoning_text, - tool_calls=parse_tools(tool_calls), + tool_calls=tool_formatter(tool_calls), ) - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") + if logging.getLogger().isEnabledFor(logging.DEBUG): + response_debug = json.dumps(response, indent="\t") + logging.debug(f"Outgoing Response: {response_debug}") - # Send an additional Content-Length header when it is known + response_json = json.dumps(response).encode() self.send_header("Content-Length", str(len(response_json))) self.end_headers() self.wfile.write(response_json) From ef0df36e4194b9ee3b0e1e7198b847a8954833bd Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Mar 2026 17:30:10 -0700 Subject: [PATCH 20/34] Small cleanup --- mlx_lm/server.py | 70 ++------------------------------------------ tests/test_server.py | 12 -------- 2 files changed, 2 insertions(+), 80 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index eb82fbde7..c4e30c01d 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -10,7 +10,7 @@ import uuid import warnings from collections import deque -from dataclasses import dataclass, field +from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from queue import Empty as QueueEmpty @@ -41,9 +41,7 @@ ) from .models.cache import ( LRUPromptCache, - can_trim_prompt_cache, make_prompt_cache, - trim_prompt_cache, ) from .sample_utils import make_logits_processors, make_sampler from .utils import _parse_size, load, sharded_load @@ -89,69 +87,6 @@ def __call__(self, tool_calls): return result -class StopCondition(NamedTuple): - stop_met: bool - trim_length: int - trim_text_length: int - - -def stopping_criteria( - tokens: List[int], - eos_token_ids: set, - stop_id_sequences: List[List[int]], - stop_words: List[str], -) -> StopCondition: - """ - Determines whether the token generation should stop based on predefined - conditions. - - Args: - tokens (List[int]): The current sequence of generated tokens. - eos_token_ids (set): The token IDs that represents the - end-of-sequence. If the last token in ``tokens`` is in the set, - the generation should stop. - stop_id_sequences (List[List[[int]]): A list of integer lists, each - representing a sequence of token IDs. If the end of the `tokens` - list matches any of these sequences, the generation should stop. - stop_words (List[str]): The stop words that correspond to the - ``stop_id_sequences``. - - Returns: - StopCondition: A named tuple indicating whether the stop condition has - been met (`stop_met`) and how many tokens should be trimmed from the - end if it has (`trim_length`) as well as the text that should be - trimmed. - """ - if tokens and tokens[-1] in eos_token_ids: - return StopCondition(stop_met=True, trim_length=0, trim_text_length=0) - - for stop_ids, stop_word in zip(stop_id_sequences, stop_words): - if len(tokens) >= len(stop_ids): - if tokens[-len(stop_ids) :] == stop_ids: - return StopCondition( - stop_met=True, - trim_length=len(stop_ids), - trim_text_length=len(stop_word), - ) - - return StopCondition(stop_met=False, trim_length=0, trim_text_length=0) - - -def sequence_overlap(s1: Sequence, s2: Sequence) -> bool: - """ - Checks if a suffix of s1 has overlap with a prefix of s2 - - Args: - s1 (Sequence): The first sequence - s2 (Sequence): The second sequence - - Returns: - bool: If the two sequences have overlap - """ - max_overlap = min(len(s1), len(s2)) - return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1)) - - def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): default_role_mapping = { "system_prompt": ( @@ -1445,8 +1380,7 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): Args: prompt (List[int]): The tokenized prompt. - stop_words (List[str]): A list of stop words passed to the - stopping_criteria function + stop_words (List[str]): A list of stop words """ args = GenerationArguments( model=ModelDescription( diff --git a/tests/test_server.py b/tests/test_server.py index b6907ac34..48141c074 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -218,18 +218,6 @@ def test_handle_models(self): self.assertEqual(model["object"], "model") self.assertIn("created", model) - def test_sequence_overlap(self): - from mlx_lm.server import sequence_overlap - - self.assertTrue(sequence_overlap([1], [1])) - self.assertTrue(sequence_overlap([1, 2], [1, 2])) - self.assertTrue(sequence_overlap([1, 3], [3, 4])) - self.assertTrue(sequence_overlap([1, 2, 3], [2, 3])) - - self.assertFalse(sequence_overlap([1], [2])) - self.assertFalse(sequence_overlap([1, 2], [3, 4])) - self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3])) - class TestServerWithDraftModel(unittest.TestCase): @classmethod From c61d0f8d3686b9708e6dd638b58f3ba853b0c348 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Mar 2026 18:47:00 -0700 Subject: [PATCH 21/34] Ensure that we insert a copy of the list in the cache --- mlx_lm/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index c4e30c01d..678de7daf 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -843,7 +843,7 @@ def get_next_request(timeout=None): ) self.prompt_cache.insert_cache( self.model_provider.model_key, - cache_key, + cache_key[:], cache, cache_type=cache_type, ) @@ -872,7 +872,7 @@ def get_next_request(timeout=None): result["rqueue"].put(None) self.prompt_cache.insert_cache( current_model_key, - r.all_tokens, + r.all_tokens[:], r.prompt_cache, cache_type="assistant", ) From c53567f7adc3b259d12159b8e8f1b0a09f75fd5d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Mar 2026 19:18:37 -0700 Subject: [PATCH 22/34] Chat templates require user message --- mlx_lm/server.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 678de7daf..45ad48b9c 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -559,11 +559,14 @@ def _tokenize(self, tokenizer, request, args): break if num_system > 0: sys_tokens = tokenizer.apply_chat_template( - messages[:num_system], + messages[:num_system] + [{"role": "user", "content": ""}], add_generation_prompt=False, **template_kwargs, ) - sys_end = len(sys_tokens) + for i, (a, b) in enumerate(zip(sys_tokens, prompt)): + if a != b: + sys_end = i + break if sys_end > 0 and sys_end < len(prompt): segments.append(prompt[:sys_end]) From e24c9543e0cb4d9ab9ef2f641d0f413708591c62 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Mar 2026 19:30:28 -0700 Subject: [PATCH 23/34] Fix the in reasoning test --- mlx_lm/server.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 45ad48b9c..5896e297d 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -571,19 +571,21 @@ def _tokenize(self, tokenizer, request, args): segments.append(prompt[:sys_end]) # The following code does 2 things: - # 1. Find a tail segment that contains thinking tokens + # 1. Find a tail segment that contains thinking tokens (small up to + # 11 tokens) # 2. Find whether we have a think start token without corresponding # think end token and set the state to reasoning. tail_start = len(prompt) - state = "reasoning" if tokenizer.has_thinking: - for i in range(1, min(11, len(prompt) - sys_end), 1): + has_think_end = False + for i in range(1, len(prompt) - sys_end, 1): if prompt[-i] == tokenizer.think_end_id: - state = "normal" - continue - if prompt[-i] == tokenizer.think_start_id: + has_think_end = True + elif prompt[-i] == tokenizer.think_start_id: tail_start = len(prompt) - i - initial_state = state + initial_state = "reasoning" if not has_think_end else "normal" + break + if has_think_end and i >= 11: break # Finalize the segments and return From f49b34e7a8b433916ce81f493857fb1906af53e4 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Mar 2026 19:52:49 -0700 Subject: [PATCH 24/34] Fix the initial state --- mlx_lm/server.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 5896e297d..dd6dc7c31 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -542,12 +542,21 @@ def _tokenize(self, tokenizer, request, args): # If we are here it means we have a chat request so we need to search # for segments for better cache management. + # Choose the initial state among only reasoning or normal + initial_state = "normal" + if tokenizer.has_thinking: + for i in range(-1, -len(prompt), -1): + if prompt[i] == tokenizer.think_start_id: + initial_state = "reasoning" + break + if prompt[i] == tokenizer.think_end_id: + break + # It is not a user message so no segmentation needed. if messages[-1]["role"] != "user": - return prompt, [prompt], "normal" + return prompt, [prompt], initial_state segments = [] - initial_state = "normal" # Find where the system prompt ends and add it as a segment. num_system = 0 @@ -570,22 +579,13 @@ def _tokenize(self, tokenizer, request, args): if sys_end > 0 and sys_end < len(prompt): segments.append(prompt[:sys_end]) - # The following code does 2 things: - # 1. Find a tail segment that contains thinking tokens (small up to - # 11 tokens) - # 2. Find whether we have a think start token without corresponding - # think end token and set the state to reasoning. + # Find a tail segment that contains thinking tokens (small up to 11 + # tokens) tail_start = len(prompt) if tokenizer.has_thinking: - has_think_end = False - for i in range(1, len(prompt) - sys_end, 1): - if prompt[-i] == tokenizer.think_end_id: - has_think_end = True - elif prompt[-i] == tokenizer.think_start_id: + for i in range(1, min(11, len(prompt) - sys_end), 1): + if prompt[-i] == tokenizer.think_start_id: tail_start = len(prompt) - i - initial_state = "reasoning" if not has_think_end else "normal" - break - if has_think_end and i >= 11: break # Finalize the segments and return From 6b1a033a64e3c703fa982048ca1ba9d2a75cde02 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 28 Mar 2026 00:55:57 -0700 Subject: [PATCH 25/34] Change the copy for prompt batch split --- mlx_lm/generate.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 2231108c4..a4836753f 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1115,10 +1115,24 @@ def extend(self, batch): self.max_tokens.extend(batch.max_tokens) self.state_machines.extend(batch.state_machines) + def _copy(self): + new_batch = self.__class__.__new__(self.__class__) + new_batch.model = self.model + new_batch.uids = list(self.uids) + new_batch.prompt_cache = copy.deepcopy(self.prompt_cache) + new_batch.tokens = list(self.tokens) + new_batch.prefill_step_size = self.prefill_step_size + new_batch.samplers = list(self.samplers) + new_batch.fallback_sampler = self.fallback_sampler + new_batch.logits_processors = list(self.logits_processors) + new_batch.state_machines = list(self.state_machines) + new_batch.max_tokens = list(self.max_tokens) + return new_batch + def split(self, indices: List[int]): indices = sorted(indices) indices_left = sorted(set(range(len(self.uids))) - set(indices)) - new_batch = copy.deepcopy(self) + new_batch = self._copy() self.filter(indices_left) new_batch.filter(indices) From 61fc80837357c43fabcd4de543439247d186be26 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 28 Mar 2026 03:04:11 -0700 Subject: [PATCH 26/34] Fixes --- mlx_lm/generate.py | 5 ++++- mlx_lm/models/cache.py | 2 +- mlx_lm/server.py | 15 +++++-------- tests/test_generate.py | 51 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index a4836753f..d1cb9299d 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -923,7 +923,7 @@ def _extend_cache(cache_a, cache_b): if not cache_a: return cache_b if not cache_b: - return cache_b + return cache_a for ca, cb in zip(cache_a, cache_b): ca.extend(cb) return cache_a @@ -1730,6 +1730,9 @@ def remove(self, uids, return_prompt_caches=False): ) if len(keep[1]) < len(self._prompt_batch): self._prompt_batch.filter(sorted(keep[1])) + self._currently_processing = [ + x for i, x in enumerate(self._currently_processing) if i in keep[1] + ] if len(keep[2]) < len(self._generation_batch): self._generation_batch.filter(sorted(keep[2])) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 88731c009..7f8d6928f 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1579,7 +1579,7 @@ def pop(self): while i + 1 < len(self._ordering): lru_a = self._lrus[self._ordering[i]] lru_b = self._lrus[self._ordering[i + 1]] - if len(lru_a) >= len(lru_b): + if len(lru_a) >= len(lru_b) > 0: return lru_a.popleft() i += 1 return lru_b.popleft() diff --git a/mlx_lm/server.py b/mlx_lm/server.py index dd6dc7c31..bf4f69eb7 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -756,11 +756,12 @@ def get_next_request(timeout=None): logits_processors=[_make_logits_processors(args)], state_machines=[sm], ) + segment_cache = ["user", "system"] batch_results[uid] = { "ctx": ctx, "rqueue": rqueue, "detokenizer": tokenizer.detokenizer, - "num_segments": len(segments), + "segment_cache": segment_cache[: len(segments) - 1], "top_logprobs": args.top_logprobs, } # just making sure we don't leave a reference around @@ -836,21 +837,17 @@ def get_next_request(timeout=None): eos_ids = [ r.uid for r in prompt_responses - if r.end_of_segment and not r.end_of_prompt + if r.end_of_segment + and not r.end_of_prompt + and batch_results[r.uid]["segment_cache"] ] caches = batch_generator.extract_cache(eos_ids) for uid, (cache, cache_key) in caches.items(): - batch_results[uid]["num_segments"] -= 1 - cache_type = ( - "system" - if batch_results[uid]["num_segments"] > 1 - else "user" - ) self.prompt_cache.insert_cache( self.model_provider.model_key, cache_key[:], cache, - cache_type=cache_type, + cache_type=batch_results[uid]["segment_cache"].pop(), ) del caches diff --git a/tests/test_generate.py b/tests/test_generate.py index 95fc548e2..2a9a4b10c 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -668,6 +668,57 @@ def test_batch_continued_generation_gated_delta(self): model = qwen3_next.Model(args) self._continued_generation_test_helper(model) + def test_extend_cache_with_empty(self): + from mlx_lm.generate import _extend_cache + from mlx_lm.models.cache import make_prompt_cache + + cache_a = make_prompt_cache(self.model) + + prompt = mx.array([[1, 2, 3]]) + self.model(prompt, cache=cache_a) + mx.eval([c.state for c in cache_a]) + + result = _extend_cache(cache_a, []) + self.assertEqual(len(result), len(cache_a)) + for c in result: + self.assertGreater(c.offset, 0) + + result = _extend_cache([], cache_a) + self.assertEqual(len(result), len(cache_a)) + for c in result: + self.assertGreater(c.offset, 0) + + def test_remove_prompt_batch_updates_currently_processing(self): + prompt_a = self.tokenizer.encode("Write a long story about a cat") + prompt_b = self.tokenizer.encode("Write a long story about a dog") + + gen = BatchGenerator( + self.model, + max_tokens=5, + prefill_batch_size=2, + prefill_step_size=4, + completion_batch_size=4, + ) + uid_a, uid_b = gen.insert([prompt_a, prompt_b]) + + gen.next() + + found = gen._find_uids([uid_a, uid_b]) + for uid in [uid_a, uid_b]: + self.assertIn(uid, found) + self.assertEqual(found[uid][0], 1) + + gen.remove([uid_a]) + + self.assertEqual(len(gen._currently_processing), len(gen._prompt_batch)) + + found = gen._find_uids([uid_b]) + self.assertIn(uid_b, found) + + while responses := gen.next_generated(): + if all(r.finish_reason is not None for r in responses): + break + if __name__ == "__main__": unittest.main() From efb30e2dc141a3485e3f743756aa6863220da1ff Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 28 Mar 2026 04:33:09 -0700 Subject: [PATCH 27/34] Fix qwen 3.5 --- mlx_lm/models/qwen3_5.py | 7 ++++++- mlx_lm/server.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 636a6ede4..a86710e74 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -158,7 +158,12 @@ def __call__( conv_input = mx.concatenate([conv_state, qkv], axis=1) if cache is not None: n_keep = self.conv_kernel_size - 1 - cache[0] = mx.contiguous(conv_input[:, -n_keep:]) + if cache.lengths is not None: + ends = mx.clip(cache.lengths, 0, S) + positions = (ends[:, None] + mx.arange(n_keep))[..., None] + cache[0] = mx.take_along_axis(conv_input, positions, axis=1) + else: + cache[0] = mx.contiguous(conv_input[:, -n_keep:, :]) conv_out = nn.silu(self.conv1d(conv_input)) q, k, v = [ diff --git a/mlx_lm/server.py b/mlx_lm/server.py index bf4f69eb7..61a51f043 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -730,7 +730,7 @@ def get_next_request(timeout=None): ) prompt_cache_count = len(prompt) - len(rest) N = prompt_cache_count - while N >= 0: + while N > 0: if N >= len(segments[0]): N -= len(segments.pop(0)) else: From c2783f3a0b479ae481088722e4e2f1613920808d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 28 Mar 2026 05:06:07 -0700 Subject: [PATCH 28/34] Fix batched gated delta net --- mlx_lm/models/gated_delta.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlx_lm/models/gated_delta.py b/mlx_lm/models/gated_delta.py index af983f6d3..fa6a2ed3f 100644 --- a/mlx_lm/models/gated_delta.py +++ b/mlx_lm/models/gated_delta.py @@ -81,6 +81,8 @@ def _make_gated_delta_kernel(has_mask=False, vectorized=False): if (thread_index_in_simdgroup == 0) {{ y[dv_idx] = static_cast(out); }} + }} else {{ + y[dv_idx] = static_cast(0); }} // Increment data pointers to next time step q_ += Hk * Dk; From bef54c3e5ca396078fbaa045ce65403a2960ab04 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 28 Mar 2026 05:12:12 -0700 Subject: [PATCH 29/34] Remove unused batch dataclass --- mlx_lm/generate.py | 44 -------------------------------------------- 1 file changed, 44 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index d1cb9299d..4f4adefa3 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -827,50 +827,6 @@ class BatchStats: peak_memory: float = 0 -@dataclass -class Batch: - uids: List[int] - y: mx.array - logprobs: mx.array - max_tokens: List[int] - num_tokens: List[int] - cache: List[Any] - samplers: List[Any] - logits_processors: List[Any] - tokens: List[mx.array] - - def __len__(self): - return len(self.uids) - - def filter(self, keep_idx: List[int]): - self.uids = [self.uids[k] for k in keep_idx] - self.logprobs = [self.logprobs[k] for k in keep_idx] - self.max_tokens = [self.max_tokens[k] for k in keep_idx] - self.num_tokens = [self.num_tokens[k] for k in keep_idx] - self.samplers = [self.samplers[k] for k in keep_idx] - self.logits_processors = [self.logits_processors[k] for k in keep_idx] - self.tokens = [self.tokens[k] for k in keep_idx] - keep_idx = mx.array(keep_idx, mx.int32) - self.y = self.y[keep_idx] - for c in self.cache: - c.filter(keep_idx) - - def extend(self, other): - self.uids.extend(other.uids) - self.y = mx.concatenate([self.y, other.y]) - self.logprobs.extend(other.logprobs) - self.num_tokens.extend(other.num_tokens) - self.max_tokens.extend(other.max_tokens) - self.samplers.extend(other.samplers) - self.logits_processors.extend(other.logits_processors) - self.tokens.extend(other.tokens) - for c, o in zip(self.cache, other.cache): - c.extend(o) - - def extract_cache(self, idx): - return [c.extract(idx) for c in self.cache] - - def _make_cache(model, left_padding, max_kv_size): """ Convert a list of regular caches into their corresponding From da0b679a5f455462441cda9001f2f01b8cc593be Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 28 Mar 2026 22:33:17 -0700 Subject: [PATCH 30/34] Small refactor of server methods --- mlx_lm/models/cache.py | 2 +- mlx_lm/server.py | 205 +++++++++++++++++------------------------ 2 files changed, 86 insertions(+), 121 deletions(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 7f8d6928f..f01f1d4d1 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1579,7 +1579,7 @@ def pop(self): while i + 1 < len(self._ordering): lru_a = self._lrus[self._ordering[i]] lru_b = self._lrus[self._ordering[i + 1]] - if len(lru_a) >= len(lru_b) > 0: + if lru_a and len(lru_a) >= len(lru_b): return lru_a.popleft() i += 1 return lru_b.popleft() diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 61a51f043..3bc220cde 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -70,7 +70,6 @@ def _format(self, tc): if self._streaming: out["index"] = self._idx self._idx += 1 - return out def __call__(self, tool_calls): @@ -83,7 +82,6 @@ def __call__(self, tool_calls): if not isinstance(parsed, list): parsed = [parsed] result.extend(self._format(tc) for tc in parsed) - return result @@ -98,7 +96,7 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): "assistant": "ASSISTANT: ", "stop": "\n", } - role_mapping = role_mapping if role_mapping is not None else default_role_mapping + role_mapping = role_mapping or default_role_mapping prompt = "" for line in messages: @@ -128,7 +126,7 @@ def process_message_content(messages): """ for message in messages: - content = message.get("content", None) + content = message.get("content") if isinstance(content, list): text_fragments = [ fragment["text"] for fragment in content if fragment["type"] == "text" @@ -138,10 +136,11 @@ def process_message_content(messages): message["content"] = "".join(text_fragments) elif content is None: message["content"] = "" - if tool_calls := message.get("tool_calls", False): + + if tool_calls := message.get("tool_calls"): for tool_call in tool_calls: - if func := tool_call.get("function", False): - if args := func.get("arguments", False): + if func := tool_call.get("function"): + if args := func.get("arguments"): func["arguments"] = json.loads(args) @@ -234,7 +233,6 @@ def __init__(self, budget=0.5, iterations=25, sync_frequency=10): self._budget = budget self._iterations = iterations self._sync_frequency = sync_frequency - self._start = None self._current_iterations = None self._loops = 0 @@ -252,20 +250,22 @@ def __next__(self): return None self._current_iterations += 1 - if self._current_iterations > self._iterations: - self._loops += 1 - self._time_spent += time.time() - self._start - if self._loops % self._sync_frequency == 0: - with mx.stream(generation_stream): - loop_time = mx.distributed.all_sum(self._time_spent).item() - avg_loop_time = loop_time / ( - mx.distributed.init().size() * self._sync_frequency - ) - factor = self._budget / avg_loop_time - self._iterations = max(round(self._iterations * factor), 1) - self._loops = 0 - self._time_spent = 0 - raise StopIteration() + if self._current_iterations <= self._iterations: + return None + + self._loops += 1 + self._time_spent += time.time() - self._start + if self._loops % self._sync_frequency == 0: + with mx.stream(generation_stream): + loop_time = mx.distributed.all_sum(self._time_spent).item() + avg_loop_time = loop_time / ( + mx.distributed.init().size() * self._sync_frequency + ) + factor = self._budget / avg_loop_time + self._iterations = max(round(self._iterations * factor), 1) + self._loops = 0 + self._time_spent = 0 + raise StopIteration() class ModelProvider: @@ -404,17 +404,17 @@ def _make_logits_processors(args): ) -def _format_top_logprobs(logprobs, top_logprobs, tokenizer) -> Tuple[Dict[str, Any]]: - """Returns info dicts for the top `top_logprobs` tokens from `logprobs`""" - if top_logprobs <= 0: +def _format_top_logprobs(logprobs, top_n, tokenizer) -> Tuple[Dict[str, Any]]: + """Returns info dicts for the top `top_n` tokens from `logprobs`""" + if top_n <= 0: return () - sorted_indices = mx.argpartition(-logprobs, kth=top_logprobs - 1) - top_indices = sorted_indices[:top_logprobs].tolist() - top_logprobs = logprobs[top_indices].tolist() + sorted_indices = mx.argpartition(-logprobs, kth=top_n - 1) + top_indices = sorted_indices[:top_n].tolist() + top_probs = logprobs[top_indices].tolist() txts = tokenizer.convert_ids_to_tokens(top_indices) return tuple( {"id": i, "token": s, "logprob": g} - for i, s, g in zip(top_indices, txts, top_logprobs) + for i, s, g in zip(top_indices, txts, top_probs) ) @@ -454,7 +454,6 @@ def _next_request(self, timeout=None): request = self.requests.get_nowait() except QueueEmpty: pass - return self._share_request(request) def _share_object(self, obj): @@ -466,19 +465,17 @@ def _share_object(self, obj): if obj is None: mx.eval(mx.distributed.all_sum(0)) return None - else: - data = mx.array(pickle.dumps(obj)) - mx.eval(mx.distributed.all_sum(data.size)) - mx.eval(mx.distributed.all_sum(data)) - return obj + data = mx.array(pickle.dumps(obj)) + mx.eval(mx.distributed.all_sum(data.size)) + mx.eval(mx.distributed.all_sum(data)) + return obj else: size = mx.distributed.all_sum(0).item() if size == 0: return None - else: - data = mx.zeros(size, dtype=mx.uint8) - data = mx.distributed.all_sum(data) - return pickle.loads(data) + data = mx.zeros(size, dtype=mx.uint8) + data = mx.distributed.all_sum(data) + return pickle.loads(data) def _share_request(self, request): if not self._is_distributed: @@ -661,12 +658,7 @@ def _make_state_machine( return sm, sequences def _is_batchable(self, args): - if not self.model_provider.is_batchable: - return False - if args.seed is not None: - return False - - return True + return self.model_provider.is_batchable and args.seed is None def _generate(self): current_model = None @@ -1186,87 +1178,60 @@ def do_POST(self): request = request_factories[self.path]() self.handle_completion(request, stop_words) - def validate_model_parameters(self): - """ - Validate the model parameters passed in the request for the correct types and values. - """ - if not isinstance(self.stream, bool): - raise ValueError("stream must be a boolean") - - if not isinstance(self.max_tokens, int) or self.max_tokens < 0: - raise ValueError("max_tokens must be a non-negative integer") - - if not isinstance(self.temperature, (float, int)) or self.temperature < 0: - raise ValueError("temperature must be a non-negative float") - - if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1: - raise ValueError("top_p must be a float between 0 and 1") - - if not isinstance(self.top_k, int) or self.top_k < 0: - raise ValueError("top_k must be a non-negative integer") - - if not isinstance(self.min_p, (float, int)) or self.min_p < 0 or self.min_p > 1: - raise ValueError("min_p must be a float between 0 and 1") - - if not isinstance(self.num_draft_tokens, int) or self.num_draft_tokens < 0: - raise ValueError("num_draft_tokens must be a non-negative integer") - - if ( - not isinstance(self.repetition_penalty, (float, int)) - or self.repetition_penalty < 0 - ): - raise ValueError("repetition_penalty must be a non-negative float") - if ( - not isinstance(self.repetition_context_size, int) - or self.repetition_context_size < 0 - ): - raise ValueError("repetition_context_size must be a non-negative integer") - if not isinstance(self.presence_penalty, (float, int)): - raise ValueError("Presence penalty must be must be a float") - if ( - not isinstance(self.presence_context_size, int) - or self.presence_context_size < 0 - ): - raise ValueError("presence_context_size must be a non-negative integer") - if not isinstance(self.frequency_penalty, (float, int)): - raise ValueError("Presence penalty must be must be a float") - if ( - not isinstance(self.frequency_context_size, int) - or self.frequency_context_size < 0 - ): - raise ValueError("frequency_context_size must be a non-negative integer") - - if not isinstance(self.logprobs, bool): - raise ValueError("logprobs must be a boolean") + def _validate( + self, + name, + expected_type, + min_val=None, + max_val=None, + optional=False, + whitelist=None, + ): + value = getattr(self, name) + if optional and value is None: + return + if not isinstance(value, expected_type): + try: + allowed = tuple(et.__name__ for et in expected_type) + except TypeError: + allowed = expected_type.__name__ + raise ValueError(f"{name} must be of type {allowed}") + if whitelist is not None and value in whitelist: + return + if min_val is not None and value < min_val: + raise ValueError(f"{name} must be at least {min_val}") + if max_val is not None and value > max_val: + raise ValueError(f"{name} must be at most {max_val}") - if self.top_logprobs != -1 and not (0 < self.top_logprobs <= 10): - raise ValueError( - f"top_logprobs must be between 1 and 10 but got {self.top_logprobs:,}" - ) + def validate_model_parameters(self): + """Validate that the passed model parameters have correct types and values.""" + self._validate("stream", bool) + self._validate("max_tokens", int, min_val=0) + self._validate("temperature", (float, int), min_val=0) + self._validate("top_p", (float, int), min_val=0, max_val=1) + self._validate("top_k", int, min_val=0) + self._validate("min_p", (float, int), min_val=0, max_val=1) + self._validate("num_draft_tokens", int, min_val=0) + self._validate("repetition_penalty", (float, int), min_val=0) + self._validate("repetition_context_size", int, min_val=0) + self._validate("presence_penalty", (float, int)) + self._validate("presence_context_size", int, min_val=0) + self._validate("frequency_penalty", (float, int)) + self._validate("frequency_context_size", int, min_val=0) + self._validate("logprobs", bool) + self._validate("top_logprobs", int, min_val=0, max_val=11, whitelist=[-1]) + self._validate("xtc_probability", float, min_val=0, max_val=1) + self._validate("xtc_threshold", float, min_val=0, max_val=1) + self._validate("requested_model", str) + self._validate("adapter", str, optional=True) + self._validate("seed", int, optional=True) + self._validate("logit_bias", dict, optional=True) if self.logit_bias is not None: - if not isinstance(self.logit_bias, dict): - raise ValueError("logit_bias must be a dict of int to float") - try: - self.logit_bias = {int(k): v for k, v in self.logit_bias.items()} + self.logit_bias = {int(k): float(v) for k, v in self.logit_bias.items()} except ValueError: raise ValueError("logit_bias must be a dict of int to float") - if not ( - isinstance(self.xtc_probability, float) - and 0.00 <= self.xtc_probability <= 1.00 - ): - raise ValueError(f"xtc_probability must be a float between 0.00 and 1.00") - if not ( - isinstance(self.xtc_threshold, float) and 0.00 <= self.xtc_threshold <= 0.50 - ): - raise ValueError(f"xtc_threshold must be a float between 0.00 and 0.5") - if not isinstance(self.requested_model, str): - raise ValueError("model must be a string") - if self.adapter is not None and not isinstance(self.adapter, str): - raise ValueError("adapter must be a string") - if self.seed is not None and not isinstance(self.seed, int): - raise ValueError("seed must be an integer") def generate_response( self, From 4154a532115f813291bdffcbe2b77226b9ce3e23 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 29 Mar 2026 00:47:48 -0700 Subject: [PATCH 31/34] Stop generation or prompt processing on disconnect --- mlx_lm/server.py | 183 ++++++++++++++++++++++++----------------------- 1 file changed, 94 insertions(+), 89 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 3bc220cde..70dea804b 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -824,6 +824,8 @@ def get_next_request(timeout=None): for r in prompt_responses: result = batch_results[r.uid] result["rqueue"].put(r.progress) + if result["ctx"]._should_stop: + uids_to_remove.append(r.uid) # Save the caches at end of segments eos_ids = [ @@ -1386,12 +1388,9 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): def keepalive_callback(processed, total): logging.info(f"Prompt processing progress: {processed}/{total}") if self.stream: - try: - msg = f": keepalive {processed}/{total}\n\n".encode() - self.wfile.write(msg) - self.wfile.flush() - except (BrokenPipeError, ConnectionResetError, OSError): - pass + msg = f": keepalive {processed}/{total}\n\n".encode() + self.wfile.write(msg) + self.wfile.flush() # Create the token generator try: @@ -1430,99 +1429,105 @@ def keepalive_callback(processed, total): token_logprobs = [] top_tokens = [] - for gen in response: - logging.debug(gen.text) - - # Collect the text according to our current state and state - # transitions. Reasoning or tool or normal text. - if gen.state == "reasoning": - reasoning_text += gen.text - elif gen.state == "tool": - tool_text += gen.text - elif gen.state == "normal": - if prev_state == "tool": - tool_calls.append(tool_text) - tool_text = "" - made_tool_call = True - text += gen.text - - # Add the tokens and logprobs to the vars. - tokens.append(gen.token) - if args.logprobs: - token_logprobs.append(gen.logprob) - if args.top_logprobs > 0: - top_tokens.append(gen.top_tokens) - - if ( - self.stream - and gen.state != "tool" - and (text or tool_calls or reasoning_text) - ): - response = self.generate_response( - text, - None, - tool_calls=tool_formatter(tool_calls), - reasoning_text=reasoning_text, - ) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - reasoning_text = "" - text = "" - tool_calls = [] + try: + for gen in response: + logging.debug(gen.text) + + # Collect the text according to our current state and state + # transitions. Reasoning or tool or normal text. + if gen.state == "reasoning": + reasoning_text += gen.text + elif gen.state == "tool": + tool_text += gen.text + elif gen.state == "normal": + if prev_state == "tool": + tool_calls.append(tool_text) + tool_text = "" + made_tool_call = True + text += gen.text + + # Add the tokens and logprobs to the vars. + tokens.append(gen.token) + if args.logprobs: + token_logprobs.append(gen.logprob) + if args.top_logprobs > 0: + top_tokens.append(gen.top_tokens) - if gen.finish_reason is not None: - finish_reason = gen.finish_reason + if ( + self.stream + and gen.state != "tool" + and (text or tool_calls or reasoning_text) + ): + resp = self.generate_response( + text, + None, + tool_calls=tool_formatter(tool_calls), + reasoning_text=reasoning_text, + ) + self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) + self.wfile.flush() + reasoning_text = "" + text = "" + tool_calls = [] - prev_state = gen.state + if gen.finish_reason is not None: + finish_reason = gen.finish_reason - if prev_state == "tool" and tool_text: - tool_calls.append(tool_text) - made_tool_call = True + prev_state = gen.state - if finish_reason == "stop" and made_tool_call: - finish_reason = "tool_calls" + if prev_state == "tool" and tool_text: + tool_calls.append(tool_text) + made_tool_call = True - if self.stream: - response = self.generate_response( - text, - finish_reason, - tool_calls=tool_formatter(tool_calls), - reasoning_text=reasoning_text, - ) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response( + if finish_reason == "stop" and made_tool_call: + finish_reason = "tool_calls" + + if self.stream: + resp = self.generate_response( + text, + finish_reason, + tool_calls=tool_formatter(tool_calls), + reasoning_text=reasoning_text, + ) + self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) + self.wfile.flush() + if ( + self.stream_options is not None + and self.stream_options["include_usage"] + ): + resp = self.completion_usage_response( + len(ctx.prompt), + len(tokens), + ctx.prompt_cache_count, + ) + self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode()) + self.wfile.flush() + else: + resp = self.generate_response( + text, + finish_reason, len(ctx.prompt), len(tokens), ctx.prompt_cache_count, + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + reasoning_text=reasoning_text, + tool_calls=tool_formatter(tool_calls), ) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + if logging.getLogger().isEnabledFor(logging.DEBUG): + response_debug = json.dumps(resp, indent="\t") + logging.debug(f"Outgoing Response: {response_debug}") + + response_json = json.dumps(resp).encode() + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + self.wfile.write(response_json) self.wfile.flush() - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() - else: - response = self.generate_response( - text, - finish_reason, - len(ctx.prompt), - len(tokens), - ctx.prompt_cache_count, - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - reasoning_text=reasoning_text, - tool_calls=tool_formatter(tool_calls), - ) - if logging.getLogger().isEnabledFor(logging.DEBUG): - response_debug = json.dumps(response, indent="\t") - logging.debug(f"Outgoing Response: {response_debug}") - - response_json = json.dumps(response).encode() - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - self.wfile.write(response_json) - self.wfile.flush() + finally: + ctx.stop() def completion_usage_response( self, From 401b86ee683f8176f4f03c5e7d91f7dffae8ee3a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 30 Mar 2026 01:49:34 -0700 Subject: [PATCH 32/34] Fix batched deepseek_v32 and GLM --- mlx_lm/models/deepseek_v32.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mlx_lm/models/deepseek_v32.py b/mlx_lm/models/deepseek_v32.py index e40e52950..7c97682e7 100644 --- a/mlx_lm/models/deepseek_v32.py +++ b/mlx_lm/models/deepseek_v32.py @@ -87,19 +87,15 @@ def __call__( b, s, _ = x.shape q = self.wq_b(qr) q = q.reshape(b, s, self.n_heads, self.head_dim).swapaxes(1, 2) - q_pe, q_nope = mx.split(q, [self.rope_head_dim], axis=-1) + k = self.wk(x) + k = self.k_norm(k) + k = mx.reshape(k, (b, 1, s, self.head_dim)) offset = cache.offset if cache is not None else 0 - q_pe = self.rope(q_pe, offset=offset) - q = mx.concatenate([q_pe, q_nope], axis=-1) + q = self.rope(q, offset=offset) + k = self.rope(k, offset=offset) - k = self.wk(x) - k = self.k_norm(k) - k = mx.reshape(k, (b, 1, s, self.head_dim)) - k_pe, k_nope = mx.split(k, [self.rope_head_dim], axis=-1) - k_pe = self.rope(k_pe, offset=offset) - k = mx.concatenate([k_pe, k_nope], axis=-1) if cache is not None: k, _ = cache.update_and_fetch(k, mx.zeros([b, 1, s, 0])) if k.shape[2] <= self.index_topk: @@ -221,7 +217,8 @@ def __call__( mx.broadcast_to(idx, idx.shape[:-1] + (k_pe.shape[-1],)), axis=2, ) - mask = None + if mask is not None: + mask = mx.take_along_axis(mask, topk_indices, axis=-1) else: shape = list(topk_indices.shape) shape[-1] = kv_latent.shape[2] From d4e898c4218aac72af52eaff7227476bf5fbec7e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 30 Mar 2026 02:18:45 -0700 Subject: [PATCH 33/34] Add cache type logging and fix segment type --- mlx_lm/models/cache.py | 20 ++++++++++++++++++- mlx_lm/server.py | 45 +++++++++++++++++++++++++++--------------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index f01f1d4d1..d6f101f90 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1554,6 +1554,7 @@ class LRUPromptCache: class CacheEntry: prompt_cache: List[Any] nbytes: int + cache_type: str class CacheOrder: def __init__(self, ordering: List[str] = ["assistant", "user", "system"]): @@ -1590,6 +1591,7 @@ def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): self._trie = PromptTrie() self._lru = LRUPromptCache.CacheOrder() self._n_bytes = 0 + self._n_bytes_by_type = {k: 0 for k in self._lru._ordering} def __len__(self): return len(self._lru) @@ -1630,14 +1632,16 @@ def insert_cache( ): # Make the cache entry entry = LRUPromptCache.CacheEntry( - prompt_cache, sum(c.nbytes for c in prompt_cache) + prompt_cache, sum(c.nbytes for c in prompt_cache), cache_type ) # Insert into the trie and update the byte counter and lru position self._n_bytes += entry.nbytes + self._n_bytes_by_type[cache_type] += entry.nbytes prev = self._trie.add(model, tokens, entry) if prev is not None: self._n_bytes -= prev.nbytes + self._n_bytes_by_type[prev.cache_type] -= prev.nbytes self._lru.remove(model, tokens) self._lru.push(model, tokens, cache_type) @@ -1646,6 +1650,7 @@ def insert_cache( if can_trim_prompt_cache(prompt_cache): for prefix_len, entry in self._trie.pop_prefixes(model, tokens): self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes self._lru.remove(model, tokens[:prefix_len]) # Ensure we match the constraints @@ -1653,10 +1658,12 @@ def insert_cache( model, tokens = self._lru.pop() entry = self._trie.pop(model, tokens) self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes while self._n_bytes > self.max_bytes: model, tokens = self._lru.pop() entry = self._trie.pop(model, tokens) self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes def trim_to( self, *, n_sequences: Optional[int] = None, n_bytes: Optional[int] = None @@ -1668,7 +1675,18 @@ def trim_to( model, tokens = self._lru.pop() entry = self._trie.pop(model, tokens) self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes while self._n_bytes > n_bytes: model, tokens = self._lru.pop() entry = self._trie.pop(model, tokens) self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes + + def stats_by_type(self): + result = {} + for cache_type in self._lru._ordering: + result[cache_type] = { + "n_sequences": len(self._lru._lrus[cache_type]), + "n_bytes": self._n_bytes_by_type[cache_type], + } + return result diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 70dea804b..6a352f2f3 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -440,9 +440,15 @@ def join(self): self._generation_thread.join() def _log_cache_stats(self): - ncaches = len(self.prompt_cache) - nbytes = self.prompt_cache.nbytes - logging.info(f"KV Caches: {ncaches} seq, {nbytes / 1e9:.2f} GB") + n_sequences = len(self.prompt_cache) + n_bytes = self.prompt_cache.nbytes + logging.info(f"Prompt Cache: {n_sequences} sequences, {n_bytes / 1e9:.2f} GB") + for cache_type, stats in self.prompt_cache.stats_by_type().items(): + n_sequences = stats["n_sequences"] + n_bytes = stats["n_bytes"] + logging.info( + f"- {cache_type}: {n_sequences} sequences, {n_bytes / 1e9:.2f} GB" + ) def _next_request(self, timeout=None): request = None @@ -494,10 +500,12 @@ def _tokenize(self, tokenizer, request, args): Returns a tuple - * prompt - full list of tokens - * segments - a list of lists of tokens. Up to 3 segments that + * prompt - Full list of tokens + * segments - A list of lists of tokens. Up to 3 segments that correspond to system prompt, context, thinking tail. - * initial state - a string that contains the initial state of the + * segment_types - A string per segment indicating if the segment is a + system prompt or a user prompt or nothing special. + * initial state - A string that contains the initial state of the state machine (normal or thinking depending on whether we have tail or not) """ @@ -531,10 +539,10 @@ def _tokenize(self, tokenizer, request, args): ) else: prompt = tokenizer.encode(convert_chat(messages, role_mapping)) - return prompt, [prompt], "normal" + return prompt, [prompt], ["assistant"], "normal" else: prompt = tokenizer.encode(request.prompt) - return prompt, [prompt], "normal" + return prompt, [prompt], ["assistant"], "normal" # If we are here it means we have a chat request so we need to search # for segments for better cache management. @@ -551,9 +559,10 @@ def _tokenize(self, tokenizer, request, args): # It is not a user message so no segmentation needed. if messages[-1]["role"] != "user": - return prompt, [prompt], initial_state + return prompt, [prompt], ["assistant"], initial_state segments = [] + segment_types = [] # Find where the system prompt ends and add it as a segment. num_system = 0 @@ -575,6 +584,7 @@ def _tokenize(self, tokenizer, request, args): break if sys_end > 0 and sys_end < len(prompt): segments.append(prompt[:sys_end]) + segment_types.append("system") # Find a tail segment that contains thinking tokens (small up to 11 # tokens) @@ -588,12 +598,15 @@ def _tokenize(self, tokenizer, request, args): # Finalize the segments and return if sys_end < tail_start: segments.append(prompt[sys_end:tail_start]) + segment_types.append("user") if tail_start < len(prompt): segments.append(prompt[tail_start:]) + segment_types.append("assistant") if not segments: segments = [prompt] + segment_types = ["assistant"] - return prompt, segments, initial_state + return prompt, segments, segment_types, initial_state def _make_state_machine( self, model_key, tokenizer, stop_words, initial_state="normal" @@ -702,7 +715,7 @@ def get_next_request(timeout=None): and self._is_batchable(args) ): try: - prompt, segments, initial_state = self._tokenize( + prompt, segments, segment_types, initial_state = self._tokenize( current_tokenizer, request, args ) except Exception as e: @@ -725,6 +738,7 @@ def get_next_request(timeout=None): while N > 0: if N >= len(segments[0]): N -= len(segments.pop(0)) + segment_types.pop(0) else: segments[0] = segments[0][N:] break @@ -748,12 +762,11 @@ def get_next_request(timeout=None): logits_processors=[_make_logits_processors(args)], state_machines=[sm], ) - segment_cache = ["user", "system"] batch_results[uid] = { "ctx": ctx, "rqueue": rqueue, "detokenizer": tokenizer.detokenizer, - "segment_cache": segment_cache[: len(segments) - 1], + "segment_types": segment_types[::-1], "top_logprobs": args.top_logprobs, } # just making sure we don't leave a reference around @@ -833,7 +846,7 @@ def get_next_request(timeout=None): for r in prompt_responses if r.end_of_segment and not r.end_of_prompt - and batch_results[r.uid]["segment_cache"] + and batch_results[r.uid]["segment_types"] ] caches = batch_generator.extract_cache(eos_ids) for uid, (cache, cache_key) in caches.items(): @@ -841,7 +854,7 @@ def get_next_request(timeout=None): self.model_provider.model_key, cache_key[:], cache, - cache_type=batch_results[uid]["segment_cache"].pop(), + cache_type=batch_results[uid]["segment_types"].pop(), ) del caches @@ -898,7 +911,7 @@ def progress(tokens_processed, tokens_total): draft_model = self.model_provider.draft_model # Prepare the prompt and state machine - prompt, _, initial_state = self._tokenize(tokenizer, request, args) + prompt, _, _, initial_state = self._tokenize(tokenizer, request, args) sm, sequences = self._make_state_machine( self.model_provider.model_key, tokenizer, From ee44cd201c13a30b6e7f311b6b02c272de7bc3ba Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 30 Mar 2026 13:56:52 -0700 Subject: [PATCH 34/34] Handle edge case when uid both removed and finished --- mlx_lm/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 6a352f2f3..6c69ea74c 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -895,7 +895,9 @@ def get_next_request(timeout=None): with mx.stream(generation_stream): batch_generator.remove(uids_to_remove) for uid in uids_to_remove: - del batch_results[uid] + # It may have already been removed during + # generation + batch_results.pop(uid, None) def _serve_single(self, request): rqueue, request, args = request