diff --git a/mlx_lm/benchmark.py b/mlx_lm/benchmark.py index 3b2cb66a4..8bfafe898 100644 --- a/mlx_lm/benchmark.py +++ b/mlx_lm/benchmark.py @@ -148,10 +148,13 @@ def batch_bench(): for i in range(args.num_trials): if args.delay > 0: time.sleep(args.delay) + tic = time.perf_counter() response = _bench() + toc = time.perf_counter() responses.append(response) results = [(k, getattr(response, k)) for k in report_keys] results = [f"{k}={v:.3f}" for k, v in results] + results.append(f"total_time={toc - tic:.3f}") rprint(f"Trial {i+1}: " + ", ".join(results)) def avg(k): diff --git a/mlx_lm/examples/batch_generate_response.py b/mlx_lm/examples/batch_generate_response.py index 6d07b4fba..0925113d9 100644 --- a/mlx_lm/examples/batch_generate_response.py +++ b/mlx_lm/examples/batch_generate_response.py @@ -27,7 +27,7 @@ # Set `verbose=True` to see generation statistics result = batch_generate( - model, tokenizer, prompts, verbose=False, return_prompt_caches=True + model, tokenizer, prompts, verbose=False, return_prompt_caches=True, max_tokens=2048 ) print(result.texts[-1]) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 22531c644..4f4adefa3 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -2,10 +2,12 @@ import argparse import contextlib +import copy import functools import json import sys import time +from collections import deque from dataclasses import dataclass from functools import partial from typing import ( @@ -14,6 +16,7 @@ Generator, List, Optional, + Sequence, Tuple, Union, ) @@ -824,65 +827,6 @@ class BatchStats: peak_memory: float = 0 -@dataclass -class BatchResponse: - """ - An data object to hold a batch generation response. - - Args: - texts: (List[str]): The generated text for each prompt. - stats (BatchStats): Statistics about the generation. - """ - - texts: List[str] - stats: BatchStats - caches: Optional[List[List[Any]]] - - -@dataclass -class Batch: - uids: List[int] - y: mx.array - logprobs: mx.array - max_tokens: List[int] - num_tokens: List[int] - cache: List[Any] - samplers: List[Any] - logits_processors: List[Any] - tokens: List[mx.array] - - def __len__(self): - return len(self.uids) - - def filter(self, keep_idx: List[int]): - self.uids = [self.uids[k] for k in keep_idx] - self.logprobs = [self.logprobs[k] for k in keep_idx] - self.max_tokens = [self.max_tokens[k] for k in keep_idx] - self.num_tokens = [self.num_tokens[k] for k in keep_idx] - self.samplers = [self.samplers[k] for k in keep_idx] - self.logits_processors = [self.logits_processors[k] for k in keep_idx] - self.tokens = [self.tokens[k] for k in keep_idx] - keep_idx = mx.array(keep_idx, mx.int32) - self.y = self.y[keep_idx] - for c in self.cache: - c.filter(keep_idx) - - def extend(self, other): - self.uids.extend(other.uids) - self.y = mx.concatenate([self.y, other.y]) - self.logprobs.extend(other.logprobs) - self.num_tokens.extend(other.num_tokens) - self.max_tokens.extend(other.max_tokens) - self.samplers.extend(other.samplers) - self.logits_processors.extend(other.logits_processors) - self.tokens.extend(other.tokens) - for c, o in zip(self.cache, other.cache): - c.extend(o) - - def extract_cache(self, idx): - return [c.extract(idx) for c in self.cache] - - def _make_cache(model, left_padding, max_kv_size): """ Convert a list of regular caches into their corresponding @@ -917,6 +861,10 @@ def to_batch_cache(c): def _merge_caches(caches): batch_cache = [] + + if not caches: + return batch_cache + for i in range(len(caches[0])): if hasattr(caches[0][i], "merge"): batch_cache.append(caches[0][i].merge([c[i] for c in caches])) @@ -927,25 +875,622 @@ def _merge_caches(caches): return batch_cache -def _lazy_extract_cache(cache, i): - # Generators like lambdas are late bound so we can't just use it in the loop - return (c.extract(i) for c in cache) +def _extend_cache(cache_a, cache_b): + if not cache_a: + return cache_b + if not cache_b: + return cache_a + for ca, cb in zip(cache_a, cache_b): + ca.extend(cb) + return cache_a -class BatchGenerator: +def _build_trie(sequences): + """Build an Aho-Corasick trie from the provided sequences + + See https://en.wikipedia.org/wiki/Aho–Corasick_algorithm . + """ + trie = {} + for idx, seq in enumerate(sequences): + node = trie + try: + for tok in seq: + node = node.setdefault(tok, {}) + node["__match__"] = (tuple(seq), idx) + except TypeError: + node = node.setdefault(seq, {}) + node["__match__"] = ((seq,), idx) + + # BFS to set failure links and propagate matches. + queue = deque() + for key, child in trie.items(): + if key == "__match__": + continue + child["__fail__"] = trie + queue.append(child) + while queue: + parent = queue.popleft() + for key, child in parent.items(): + if key in ("__fail__", "__match__"): + continue + queue.append(child) + fail = parent["__fail__"] + while key not in fail and fail is not trie: + fail = fail["__fail__"] + child["__fail__"] = fail[key] if key in fail else trie + if "__match__" not in child and "__match__" in child["__fail__"]: + child["__match__"] = child["__fail__"]["__match__"] + return trie + + +def _step_trie(node, trie, x): + """One step in the Aho-Corasick trie.""" + while x not in node and node is not trie: + node = node["__fail__"] + if x in node: + node = node[x] + return node + + +class SequenceStateMachine: + """A state machine that uses one Aho-Corasick trie per state to efficiently + track state across a generated sequence. + + The transitions are provided as state -> [(sequence, new_state)]. + + Example: + + sm = SequenceStateMachine( + transitions={ + "normal": [ + (think_start_tokens, "reasoning"), + (tool_start_tokens, "tool"), + (eos, None), + ], + "reasoning": [ + (think_end_tokens, "normal"), + (eos, None), + ], + "tool": [ + (tool_end_tokens, None), + (eos, None) + ], + }, + initial="normal" + ) + """ + + def __init__(self, transitions={}, initial="normal"): + self._initial = initial + self._states = {} + for src, edges in transitions.items(): + sequences, dst = zip(*edges) + self._states[src] = (_build_trie(sequences), dst) + if not self._states: + self._states[initial] = (_build_trie([]), []) + + def __deepcopy__(self, memo): + new = object.__new__(SequenceStateMachine) + new._initial = self._initial + new._states = self._states + return new + + def make_state(self): + return (self._initial, self._states[self._initial][0], self._states) + + @staticmethod + def match(state, x): + s, n, states = state + n = _step_trie(n, states[s][0], x) + + seq = None + match = n.get("__match__") + if match is not None: + seq = match[0] + s = states[s][1][match[1]] + n = states[s][0] if s is not None else None + + return (s, n, states), seq, s + + +class PromptProcessingBatch: + """ + A batch processor for prompt tokens with support for incremental processing. + + This class handles batched prompt processing, managing KV caches and preparing + tokens for generation. It supports extending, filtering, and splitting batches. + """ + + @dataclass + class Response: + uid: int + progress: tuple + end_of_segment: bool + end_of_prompt: bool + + def __init__( + self, + model: nn.Module, + uids: List[int], + caches: List[List[Any]], + tokens: Optional[List[List[int]]] = None, + prefill_step_size: int = 2048, + samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, + fallback_sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ] = None, + state_machines: Optional[List[SequenceStateMachine]] = None, + max_tokens: Optional[List[int]] = None, + ): + self.model = model + self.uids = uids + self.prompt_cache = _merge_caches(caches) + self.tokens = tokens if tokens is not None else [[] for _ in uids] + + self.prefill_step_size = prefill_step_size + self.samplers = samplers if samplers is not None else [] + self.fallback_sampler = fallback_sampler or (lambda x: mx.argmax(x, axis=-1)) + self.logits_processors = ( + logits_processors if logits_processors is not None else [] + ) + self.state_machines = ( + state_machines + if state_machines is not None + else [SequenceStateMachine()] * len(uids) + ) + self.max_tokens = ( + max_tokens + if max_tokens is not None + else [DEFAULT_MAX_TOKENS] * len(self.uids) + ) + + def __len__(self): + return len(self.uids) + + def extract_cache(self, idx: int) -> List[Any]: + return [c.extract(idx) for c in self.prompt_cache] + + def extend(self, batch): + if not any(self.samplers): + self.samplers = [None] * len(self.uids) + if not any(self.logits_processors): + self.logits_processors = [None] * len(self.uids) + samplers = batch.samplers if any(batch.samplers) else [None] * len(batch.uids) + logits_processors = ( + batch.logits_processors + if any(batch.logits_processors) + else [None] * len(batch.uids) + ) + + self.uids.extend(batch.uids) + self.prompt_cache = _extend_cache(self.prompt_cache, batch.prompt_cache) + self.tokens.extend(batch.tokens) + self.samplers.extend(samplers) + self.logits_processors.extend(logits_processors) + self.max_tokens.extend(batch.max_tokens) + self.state_machines.extend(batch.state_machines) + + def _copy(self): + new_batch = self.__class__.__new__(self.__class__) + new_batch.model = self.model + new_batch.uids = list(self.uids) + new_batch.prompt_cache = copy.deepcopy(self.prompt_cache) + new_batch.tokens = list(self.tokens) + new_batch.prefill_step_size = self.prefill_step_size + new_batch.samplers = list(self.samplers) + new_batch.fallback_sampler = self.fallback_sampler + new_batch.logits_processors = list(self.logits_processors) + new_batch.state_machines = list(self.state_machines) + new_batch.max_tokens = list(self.max_tokens) + return new_batch + + def split(self, indices: List[int]): + indices = sorted(indices) + indices_left = sorted(set(range(len(self.uids))) - set(indices)) + new_batch = self._copy() + self.filter(indices_left) + new_batch.filter(indices) + + return new_batch + + def filter(self, keep: List[int]): + self.uids = [self.uids[idx] for idx in keep] + if not keep: + self.prompt_cache.clear() + else: + for c in self.prompt_cache: + c.filter(keep) + self.tokens = [self.tokens[idx] for idx in keep] + if any(self.samplers): + self.samplers = [self.samplers[idx] for idx in keep] + else: + self.samplers = [None] * len(keep) + if any(self.logits_processors): + self.logits_processors = [self.logits_processors[idx] for idx in keep] + else: + self.logits_processors = [[]] * len(keep) + self.max_tokens = [self.max_tokens[idx] for idx in keep] + self.state_machines = [self.state_machines[idx] for idx in keep] + + def prompt(self, tokens: List[List[int]]): + """ + Process prompt tokens through the model. + + Args: + tokens: List of token sequences to process. + """ + if len(self.uids) != len(tokens): + raise ValueError("The batch length doesn't match the number of inputs") + + if not tokens: + return + + # Add the tokens to the self.tokens so they represent the tokens + # contained in the KV Cache. + for sti, ti in zip(self.tokens, tokens): + sti += ti + + # Calculate if we need to pad + lengths = [len(p) for p in tokens] + max_length = max(lengths) + padding = [max_length - l for l in lengths] + max_padding = max(padding) + + # Prepare the caches and inputs. Right pad if needed otherwise just + # cast to array. + if max_padding > 0: + tokens = _right_pad_prompts(tokens, max_length=max_length) + for c in self.prompt_cache: + c.prepare(lengths=lengths, right_padding=padding) + else: + tokens = mx.array(tokens) + + # Actual prompt processing loop + while tokens.shape[1] > 0: + n_to_process = min(self.prefill_step_size, tokens.shape[1]) + self.model(tokens[:, :n_to_process], cache=self.prompt_cache) + mx.eval([c.state for c in self.prompt_cache]) + mx.clear_cache() + tokens = tokens[:, n_to_process:] + + # Finalize the cache if there was any padding + if max_padding > 0: + for c in self.prompt_cache: + c.finalize() + mx.eval([c.state for c in self.prompt_cache]) + mx.clear_cache() + + def generate(self, tokens: List[List[int]]): + """ + Transition from prompt processing to generation. + + Args: + tokens: Final tokens for each sequence to start generation. + + Returns: + A GenerationBatch ready for token generation. + """ + if any(len(t) > 1 for t in tokens): + self.prompt([t[:-1] for t in tokens]) + last_token = mx.array([t[-1] for t in tokens]) + + generation = GenerationBatch( + self.model, + self.uids, + last_token, + self.prompt_cache, + self.tokens, + self.samplers, + self.fallback_sampler, + self.logits_processors, + self.state_machines, + self.max_tokens, + ) + + self.uids = [] + self.prompt_cache = [] + self.tokens = [] + self.samplers = [] + self.logits_processors = [] + self.max_tokens = [] + + return generation + + @classmethod + def empty( + cls, + model: nn.Module, + fallback_sampler: Callable[[mx.array], mx.array], + prefill_step_size: int = 2048, + ): + return cls( + model=model, + fallback_sampler=fallback_sampler, + prefill_step_size=prefill_step_size, + uids=[], + caches=[], + tokens=[], + samplers=[], + logits_processors=[], + max_tokens=[], + state_machines=[], + ) + + +class GenerationBatch: + """ + A batched token generator that manages multiple sequences in parallel. + + This class handles the generation phase after prompt processing, managing + KV caches, sampling, and stop sequence detection for multiple sequences. + """ + @dataclass class Response: uid: int token: int logprobs: mx.array finish_reason: Optional[str] - prompt_cache: Callable[[], List[Any]] + current_state: Optional[str] + match_sequence: Optional[List[int]] + prompt_cache: Optional[List[Any]] + all_tokens: Optional[List[int]] def __init__( self, - model, + model: nn.Module, + uids: List[int], + inputs: mx.array, + prompt_cache: List[Any], + tokens: List[List[int]], + samplers: Optional[List[Callable[[mx.array], mx.array]]], + fallback_sampler: Callable[[mx.array], mx.array], + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ], + state_machines: List[SequenceStateMachine], + max_tokens: List[int], + ): + self.model = model + self.uids = uids + self.prompt_cache = prompt_cache + self.tokens = tokens + + self.samplers = samplers + self.fallback_sampler = fallback_sampler + self.logits_processors = logits_processors + self.state_machines = state_machines + self.max_tokens = max_tokens + + if self.samplers and len(self.samplers) != len(self.uids): + raise ValueError("Insufficient number of samplers provided") + if self.logits_processors and len(self.logits_processors) != len(self.uids): + raise ValueError("Insufficient number of logits_processors provided") + + self._current_tokens = None + self._current_logprobs = [] + self._next_tokens = inputs + self._next_logprobs = [] + self._token_context = [mx.array(t[-256:]) for t in tokens] + self._num_tokens = [0] * len(self.uids) + self._matcher_states = [m.make_state() for m in state_machines] + + if self.uids: + self._step() + + def __len__(self): + return len(self.uids) + + def extend(self, batch): + """Extend this batch with another generation batch.""" + self.uids.extend(batch.uids) + self.prompt_cache = _extend_cache(self.prompt_cache, batch.prompt_cache) + self.tokens.extend(batch.tokens) + self.samplers.extend(batch.samplers) + self.logits_processors.extend(batch.logits_processors) + self.max_tokens.extend(batch.max_tokens) + self.state_machines.extend(batch.state_machines) + if self._current_tokens is None: + self._current_tokens = batch._current_tokens + self._current_logprobs = batch._current_logprobs + elif batch._current_tokens is not None: + self._current_tokens = mx.concatenate( + [self._current_tokens, batch._current_tokens] + ) + self._current_logprobs.extend(batch._current_logprobs) + if self._next_tokens is None: + self._next_tokens = batch._next_tokens + self._next_logprobs = batch._next_logprobs + elif batch._next_tokens is not None: + self._next_tokens = mx.concatenate([self._next_tokens, batch._next_tokens]) + self._next_logprobs.extend(batch._next_logprobs) + self._token_context.extend(batch._token_context) + self._num_tokens.extend(batch._num_tokens) + self._matcher_states.extend(batch._matcher_states) + + def _step(self) -> Tuple[List[int], List[mx.array]]: + """ + Perform a single generation step. + + Returns: + Tuple of token list and logprobs list. + """ + self._current_tokens = self._next_tokens + self._current_logprobs = self._next_logprobs + inputs = self._current_tokens + + # Update the token context that will be used by the logits processors + for i, ti in enumerate(self._token_context): + self._token_context[i] = mx.concatenate( + [ti[1:] if len(ti) == 256 else ti, inputs[i : i + 1]] + ) + + # Forward pass + logits = self.model(inputs[:, None], cache=self.prompt_cache) + logits = logits[:, -1, :] + + # Logits processors + if any(self.logits_processors): + processed_logits = [] + for e in range(len(self.uids)): + sample_logits = logits[e : e + 1] + for processor in self.logits_processors[e]: + sample_logits = processor(self.tokens[e], sample_logits) + processed_logits.append(sample_logits) + logits = mx.concatenate(processed_logits, axis=0) + + # Normalize the logits + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + + # Sample + if any(self.samplers): + all_samples = [] + for e in range(len(self.uids)): + sample_sampler = self.samplers[e] or self.fallback_sampler + sampled = sample_sampler(logprobs[e : e + 1]) + all_samples.append(sampled) + sampled = mx.concatenate(all_samples, axis=0) + else: + sampled = self.fallback_sampler(logprobs) + + # Assign the next step to member variables and start computing it + # asynchronously + self._next_tokens = sampled + self._next_logprobs = list(logprobs) + mx.async_eval(self._next_tokens, self._next_logprobs, self._token_context) + + # Eval the current tokens and current logprobs. After that also add + # them to self.tokens so that it always represents the tokens contained + # in the KV Cache. + mx.eval(inputs, self._current_logprobs) + inputs = inputs.tolist() + for sti, ti in zip(self.tokens, inputs): + sti.append(ti) + return inputs, self._current_logprobs + + def extract_cache(self, idx: int) -> List[Any]: + return [c.extract(idx) for c in self.prompt_cache] + + def filter(self, keep: List[int]): + """Filter the batch to keep only the specified indices.""" + self.uids = [self.uids[idx] for idx in keep] + if not keep: + self.prompt_cache.clear() + else: + for c in self.prompt_cache: + c.filter(keep) + self.tokens = [self.tokens[idx] for idx in keep] + if any(self.samplers): + self.samplers = [self.samplers[idx] for idx in keep] + if any(self.logits_processors): + self.logits_processors = [self.logits_processors[idx] for idx in keep] + self.max_tokens = [self.max_tokens[idx] for idx in keep] + self.state_machines = [self.state_machines[idx] for idx in keep] + + self._next_tokens = self._next_tokens[keep] if keep else None + self._next_logprobs = [self._next_logprobs[idx] for idx in keep] + self._token_context = [self._token_context[idx] for idx in keep] + self._num_tokens = [self._num_tokens[idx] for idx in keep] + self._matcher_states = [self._matcher_states[idx] for idx in keep] + + def next(self) -> List[Response]: + """ + Generate the next batch of tokens. + + Returns: + List of Response objects for each sequence in the batch. + """ + if not self.uids: + return [] + + tokens, logprobs = self._step() + + keep = [] + responses = [] + for i in range(len(self.uids)): + finish_reason = None + match_sequence = None + + self._num_tokens[i] += 1 + if self._num_tokens[i] >= self.max_tokens[i]: + finish_reason = "length" + + self._matcher_states[i], match_sequence, current_state = ( + self.state_machines[i].match(self._matcher_states[i], tokens[i]) + ) + if match_sequence is not None and current_state is None: + finish_reason = "stop" + + if finish_reason is not None: + responses.append( + self.Response( + uid=self.uids[i], + token=tokens[i], + logprobs=logprobs[i], + finish_reason=finish_reason, + current_state=current_state, + match_sequence=match_sequence, + prompt_cache=self.extract_cache(i), + all_tokens=self.tokens[i], + ) + ) + else: + keep.append(i) + responses.append( + self.Response( + uid=self.uids[i], + token=tokens[i], + logprobs=logprobs[i], + finish_reason=None, + match_sequence=match_sequence, + current_state=current_state, + prompt_cache=None, + all_tokens=None, + ) + ) + + if len(keep) < len(self.uids): + self.filter(keep) + + return responses + + @classmethod + def empty( + cls, + model: nn.Module, + fallback_sampler: Callable[[mx.array], mx.array], + ): + return cls( + model=model, + fallback_sampler=fallback_sampler, + uids=[], + inputs=mx.array([], dtype=mx.uint32), + prompt_cache=[], + tokens=[], + samplers=[], + logits_processors=[], + max_tokens=[], + state_machines=[], + ) + + +class BatchGenerator: + """ + A batch generator implements continuous batching. + + This class provides automatic management of prompt processing and generation + batches, handling the transition between the two. + + It also allows for segmented prompt processing which guarantees that the + generator will stop at these boundaries when processing an input. + """ + + def __init__( + self, + model: nn.Module, max_tokens: int = 128, - stop_tokens: Optional[set] = None, + stop_tokens: Optional[Sequence[Sequence[int]]] = None, sampler: Optional[Callable[[mx.array], mx.array]] = None, logits_processors: Optional[ List[Callable[[mx.array, mx.array], mx.array]] @@ -953,31 +1498,34 @@ def __init__( completion_batch_size: int = 32, prefill_batch_size: int = 8, prefill_step_size: int = 2048, - prompt_checkpoint_callback: Optional[ - Callable[[List[Tuple[int, int, List[Any]]]], None] - ] = None, - prompt_progress_callback: Optional[ - Callable[[List[Tuple[int, int, int]]], None] - ] = None, - max_kv_size: Optional[int] = None, ): self.model = model - self.unprocessed_prompts = [] self.max_tokens = max_tokens - self.stop_tokens = stop_tokens or set() self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) self.logits_processors = logits_processors or [] self.uid_count = 0 self.prefill_step_size = prefill_step_size self.prefill_batch_size = prefill_batch_size self.completion_batch_size = max(completion_batch_size, prefill_batch_size) - self.prompt_checkpoint_callback = prompt_checkpoint_callback - self.prompt_progress_callback = prompt_progress_callback or (lambda *_: None) - self._stats = BatchStats() - self._next_count = 0 - self.max_kv_size = max_kv_size - self.active_batch = None + self._default_state_machine = SequenceStateMachine( + {"normal": [(seq, None) for seq in stop_tokens]} if stop_tokens else {}, + initial="normal", + ) + self._uid_count = 0 + self._prompt_batch = PromptProcessingBatch.empty( + self.model, + self.sampler, + prefill_step_size=prefill_step_size, + ) + self._generation_batch = GenerationBatch.empty(self.model, self.sampler) + self._unprocessed_sequences = deque() + self._currently_processing = [] + + self._prompt_tokens_counter = 0 + self._prompt_time_counter = 0 + self._gen_tokens_counter = 0 + self._steps_counter = 0 if mx.metal.is_available(): self._old_wired_limit = mx.set_wired_limit( @@ -995,338 +1543,314 @@ def close(self): def __del__(self): self.close() + @contextlib.contextmanager + def stats(self, stats=None): + stats = stats or BatchStats() + self._prompt_tokens_counter = 0 + self._prompt_time_counter = 0 + self._gen_tokens_counter = 0 + tic = time.perf_counter() + try: + yield stats + finally: + toc = time.perf_counter() + total_time = toc - tic + gen_time = total_time - self._prompt_time_counter + stats.prompt_tokens += self._prompt_tokens_counter + stats.prompt_time += self._prompt_time_counter + stats.prompt_tps = stats.prompt_tokens / stats.prompt_time + stats.generation_tokens += self._gen_tokens_counter + stats.generation_time += gen_time + stats.generation_tps = stats.generation_tokens / stats.generation_time + stats.peak_memory = max(stats.peak_memory, mx.get_peak_memory() / 1e9) + def insert( self, - prompts, - max_tokens: Union[List[int], int, None] = None, - caches=None, - samplers: list | None = None, - logits_processors: list | None = None, - prompt_checkpoints: list | int | None = None, + prompts: List[List[int]], + max_tokens: Optional[List[int]] = None, + caches: Optional[List[List[Any]]] = None, + all_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ] = None, + state_machines: Optional[List[SequenceStateMachine]] = None, ): - uids = [] + return self.insert_segments( + [[p] for p in prompts], + max_tokens, + caches, + all_tokens, + samplers, + logits_processors, + state_machines, + ) - if max_tokens is None or isinstance(max_tokens, int): - max_tokens = [max_tokens or self.max_tokens] * len(prompts) + def insert_segments( + self, + segments: List[List[List[int]]], + max_tokens: Optional[List[int]] = None, + caches: Optional[List[List[Any]]] = None, + all_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ] = None, + state_machines: Optional[List[SequenceStateMachine]] = None, + ): + uids = [] - if prompt_checkpoints is None or isinstance(prompt_checkpoints, int): - prompt_checkpoints = [prompt_checkpoints or -1] * len(prompts) + max_tokens = max_tokens or [self.max_tokens] * len(segments) + all_tokens = all_tokens or [[] for _ in segments] + samplers = samplers or [None] * len(segments) + logits_processors = logits_processors or ( + [self.logits_processors] * len(segments) + ) + state_machines = state_machines or ( + [self._default_state_machine] * len(segments) + ) - if caches is None: - caches = [None] * len(prompts) - for i in range(len(prompts)): + caches = caches or [None] * len(segments) + for i in range(len(segments)): if caches[i] is None: caches[i] = cache.make_prompt_cache(self.model) - samplers = samplers or [None] * len(prompts) - logits_processors = logits_processors or [self.logits_processors] * len(prompts) - - for p, m, c, s, lp, pc in zip( - prompts, max_tokens, caches, samplers, logits_processors, prompt_checkpoints - ): - self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp, pc)) - uids.append(self.uid_count) - self.uid_count += 1 - # Sort in ascending order of length - self.unprocessed_prompts = sorted( - self.unprocessed_prompts, - key=lambda x: len(x[1]) + max(c.size() for c in x[3]), - ) - return uids - - def remove(self, uids: List[int], return_prompt_caches: bool = False): - caches = {} - uids = set(uids) - if self.active_batch is not None: - batch = self.active_batch - if return_prompt_caches: - for e, uid in enumerate(batch.uids): - if uid not in uids: - continue - caches[uid] = batch.extract_cache(e) - keep_idx = [e for e, uid in enumerate(batch.uids) if uid not in uids] - if len(keep_idx) > 0: - batch.filter(keep_idx) - else: - self.active_batch = None - - for i in reversed(range(len(self.unprocessed_prompts))): - if self.unprocessed_prompts[i][0] in uids: - self.unprocessed_prompts.pop(i) - - if return_prompt_caches: - return caches - - @property - def prompt_cache_nbytes(self): - total = sum(c.nbytes for p in self.unprocessed_prompts for c in p[3]) - if self.active_batch is not None: - total += sum(c.nbytes for c in self.active_batch.cache) - return total - - def _process_prompts(self, prompts): - ( - uids, - inputs, + for seq, m, c, at, s, lp, sm in zip( + segments, max_tokens, caches, + all_tokens, samplers, logits_processors, - prompt_checkpoints, - ) = zip(*prompts) + state_machines, + ): + seq = list(seq) + if len(seq[-1]) != 1: + seq.append(seq[-1][-1:]) + seq[-2] = seq[-2][:-1] + self._unprocessed_sequences.append( + (self._uid_count, seq, m, c, at, s, lp, sm) + ) + uids.append(self._uid_count) + self._uid_count += 1 - lengths = [len(p) for p in inputs] - max_length = max(lengths) - padding = [max_length - l for l in lengths] + return uids - # Get the checkpoint token as an offset from the end of each prompt. - # Then select the largest one so that we perform the checkpoint at - # least `pc` before the end. - prompt_checkpoints = [ - (l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints) - ] - prompt_checkpoint = max(1, max(prompt_checkpoints)) - - self._stats.prompt_tokens += sum(lengths) - - tokens = [mx.array(inp) for inp in inputs] - processed_tokens = 0 - - # New prompts so - # 1. Left-pad the inputs - # 2. Process - if all(c[0].empty() for c in caches): - inputs = _left_pad_prompts(inputs, max_length=max_length) - prompt_cache = _make_cache(self.model, padding, self.max_kv_size) - - while inputs.shape[1] > prompt_checkpoint: - n_to_process = min( - self.prefill_step_size, inputs.shape[1] - prompt_checkpoint + def _find_uids(self, uids): + uids = set(uids) + results = {} + for i, uid_i in enumerate(self._generation_batch.uids): + if uid_i in uids: + results[uid_i] = (2, i) + for i, uid_i in enumerate(self._prompt_batch.uids): + if uid_i in uids: + results[uid_i] = (1, i) + for i, seq in enumerate(self._unprocessed_sequences): + if seq[0] in uids: + results[seq[0]] = (0, i) + return results + + def extract_cache(self, uids): + results = {} + for uid, (stage, idx) in self._find_uids(uids).items(): + if stage == 0: + results[uid] = self._unprocessed_sequences[idx][3:5] + elif stage == 1: + results[uid] = ( + self._prompt_batch.extract_cache(idx), + self._prompt_batch.tokens[idx], ) - self.model(inputs[:, :n_to_process], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - inputs = inputs[:, n_to_process:] - processed_tokens += n_to_process - self.prompt_progress_callback( - [ - (uid, processed_tokens, length) - for uid, length in zip(uids, lengths) - ] + else: + results[uid] = ( + self._generation_batch.extract_cache(idx), + self._generation_batch.tokens[idx], ) - mx.clear_cache() + return results - # Further prompt processing so we need to - # 1. Merge the KV caches and prepare for right padded prompts - # 2. Right pad the inputs - # 2. Process - # 3. Finalize the KV caches so they are left padded again - else: - last_inputs = mx.array([p[-prompt_checkpoint:] for p in inputs]) - inputs = _right_pad_prompts(inputs, max_length=max_length) - prompt_cache = _merge_caches(caches) - - for c in prompt_cache: - # subtract from lengths since we don't process the last - # `prompt_checkpoint` tokens during prefill - c.prepare( - lengths=[l - prompt_checkpoint for l in lengths], - right_padding=padding, - ) + def remove(self, uids, return_prompt_caches=False): + caches = {} + if return_prompt_caches: + caches = self.extract_cache(uids) - while inputs.shape[1] > prompt_checkpoint: - n_to_process = min( - self.prefill_step_size, inputs.shape[1] - prompt_checkpoint - ) - self.model(inputs[:, :n_to_process], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - inputs = inputs[:, n_to_process:] - processed_tokens += n_to_process - self.prompt_progress_callback( - [ - (uid, processed_tokens, length) - for uid, length in zip(uids, lengths) - ] - ) - mx.clear_cache() + keep = ( + set(range(len(self._unprocessed_sequences))), + set(range(len(self._prompt_batch))), + set(range(len(self._generation_batch))), + ) + for stage, idx in self._find_uids(uids).values(): + keep[stage].remove(idx) - mx.eval([c.state for c in prompt_cache]) - inputs = last_inputs - - for c in prompt_cache: - c.finalize() - - # We processed L - prompt_checkpoint tokens so call the checkpoint - # callback. - if self.prompt_checkpoint_callback is not None: - self.prompt_checkpoint_callback( - [ - (uid, prompt_checkpoint, _lazy_extract_cache(prompt_cache, i)) - for i, uid in enumerate(uids) - ] + if len(keep[0]) < len(self._unprocessed_sequences): + self._unprocessed_sequences = deque( + x for i, x in enumerate(self._unprocessed_sequences) if i in keep[0] ) - # Process the remaining prompt_checkpoint-1 tokens - if prompt_checkpoint > 1: - self.model(inputs[:, : prompt_checkpoint - 1], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - mx.clear_cache() - - y, logprobs = self._step( - inputs, prompt_cache, samplers, logits_processors, tokens - ) + if len(keep[1]) < len(self._prompt_batch): + self._prompt_batch.filter(sorted(keep[1])) + self._currently_processing = [ + x for i, x in enumerate(self._currently_processing) if i in keep[1] + ] + if len(keep[2]) < len(self._generation_batch): + self._generation_batch.filter(sorted(keep[2])) - mx.async_eval(y, logprobs) - - return Batch( - list(uids), - y, - logprobs, - list(max_tokens), - [0] * len(uids), - prompt_cache, - list(samplers), - list(logits_processors), - tokens, - ) + return caches - def _step( - self, - input_tokens: mx.array, - prompt_cache: List[Any], - samplers: list | None, - logits_processors: list | None, - tokens: List[mx.array], - ): - batch_size = input_tokens.shape[0] + @property + def prompt_cache_nbytes(self): + total = sum(c.nbytes for p in self._unprocessed_sequences for c in p[3]) + total += sum(c.nbytes for c in self._prompt_batch.prompt_cache) + total += sum(c.nbytes for c in self._generation_batch.prompt_cache) + return total - logits = self.model(input_tokens, cache=prompt_cache) - logits = logits[:, -1, :] + def _make_batch(self, n: int): + uids = [] + caches = [] + tokens = [] + samplers = [] + logits_processors = [] + max_tokens = [] + state_machines = [] + for _ in range(n): + sequence = self._unprocessed_sequences.popleft() + uids.append(sequence[0]) + caches.append(sequence[3]) + tokens.append(sequence[4]) + samplers.append(sequence[5]) + logits_processors.append(sequence[6]) + max_tokens.append(sequence[2]) + state_machines.append(sequence[7]) + self._currently_processing.append( + [sequence[1], 0, sum(len(s) for s in sequence[1])] + ) - if any(logits_processors): - processed_logits = [] - for e in range(batch_size): - sample_logits = logits[e : e + 1] - for processor in logits_processors[e]: - sample_logits = processor(tokens[e], sample_logits) - processed_logits.append(sample_logits) - logits = mx.concatenate(processed_logits, axis=0) + return PromptProcessingBatch( + model=self.model, + uids=uids, + caches=caches, + tokens=tokens, + prefill_step_size=self.prefill_step_size, + samplers=samplers, + fallback_sampler=self.sampler, + logits_processors=logits_processors, + state_machines=state_machines, + max_tokens=max_tokens, + ) - logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - if any(samplers): - all_samples = [] - for e in range(batch_size): - sample_sampler = samplers[e] or self.sampler - sampled = sample_sampler(logprobs[e : e + 1]) - all_samples.append(sampled) - sampled = mx.concatenate(all_samples, axis=0) - else: - sampled = self.sampler(logprobs) + def _next(self): + generation_responses = [] + prompt_responses = [] + + # Generate tokens first + if len(self._generation_batch) > 0: + generation_responses = self._generation_batch.next() + self._gen_tokens_counter += len(generation_responses) + self._steps_counter += 1 + if self._steps_counter % 512 == 0: + mx.clear_cache() - return sampled, list(logprobs) + # Exit early because we already have our hands full with decoding + if len(self._generation_batch) >= self.completion_batch_size: + return prompt_responses, generation_responses - def stats(self): - self._stats.prompt_tps = self._stats.prompt_tokens / self._stats.prompt_time - self._stats.generation_tps = ( - self._stats.generation_tokens / self._stats.generation_time + # Check if we have sequences and add them to the prompt batch + n = min( + self.prefill_batch_size - len(self._prompt_batch), + self.completion_batch_size - len(self._generation_batch), + len(self._unprocessed_sequences), ) - self._stats.peak_memory = mx.get_peak_memory() / 1e9 - return self._stats + if n > 0: + self._prompt_batch.extend(self._make_batch(n)) + + # Split the prompt sequences to the ones moving to generation and the rest + keep = [] + split = [] + for i, seq in enumerate(self._currently_processing): + segments = seq[0] + if len(segments) == 1 and len(segments[0]) == 1: + split.append(i) + else: + keep.append(i) + + # Actually split off part of the prompt batch and start generation + if split: + last_inputs = [self._currently_processing[i][0][0] for i in split] + progress = [(self._currently_processing[i][2],) * 2 for i in split] + self._currently_processing = [self._currently_processing[i] for i in keep] + gen_batch = self._prompt_batch.split(split).generate(last_inputs) + for i, p in enumerate(progress): + prompt_responses.append( + PromptProcessingBatch.Response( + gen_batch.uids[i], + p, + True, + True, + ) + ) + self._generation_batch.extend(gen_batch) - def _next(self): + # Extract the next prompts input + prompts = [] + for i, seq in enumerate(self._currently_processing): + response = PromptProcessingBatch.Response( + self._prompt_batch.uids[i], 0, False, False + ) + segments = seq[0] + n = min(len(segments[0]), self.prefill_step_size) + prompts.append(segments[0][:n]) + segments[0] = segments[0][n:] + if len(segments[0]) == 0: + segments.pop(0) + response.end_of_segment = True + seq[1] += len(prompts[-1]) + response.progress = (seq[1], seq[2]) + prompt_responses.append(response) + + # Process the prompts + self._prompt_tokens_counter += sum(len(p) for p in prompts) tic = time.perf_counter() + self._prompt_batch.prompt(prompts) + toc = time.perf_counter() + self._prompt_time_counter += toc - tic - prompt_processing = False - batch = self.active_batch - num_active = len(batch) if batch else 0 - num_to_add = self.completion_batch_size - num_active - while num_to_add >= self.prefill_batch_size: - prompts = self.unprocessed_prompts[: self.prefill_batch_size] - # Finish processing the last examples of the last batch - if len(prompts) == 0 and num_active > 0: - break - # No more prompts and no more completions, all done - elif len(prompts) == 0: - self.active_batch = None - return [] - # Process prompts - if batch is not None and not prompt_processing: - # Finish any active completion tokens - mx.eval(batch.y, batch.logprobs) - self._stats.generation_time += time.perf_counter() - tic - tic = time.perf_counter() + return prompt_responses, generation_responses - batch = self._process_prompts(prompts) - self.unprocessed_prompts = self.unprocessed_prompts[ - self.prefill_batch_size : - ] - prompt_processing = True - # If there was no active batch, set it - if self.active_batch is None: - self.active_batch = batch - else: - self.active_batch.extend(batch) - - num_active = len(self.active_batch) - num_to_add -= len(batch) - - batch = self.active_batch - y, logprobs = batch.y, batch.logprobs - for i, toks in enumerate(batch.tokens): - batch.tokens[i] = mx.concatenate((toks, y[i : i + 1])) - batch.y, batch.logprobs = self._step( - y[:, None], - batch.cache, - batch.samplers, - batch.logits_processors, - batch.tokens, - ) + def next(self): + """ + Get the next batch of responses. - mx.async_eval(batch.y, batch.logprobs, batch.tokens) + Returns: + Tuple of prompt processing responses and generation responses. + """ + with mx.stream(generation_stream): + return self._next() - y = y.tolist() - toc = time.perf_counter() - if prompt_processing: - self._stats.prompt_time += toc - tic - else: - self._stats.generation_time += toc - tic - keep_idx = [] - end_idx = [] - responses = [] + def next_generated(self): + """ + Return only generated tokens ignoring batch generation responses. - for e, (t, uid, num_tok, max_tok) in enumerate( - zip(y, batch.uids, batch.num_tokens, batch.max_tokens) - ): - cache = None - num_tok += 1 - batch.num_tokens[e] = num_tok - if t in self.stop_tokens: - finish_reason = "stop" - end_idx.append(e) - elif num_tok >= max_tok: - finish_reason = "length" - end_idx.append(e) - else: - finish_reason = None - keep_idx.append(e) - if finish_reason is not None: - cache = batch.extract_cache(e) - responses.append(self.Response(uid, t, logprobs[e], finish_reason, cache)) + Returns: + List of GenerationBatch.Response objects + """ + with mx.stream(generation_stream): + while True: + prompt_responses, generation_responses = self._next() + if not generation_responses and prompt_responses: + continue + return generation_responses - # Remove any finished completions - if len(end_idx): - if len(keep_idx) > 0: - batch.filter(keep_idx) - else: - self.active_batch = None - self._next_count += 1 - if self._next_count % 512 == 0: - mx.clear_cache() - self._stats.generation_tokens += len(responses) - return responses +@dataclass +class BatchResponse: + """ + A data object to hold a batch generation response. - def next(self): - with mx.stream(generation_stream): - return self._next() + Args: + texts: (List[str]): The generated text for each prompt. + stats (BatchStats): Statistics about the generation. + """ + + texts: List[str] + stats: BatchStats + caches: Optional[List[List[Any]]] def batch_generate( @@ -1337,6 +1861,7 @@ def batch_generate( max_tokens: Union[int, List[int]] = 128, verbose: bool = False, return_prompt_caches: bool = False, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, **kwargs, ) -> BatchResponse: """ @@ -1355,13 +1880,15 @@ def batch_generate( can be per prompt if a list is provided. return_prompt_caches (bool): Return the prompt caches in the batch responses. Default: ``False``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed logits. Default: ``None``. kwargs: The remaining options get passed to :obj:`BatchGenerator`. See :obj:`BatchGenerator` for more details. """ gen = BatchGenerator( model, - stop_tokens=tokenizer.eos_token_ids, + stop_tokens=[[t] for t in tokenizer.eos_token_ids], **kwargs, ) num_samples = len(prompts) @@ -1369,29 +1896,32 @@ def batch_generate( if verbose: print(f"[batch_generate] Finished processing 0/{num_samples} ...", end="\r") + if isinstance(max_tokens, int): + max_tokens = [max_tokens] * len(prompts) + uids = gen.insert(prompts, max_tokens, caches=prompt_caches) results = {uid: [] for uid in uids} prompt_caches = {} - while responses := gen.next(): - for r in responses: - if r.finish_reason is not None: - if return_prompt_caches: - prompt_caches[r.uid] = r.prompt_cache - if verbose: - fin += 1 - print( - f"[batch_generate] Finished processing {fin}/{num_samples} ...", - end="\r", - ) - if r.finish_reason != "stop": - results[r.uid].append(r.token) + with gen.stats() as stats: + while responses := gen.next_generated(): + for r in responses: + if r.finish_reason is not None: + if return_prompt_caches: + prompt_caches[r.uid] = r.prompt_cache + if verbose: + fin += 1 + print( + f"[batch_generate] Finished processing {fin}/{num_samples} ...", + end="\r", + ) + if r.finish_reason != "stop": + results[r.uid].append(r.token) gen.close() if verbose: print(f"[batch_generate] Finished processing {fin}/{num_samples}") # Return results in correct order texts = [tokenizer.decode(results[uid]) for uid in uids] - stats = gen.stats() caches = [prompt_caches[uid] for uid in uids] if return_prompt_caches else None if verbose: print( diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 756ce4ecf..d6f101f90 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -621,13 +621,23 @@ def filter(self, batch_indices): """ In-place filter to keep just the given indices in the cache. """ - self.cache = [c[batch_indices] for c in self.cache] + self.cache = [c[batch_indices] if c is not None else None for c in self.cache] + if self.lengths is not None: + self.lengths = self.lengths[batch_indices] def extend(self, other): """ In-place extend this cache with the other cache. """ - self.cache = [mx.concatenate([c, o]) for c, o in zip(self.cache, other.cache)] + + def cat(a, b): + if a is None: + return b + if b is None: + return a + return mx.concatenate([a, b]) + + self.cache = [cat(c, o) for c, o in zip(self.cache, other.cache)] def extract(self, idx): cache = ArraysCache(len(self.cache)) @@ -662,6 +672,11 @@ def merge(cls, caches): n_state = len(caches[0].cache) B = len(caches) cache = cls(n_state) + + # All caches are empty so return early + if all(c.empty() for c in caches): + return cache + for e in range(n_state): c_init = next(iter(c[e] for c in caches if c[e] is not None)) shape = list(c_init.shape) @@ -970,16 +985,18 @@ def filter(self, batch_indices): """ In-place filter to keep just the given indices in the cache. """ - self.keys = self.keys[batch_indices] - self.values = self.values[batch_indices] + if self.keys is not None: + self.keys = self.keys[batch_indices] + self.values = self.values[batch_indices] self.offset = self.offset[batch_indices] self.left_padding = self.left_padding[batch_indices] # Shift left to reduce padding min_left_pad = self.left_padding.min().item() if min_left_pad > 0: - self.keys = self.keys[..., min_left_pad:, :] - self.values = self.values[..., min_left_pad:, :] + if self.keys is not None: + self.keys = self.keys[..., min_left_pad:, :] + self.values = self.values[..., min_left_pad:, :] self._idx -= min_left_pad self.left_padding -= min_left_pad @@ -987,15 +1004,30 @@ def extend(self, other): """ In-place extend this cache with the other cache. """ + if self.keys is None and other.keys is None: + self.left_padding = mx.concatenate([self.left_padding, other.left_padding]) + self.offset = mx.concatenate([self.offset, other.offset]) + return + max_idx = max(self._idx, other._idx) - max_size = max(self.keys.shape[2], other.keys.shape[2]) + L1 = L2 = 0 + if self.keys is not None: + B, H, L1, D = self.keys.shape + M = self.values.shape[3] + if other.keys is not None: + B, H, L2, D = other.keys.shape + M = other.values.shape[3] + max_size = max(L1, L2) # Pad the keys and values so they are right-justified # with the index and the same size def pad(c): - left = max_idx - c._idx - right = max_size - c.keys.shape[2] - left k, v = c.keys, c.values + if k is None: + k = mx.array([]).reshape(B, H, 0, D) + v = mx.array([]).reshape(B, H, 0, M) + left = max_idx - c._idx + right = max_size - k.shape[2] - left if right < 0: k = k[..., :right, :] v = v[..., :right, :] @@ -1024,6 +1056,11 @@ def extract(self, idx): def merge(cls, caches): lengths = [c.size() for c in caches] max_length = max(lengths) + + # No cache has content so make an empty one + if max_length == 0: + return BatchKVCache([0] * len(caches)) + padding = [max_length - l for l in lengths] B = len(caches) H = max(c.keys.shape[1] for c in caches if c.keys is not None) @@ -1047,6 +1084,9 @@ def merge(cls, caches): return cache + def size(self): + return self._idx + def empty(self): return self.keys is None @@ -1287,8 +1327,9 @@ def filter(self, batch_indices): """ In-place filter to keep just the given indices in the cache. """ - self.keys = self.keys[batch_indices] - self.values = self.values[batch_indices] + if self.keys is not None: + self.keys = self.keys[batch_indices] + self.values = self.values[batch_indices] self.offset = self.offset[batch_indices] self.left_padding = self.left_padding[batch_indices] @@ -1296,17 +1337,32 @@ def extend(self, other): """ In-place extend this cache with the other cache. """ + if self.keys is None and other.keys is None: + self.left_padding = mx.concatenate([self.left_padding, other.left_padding]) + self.offset = mx.concatenate([self.offset, other.offset]) + return + if (self.rotated != other.rotated) or self._idx != other._idx: self._temporal_order() other._temporal_order() max_idx = max(self._idx, other._idx) - max_size = max(self.keys.shape[2], other.keys.shape[2]) + L1 = L2 = 0 + if self.keys is not None: + B, H, L1, D = self.keys.shape + M = self.values.shape[3] + if other.keys is not None: + B, H, L2, D = other.keys.shape + M = other.values.shape[3] + max_size = max(L1, L2) def pad(c): left = max_idx - c._idx - right = max_size - c.keys.shape[2] - left k, v = c.keys, c.values + if k is None: + k = mx.array([]).reshape(B, H, 0, D) + v = mx.array([]).reshape(B, H, 0, M) + right = max_size - k.shape[2] - left if right < 0: k = k[..., :right, :] v = v[..., :right, :] @@ -1351,6 +1407,11 @@ def merge(cls, caches): offsets = [c.offset for c in caches] lengths = [c.size() for c in caches] max_length = max(lengths) + + # No cache has content so make an empty one + if max_length == 0: + return cls(caches[0].max_size, [0] * len(caches)) + padding = [max_length - l for l in lengths] B = len(caches) H = max(c.keys.shape[1] for c in caches if c.keys is not None) @@ -1360,11 +1421,11 @@ def merge(cls, caches): keys = mx.zeros((B, H, max_length, Dk), dtype=dt) values = mx.zeros((B, H, max_length, Dv), dtype=dt) - for i, (p, c) in enumerate(zip(padding, caches)): + for i, (p, l, c) in enumerate(zip(padding, lengths, caches)): if c.keys is None: continue - keys[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.keys) - values[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.values) + keys[i : i + 1, :, p : p + l] = c._temporal_order(c.keys) + values[i : i + 1, :, p : p + l] = c._temporal_order(c.values) cache = cls(caches[0].max_size, padding) cache.keys = keys @@ -1375,6 +1436,9 @@ def merge(cls, caches): return cache + def size(self): + return min(self._offset, self.max_size) + def empty(self): return self.keys is None @@ -1490,6 +1554,7 @@ class LRUPromptCache: class CacheEntry: prompt_cache: List[Any] nbytes: int + cache_type: str class CacheOrder: def __init__(self, ordering: List[str] = ["assistant", "user", "system"]): @@ -1515,7 +1580,7 @@ def pop(self): while i + 1 < len(self._ordering): lru_a = self._lrus[self._ordering[i]] lru_b = self._lrus[self._ordering[i + 1]] - if len(lru_a) >= len(lru_b): + if lru_a and len(lru_a) >= len(lru_b): return lru_a.popleft() i += 1 return lru_b.popleft() @@ -1526,6 +1591,7 @@ def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): self._trie = PromptTrie() self._lru = LRUPromptCache.CacheOrder() self._n_bytes = 0 + self._n_bytes_by_type = {k: 0 for k in self._lru._ordering} def __len__(self): return len(self._lru) @@ -1566,14 +1632,16 @@ def insert_cache( ): # Make the cache entry entry = LRUPromptCache.CacheEntry( - prompt_cache, sum(c.nbytes for c in prompt_cache) + prompt_cache, sum(c.nbytes for c in prompt_cache), cache_type ) # Insert into the trie and update the byte counter and lru position self._n_bytes += entry.nbytes + self._n_bytes_by_type[cache_type] += entry.nbytes prev = self._trie.add(model, tokens, entry) if prev is not None: self._n_bytes -= prev.nbytes + self._n_bytes_by_type[prev.cache_type] -= prev.nbytes self._lru.remove(model, tokens) self._lru.push(model, tokens, cache_type) @@ -1582,6 +1650,7 @@ def insert_cache( if can_trim_prompt_cache(prompt_cache): for prefix_len, entry in self._trie.pop_prefixes(model, tokens): self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes self._lru.remove(model, tokens[:prefix_len]) # Ensure we match the constraints @@ -1589,10 +1658,12 @@ def insert_cache( model, tokens = self._lru.pop() entry = self._trie.pop(model, tokens) self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes while self._n_bytes > self.max_bytes: model, tokens = self._lru.pop() entry = self._trie.pop(model, tokens) self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes def trim_to( self, *, n_sequences: Optional[int] = None, n_bytes: Optional[int] = None @@ -1604,7 +1675,18 @@ def trim_to( model, tokens = self._lru.pop() entry = self._trie.pop(model, tokens) self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes while self._n_bytes > n_bytes: model, tokens = self._lru.pop() entry = self._trie.pop(model, tokens) self._n_bytes -= entry.nbytes + self._n_bytes_by_type[entry.cache_type] -= entry.nbytes + + def stats_by_type(self): + result = {} + for cache_type in self._lru._ordering: + result[cache_type] = { + "n_sequences": len(self._lru._lrus[cache_type]), + "n_bytes": self._n_bytes_by_type[cache_type], + } + return result diff --git a/mlx_lm/models/deepseek_v32.py b/mlx_lm/models/deepseek_v32.py index e40e52950..7c97682e7 100644 --- a/mlx_lm/models/deepseek_v32.py +++ b/mlx_lm/models/deepseek_v32.py @@ -87,19 +87,15 @@ def __call__( b, s, _ = x.shape q = self.wq_b(qr) q = q.reshape(b, s, self.n_heads, self.head_dim).swapaxes(1, 2) - q_pe, q_nope = mx.split(q, [self.rope_head_dim], axis=-1) + k = self.wk(x) + k = self.k_norm(k) + k = mx.reshape(k, (b, 1, s, self.head_dim)) offset = cache.offset if cache is not None else 0 - q_pe = self.rope(q_pe, offset=offset) - q = mx.concatenate([q_pe, q_nope], axis=-1) + q = self.rope(q, offset=offset) + k = self.rope(k, offset=offset) - k = self.wk(x) - k = self.k_norm(k) - k = mx.reshape(k, (b, 1, s, self.head_dim)) - k_pe, k_nope = mx.split(k, [self.rope_head_dim], axis=-1) - k_pe = self.rope(k_pe, offset=offset) - k = mx.concatenate([k_pe, k_nope], axis=-1) if cache is not None: k, _ = cache.update_and_fetch(k, mx.zeros([b, 1, s, 0])) if k.shape[2] <= self.index_topk: @@ -221,7 +217,8 @@ def __call__( mx.broadcast_to(idx, idx.shape[:-1] + (k_pe.shape[-1],)), axis=2, ) - mask = None + if mask is not None: + mask = mx.take_along_axis(mask, topk_indices, axis=-1) else: shape = list(topk_indices.shape) shape[-1] = kv_latent.shape[2] diff --git a/mlx_lm/models/gated_delta.py b/mlx_lm/models/gated_delta.py index af983f6d3..fa6a2ed3f 100644 --- a/mlx_lm/models/gated_delta.py +++ b/mlx_lm/models/gated_delta.py @@ -81,6 +81,8 @@ def _make_gated_delta_kernel(has_mask=False, vectorized=False): if (thread_index_in_simdgroup == 0) {{ y[dv_idx] = static_cast(out); }} + }} else {{ + y[dv_idx] = static_cast(0); }} // Increment data pointers to next time step q_ += Hk * Dk; diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 636a6ede4..a86710e74 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -158,7 +158,12 @@ def __call__( conv_input = mx.concatenate([conv_state, qkv], axis=1) if cache is not None: n_keep = self.conv_kernel_size - 1 - cache[0] = mx.contiguous(conv_input[:, -n_keep:]) + if cache.lengths is not None: + ends = mx.clip(cache.lengths, 0, S) + positions = (ends[:, None] + mx.arange(n_keep))[..., None] + cache[0] = mx.take_along_axis(conv_input, positions, axis=1) + else: + cache[0] = mx.contiguous(conv_input[:, -n_keep:, :]) conv_out = nn.silu(self.conv1d(conv_input)) q, k, v = [ diff --git a/mlx_lm/server.py b/mlx_lm/server.py index c5d1f95c3..6c69ea74c 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -9,7 +9,8 @@ import time import uuid import warnings -from dataclasses import dataclass, field +from collections import deque +from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from queue import Empty as QueueEmpty @@ -32,12 +33,15 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ -from .generate import BatchGenerator, generation_stream, stream_generate +from .generate import ( + BatchGenerator, + SequenceStateMachine, + generation_stream, + stream_generate, +) from .models.cache import ( LRUPromptCache, - can_trim_prompt_cache, make_prompt_cache, - trim_prompt_cache, ) from .sample_utils import make_logits_processors, make_sampler from .utils import _parse_size, load, sharded_load @@ -48,67 +52,37 @@ def get_system_fingerprint(): return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" -class StopCondition(NamedTuple): - stop_met: bool - trim_length: int - trim_text_length: int - - -def stopping_criteria( - tokens: List[int], - eos_token_ids: set, - stop_id_sequences: List[List[int]], - stop_words: List[str], -) -> StopCondition: - """ - Determines whether the token generation should stop based on predefined - conditions. - - Args: - tokens (List[int]): The current sequence of generated tokens. - eos_token_ids (set): The token IDs that represents the - end-of-sequence. If the last token in ``tokens`` is in the set, - the generation should stop. - stop_id_sequences (List[List[[int]]): A list of integer lists, each - representing a sequence of token IDs. If the end of the `tokens` - list matches any of these sequences, the generation should stop. - stop_words (List[str]): The stop words that correspond to the - ``stop_id_sequences``. - - Returns: - StopCondition: A named tuple indicating whether the stop condition has - been met (`stop_met`) and how many tokens should be trimmed from the - end if it has (`trim_length`) as well as the text that should be - trimmed. - """ - if tokens and tokens[-1] in eos_token_ids: - return StopCondition(stop_met=True, trim_length=0, trim_text_length=0) - - for stop_ids, stop_word in zip(stop_id_sequences, stop_words): - if len(tokens) >= len(stop_ids): - if tokens[-len(stop_ids) :] == stop_ids: - return StopCondition( - stop_met=True, - trim_length=len(stop_ids), - trim_text_length=len(stop_word), - ) - - return StopCondition(stop_met=False, trim_length=0, trim_text_length=0) - - -def sequence_overlap(s1: Sequence, s2: Sequence) -> bool: - """ - Checks if a suffix of s1 has overlap with a prefix of s2 +class ToolCallFormatter: + def __init__(self, tool_parser, tools, streaming=False): + self._idx = 0 + self._tool_parser = tool_parser + self._tools = tools + self._streaming = streaming + + def _format(self, tc): + tc_id = tc.pop("id", None) or str(uuid.uuid4()) + tc["arguments"] = json.dumps(tc["arguments"], ensure_ascii=False) + out = { + "function": tc, + "type": "function", + "id": tc_id, + } + if self._streaming: + out["index"] = self._idx + self._idx += 1 + return out - Args: - s1 (Sequence): The first sequence - s2 (Sequence): The second sequence + def __call__(self, tool_calls): + if not tool_calls: + return [] - Returns: - bool: If the two sequences have overlap - """ - max_overlap = min(len(s1), len(s2)) - return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1)) + result = [] + for tool_text in tool_calls: + parsed = self._tool_parser(tool_text, self._tools) + if not isinstance(parsed, list): + parsed = [parsed] + result.extend(self._format(tc) for tc in parsed) + return result def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): @@ -122,7 +96,7 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): "assistant": "ASSISTANT: ", "stop": "\n", } - role_mapping = role_mapping if role_mapping is not None else default_role_mapping + role_mapping = role_mapping or default_role_mapping prompt = "" for line in messages: @@ -152,7 +126,7 @@ def process_message_content(messages): """ for message in messages: - content = message.get("content", None) + content = message.get("content") if isinstance(content, list): text_fragments = [ fragment["text"] for fragment in content if fragment["type"] == "text" @@ -162,10 +136,11 @@ def process_message_content(messages): message["content"] = "".join(text_fragments) elif content is None: message["content"] = "" - if tool_calls := message.get("tool_calls", False): + + if tool_calls := message.get("tool_calls"): for tool_call in tool_calls: - if func := tool_call.get("function", False): - if args := func.get("arguments", False): + if func := tool_call.get("function"): + if args := func.get("arguments"): func["arguments"] = json.loads(args) @@ -227,15 +202,11 @@ class CompletionRequest: @dataclass class GenerationContext: has_tool_calling: bool - tool_call_start: str - tool_call_end: str - tool_parser: Callable[[str, Any], Dict] has_thinking: bool - think_start_id: int - think_end_id: int - think_end: str - eos_token_ids: set - stop_token_sequences: List[List[int]] + tool_parser: Callable[[str, Any], Dict] + + sequences: Dict[Tuple[int], str] + prompt: List[int] prompt_cache_count: int = -1 @@ -249,6 +220,8 @@ def stop(self): class Response: text: str token: int + state: str + match: Tuple[int] logprob: float finish_reason: Optional[str] top_tokens: Tuple[Dict[str, Any]] @@ -260,7 +233,6 @@ def __init__(self, budget=0.5, iterations=25, sync_frequency=10): self._budget = budget self._iterations = iterations self._sync_frequency = sync_frequency - self._start = None self._current_iterations = None self._loops = 0 @@ -278,20 +250,22 @@ def __next__(self): return None self._current_iterations += 1 - if self._current_iterations > self._iterations: - self._loops += 1 - self._time_spent += time.time() - self._start - if self._loops % self._sync_frequency == 0: - with mx.stream(generation_stream): - loop_time = mx.distributed.all_sum(self._time_spent).item() - avg_loop_time = loop_time / ( - mx.distributed.init().size() * self._sync_frequency - ) - factor = self._budget / avg_loop_time - self._iterations = max(round(self._iterations * factor), 1) - self._loops = 0 - self._time_spent = 0 - raise StopIteration() + if self._current_iterations <= self._iterations: + return None + + self._loops += 1 + self._time_spent += time.time() - self._start + if self._loops % self._sync_frequency == 0: + with mx.stream(generation_stream): + loop_time = mx.distributed.all_sum(self._time_spent).item() + avg_loop_time = loop_time / ( + mx.distributed.init().size() * self._sync_frequency + ) + factor = self._budget / avg_loop_time + self._iterations = max(round(self._iterations * factor), 1) + self._loops = 0 + self._time_spent = 0 + raise StopIteration() class ModelProvider: @@ -430,17 +404,17 @@ def _make_logits_processors(args): ) -def _format_top_logprobs(logprobs, top_logprobs, tokenizer) -> Tuple[Dict[str, Any]]: - """Returns info dicts for the top `top_logprobs` tokens from `logprobs`""" - if top_logprobs <= 0: +def _format_top_logprobs(logprobs, top_n, tokenizer) -> Tuple[Dict[str, Any]]: + """Returns info dicts for the top `top_n` tokens from `logprobs`""" + if top_n <= 0: return () - sorted_indices = mx.argpartition(-logprobs, kth=top_logprobs - 1) - top_indices = sorted_indices[:top_logprobs].tolist() - top_logprobs = logprobs[top_indices].tolist() + sorted_indices = mx.argpartition(-logprobs, kth=top_n - 1) + top_indices = sorted_indices[:top_n].tolist() + top_probs = logprobs[top_indices].tolist() txts = tokenizer.convert_ids_to_tokens(top_indices) return tuple( {"id": i, "token": s, "logprob": g} - for i, s, g in zip(top_indices, txts, top_logprobs) + for i, s, g in zip(top_indices, txts, top_probs) ) @@ -449,6 +423,7 @@ def __init__(self, model_provider: ModelProvider, prompt_cache: LRUPromptCache): self.model_provider = model_provider self.prompt_cache = prompt_cache self.requests = Queue() + self._state_machine_cache = {} self._time_budget = TimeBudget() self._is_distributed = mx.distributed.init().size() > 1 @@ -465,9 +440,15 @@ def join(self): self._generation_thread.join() def _log_cache_stats(self): - ncaches = len(self.prompt_cache) - nbytes = self.prompt_cache.nbytes - logging.info(f"KV Caches: {ncaches} seq, {nbytes / 1e9:.2f} GB") + n_sequences = len(self.prompt_cache) + n_bytes = self.prompt_cache.nbytes + logging.info(f"Prompt Cache: {n_sequences} sequences, {n_bytes / 1e9:.2f} GB") + for cache_type, stats in self.prompt_cache.stats_by_type().items(): + n_sequences = stats["n_sequences"] + n_bytes = stats["n_bytes"] + logging.info( + f"- {cache_type}: {n_sequences} sequences, {n_bytes / 1e9:.2f} GB" + ) def _next_request(self, timeout=None): request = None @@ -479,7 +460,6 @@ def _next_request(self, timeout=None): request = self.requests.get_nowait() except QueueEmpty: pass - return self._share_request(request) def _share_object(self, obj): @@ -491,19 +471,17 @@ def _share_object(self, obj): if obj is None: mx.eval(mx.distributed.all_sum(0)) return None - else: - data = mx.array(pickle.dumps(obj)) - mx.eval(mx.distributed.all_sum(data.size)) - mx.eval(mx.distributed.all_sum(data)) - return obj + data = mx.array(pickle.dumps(obj)) + mx.eval(mx.distributed.all_sum(data.size)) + mx.eval(mx.distributed.all_sum(data)) + return obj else: size = mx.distributed.all_sum(0).item() if size == 0: return None - else: - data = mx.zeros(size, dtype=mx.uint8) - data = mx.distributed.all_sum(data) - return pickle.loads(data) + data = mx.zeros(size, dtype=mx.uint8) + data = mx.distributed.all_sum(data) + return pickle.loads(data) def _share_request(self, request): if not self._is_distributed: @@ -518,6 +496,19 @@ def _share_request(self, request): return rq, *shareable def _tokenize(self, tokenizer, request, args): + """Tokenize a request and split the prompt into segments. + + Returns a tuple + + * prompt - Full list of tokens + * segments - A list of lists of tokens. Up to 3 segments that + correspond to system prompt, context, thinking tail. + * segment_types - A string per segment indicating if the segment is a + system prompt or a user prompt or nothing special. + * initial state - A string that contains the initial state of the + state machine (normal or thinking depending on whether we have tail + or not) + """ if request.request_type == "chat": messages = request.messages tools = request.tools @@ -536,43 +527,151 @@ def _tokenize(self, tokenizer, request, args): if args.chat_template_kwargs: chat_template_args = chat_template_args.copy() chat_template_args.update(args.chat_template_kwargs) - return tokenizer.apply_chat_template( - messages, + template_kwargs = dict( tools=tools, - add_generation_prompt=True, tokenize=True, **chat_template_args, ) + prompt = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + **template_kwargs, + ) else: - return tokenizer.encode(convert_chat(messages, role_mapping)) + prompt = tokenizer.encode(convert_chat(messages, role_mapping)) + return prompt, [prompt], ["assistant"], "normal" else: - return tokenizer.encode(request.prompt) - - def _compute_prompt_checkpoint(self, tokenizer, request, prompt): - if request.request_type != "chat": - return False, -1 - if request.messages[-1]["role"] != "user": - return False, -1 - - # Save the KV cache at the end of the prompt just before - # the think start token which will likely be removed in the - # next turn. - prompt_checkpoint = -1 + prompt = tokenizer.encode(request.prompt) + return prompt, [prompt], ["assistant"], "normal" + + # If we are here it means we have a chat request so we need to search + # for segments for better cache management. + + # Choose the initial state among only reasoning or normal + initial_state = "normal" + if tokenizer.has_thinking: + for i in range(-1, -len(prompt), -1): + if prompt[i] == tokenizer.think_start_id: + initial_state = "reasoning" + break + if prompt[i] == tokenizer.think_end_id: + break + + # It is not a user message so no segmentation needed. + if messages[-1]["role"] != "user": + return prompt, [prompt], ["assistant"], initial_state + + segments = [] + segment_types = [] + + # Find where the system prompt ends and add it as a segment. + num_system = 0 + sys_end = 0 + for m in messages: + if m["role"] == "system": + num_system += 1 + else: + break + if num_system > 0: + sys_tokens = tokenizer.apply_chat_template( + messages[:num_system] + [{"role": "user", "content": ""}], + add_generation_prompt=False, + **template_kwargs, + ) + for i, (a, b) in enumerate(zip(sys_tokens, prompt)): + if a != b: + sys_end = i + break + if sys_end > 0 and sys_end < len(prompt): + segments.append(prompt[:sys_end]) + segment_types.append("system") + + # Find a tail segment that contains thinking tokens (small up to 11 + # tokens) + tail_start = len(prompt) if tokenizer.has_thinking: - for i in range(1, min(11, len(prompt)) - 1, 1): + for i in range(1, min(11, len(prompt) - sys_end), 1): if prompt[-i] == tokenizer.think_start_id: - prompt_checkpoint = -i - 1 + tail_start = len(prompt) - i break - return True, prompt_checkpoint + # Finalize the segments and return + if sys_end < tail_start: + segments.append(prompt[sys_end:tail_start]) + segment_types.append("user") + if tail_start < len(prompt): + segments.append(prompt[tail_start:]) + segment_types.append("assistant") + if not segments: + segments = [prompt] + segment_types = ["assistant"] + + return prompt, segments, segment_types, initial_state + + def _make_state_machine( + self, model_key, tokenizer, stop_words, initial_state="normal" + ): + """Make a new SequenceStateMachine or fetch it if we 've made it before. - def _is_batchable(self, args): - if not self.model_provider.is_batchable: - return False - if args.seed is not None: - return False + Return also a dictionary that maps the token sequences in the state + machine to their strings. + """ + cache_key = (model_key, tuple(stop_words), initial_state) + rs = self._state_machine_cache.get(cache_key) + if rs is not None: + return rs + + # Will hold the state machine transitions and the sequences map to + # strings. + transitions = {} + sequences = {} + + # Add all the stop sequences + common_stops = [] + for t in tokenizer.eos_token_ids: + sequences[(t,)] = tokenizer.convert_ids_to_tokens(t) + common_stops.append(((t,), None)) + for w in stop_words: + t = tuple(tokenizer.encode(w, add_special_tokens=False)) + sequences[t] = w + common_stops.append((t, None)) + + # From normal to stop + transitions["normal"] = list(common_stops) + + # Reasoning related transitions + if tokenizer.has_thinking: + ts = tokenizer.think_start_id + te = tokenizer.think_end_id + transitions["normal"].append(((ts,), "reasoning")) + transitions["reasoning"] = [((te,), "normal")] + transitions["reasoning"].extend(common_stops) + sequences[(ts,)] = tokenizer.convert_ids_to_tokens(ts) + sequences[(te,)] = tokenizer.convert_ids_to_tokens(te) + + # Tool calling relating transitions + if tokenizer.has_tool_calling: + ts = tuple( + tokenizer.encode(tokenizer.tool_call_start, add_special_tokens=False) + ) + te = tuple( + tokenizer.encode(tokenizer.tool_call_end, add_special_tokens=False) + ) + transitions["normal"].append((ts, "tool")) + transitions["tool"] = [(te, "normal")] + transitions["tool"].extend(common_stops) + sequences[ts] = tokenizer.tool_call_start + sequences[te] = tokenizer.tool_call_end - return True + sm = SequenceStateMachine(transitions, initial=initial_state) + if len(self._state_machine_cache) > 100: + self._state_machine_cache.clear() + self._state_machine_cache[cache_key] = (sm, sequences) + + return sm, sequences + + def _is_batchable(self, args): + return self.model_provider.is_batchable and args.seed is None def _generate(self): current_model = None @@ -591,23 +690,6 @@ def get_next_request(timeout=None): else: return self._next_request(timeout) - def progress_callback(info): - for uid, processed, total in info: - if uid in batch_results: - batch_results[uid]["rqueue"].put((min(processed, total), total)) - - def checkpoint_callback(prompts): - for uid, prompt_end, cache in prompts: - rs = batch_results[uid] - if not rs["checkpoint"]: - continue - self.prompt_cache.insert_cache( - current_model_key, - rs["cache_key"][:-prompt_end], - list(cache), - cache_type="user", - ) - if self._is_distributed: seed = mx.distributed.all_sum(mx.random.state[0]).view(mx.uint64).item() mx.random.seed(seed) @@ -633,55 +715,59 @@ def checkpoint_callback(prompts): and self._is_batchable(args) ): try: - prompt = self._tokenize(current_tokenizer, request, args) + prompt, segments, segment_types, initial_state = self._tokenize( + current_tokenizer, request, args + ) except Exception as e: rqueue.put(e) continue - ctx = GenerationContext( - has_tool_calling=tokenizer.has_tool_calling, - tool_call_start=tokenizer.tool_call_start, - tool_call_end=tokenizer.tool_call_end, - tool_parser=tokenizer.tool_parser, - has_thinking=tokenizer.has_thinking, - think_start_id=tokenizer.think_start_id, - think_end=tokenizer.think_end, - think_end_id=tokenizer.think_end_id, - eos_token_ids=tokenizer.eos_token_ids, - stop_token_sequences=[ - tokenizer.encode(stop_word, add_special_tokens=False) - for stop_word in args.stop_words - ], - prompt=prompt, + sm, sequences = self._make_state_machine( + self.model_provider.model_key, + tokenizer, + args.stop_words, + initial_state, ) - rqueue.put(ctx) self._log_cache_stats() cache, rest = self.prompt_cache.fetch_nearest_cache( current_model_key, prompt ) - ctx.prompt_cache_count = len(prompt) - len(rest) - if cache is None: - cache = make_prompt_cache(self.model_provider.model) + prompt_cache_count = len(prompt) - len(rest) + N = prompt_cache_count + while N > 0: + if N >= len(segments[0]): + N -= len(segments.pop(0)) + segment_types.pop(0) + else: + segments[0] = segments[0][N:] + break - do_checkpoint, checkpoint_position = ( - self._compute_prompt_checkpoint(tokenizer, request, prompt) + ctx = GenerationContext( + has_tool_calling=tokenizer.has_tool_calling, + has_thinking=tokenizer.has_thinking, + tool_parser=tokenizer.tool_parser, + sequences=sequences, + prompt=prompt, + prompt_cache_count=prompt_cache_count, ) + rqueue.put(ctx) - (uid,) = batch_generator.insert( - [rest], - args.max_tokens, + (uid,) = batch_generator.insert_segments( + segments=[segments], + max_tokens=[args.max_tokens], caches=[cache], + all_tokens=[prompt[:prompt_cache_count]], samplers=[_make_sampler(args, tokenizer)], logits_processors=[_make_logits_processors(args)], - prompt_checkpoints=[checkpoint_position], + state_machines=[sm], ) batch_results[uid] = { "ctx": ctx, - "cache_key": prompt[:], "rqueue": rqueue, "detokenizer": tokenizer.detokenizer, - "checkpoint": do_checkpoint, + "segment_types": segment_types[::-1], + "top_logprobs": args.top_logprobs, } # just making sure we don't leave a reference around del cache @@ -714,12 +800,9 @@ def checkpoint_callback(prompts): batch_results = {} batch_generator = BatchGenerator( model, - stop_tokens=tokenizer.eos_token_ids, completion_batch_size=self.cli_args.decode_concurrency, prefill_batch_size=self.cli_args.prompt_concurrency, prefill_step_size=self.cli_args.prefill_step_size, - prompt_progress_callback=progress_callback, - prompt_checkpoint_callback=checkpoint_callback, ) unprocessed_requests.append((rqueue, request, args)) continue @@ -746,24 +829,50 @@ def checkpoint_callback(prompts): uids_to_remove = [] for _ in self._time_budget: - responses = batch_generator.next() - if not responses: + prompt_responses, gen_responses = batch_generator.next() + if not prompt_responses and not gen_responses: break - for r in responses: + # Progress report for prompt processing + for r in prompt_responses: result = batch_results[r.uid] - result["cache_key"].append(r.token) - if r.finish_reason != "stop": - result["detokenizer"].add_token(r.token) + result["rqueue"].put(r.progress) + if result["ctx"]._should_stop: + uids_to_remove.append(r.uid) + # Save the caches at end of segments + eos_ids = [ + r.uid + for r in prompt_responses + if r.end_of_segment + and not r.end_of_prompt + and batch_results[r.uid]["segment_types"] + ] + caches = batch_generator.extract_cache(eos_ids) + for uid, (cache, cache_key) in caches.items(): + self.prompt_cache.insert_cache( + self.model_provider.model_key, + cache_key[:], + cache, + cache_type=batch_results[uid]["segment_types"].pop(), + ) + del caches + + for r in gen_responses: + result = batch_results[r.uid] + result["detokenizer"].add_token(r.token) result["rqueue"].put( Response( result["detokenizer"].last_segment, r.token, + r.current_state, + r.match_sequence, r.logprobs[r.token].item(), r.finish_reason, _format_top_logprobs( - r.logprobs, args.top_logprobs, current_tokenizer + r.logprobs, + result["top_logprobs"], + current_tokenizer, ), ) ) @@ -771,7 +880,10 @@ def checkpoint_callback(prompts): if r.finish_reason is not None: result["rqueue"].put(None) self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], r.prompt_cache + current_model_key, + r.all_tokens[:], + r.prompt_cache, + cache_type="assistant", ) del batch_results[r.uid] @@ -781,17 +893,11 @@ def checkpoint_callback(prompts): uids_to_remove = self._share_object(uids_to_remove) if uids_to_remove: with mx.stream(generation_stream): - caches = batch_generator.remove( - uids_to_remove, return_prompt_caches=True - ) - for uid, prompt_cache in caches.items(): - if uid not in batch_results: - continue - result = batch_results[uid] - self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], prompt_cache - ) - del batch_results[uid] + batch_generator.remove(uids_to_remove) + for uid in uids_to_remove: + # It may have already been removed during + # generation + batch_results.pop(uid, None) def _serve_single(self, request): rqueue, request, args = request @@ -806,24 +912,22 @@ def progress(tokens_processed, tokens_total): tokenizer = self.model_provider.tokenizer draft_model = self.model_provider.draft_model - # Prepare the prompt - prompt = self._tokenize(tokenizer, request, args) + # Prepare the prompt and state machine + prompt, _, _, initial_state = self._tokenize(tokenizer, request, args) + sm, sequences = self._make_state_machine( + self.model_provider.model_key, + tokenizer, + args.stop_words, + initial_state=initial_state, + ) + sm_state = sm.make_state() # Start the generation context ctx = GenerationContext( + has_thinking=tokenizer.has_thinking, has_tool_calling=tokenizer.has_tool_calling, - tool_call_start=tokenizer.tool_call_start, - tool_call_end=tokenizer.tool_call_end, tool_parser=tokenizer.tool_parser, - has_thinking=tokenizer.has_thinking, - think_start_id=tokenizer.think_start_id, - think_end=tokenizer.think_end, - think_end_id=tokenizer.think_end_id, - eos_token_ids=tokenizer.eos_token_ids, - stop_token_sequences=[ - tokenizer.encode(stop_word, add_special_tokens=False) - for stop_word in args.stop_words - ], + sequences=sequences, prompt=prompt, ) rqueue.put(ctx) @@ -862,12 +966,18 @@ def progress(tokens_processed, tokens_total): prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, ): + finish_reason = gen.finish_reason + sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) + if match_sequence is not None and current_state is None: + finish_reason = "stop" rqueue.put( Response( gen.text, gen.token, + current_state, + match_sequence, gen.logprobs[gen.token].item(), - gen.finish_reason, + finish_reason, _format_top_logprobs( gen.logprobs, args.top_logprobs, tokenizer ), @@ -880,6 +990,9 @@ def progress(tokens_processed, tokens_total): raise NotImplementedError() break + if finish_reason is not None: + break + rqueue.put(None) # Save the KV cache again @@ -912,11 +1025,25 @@ def _inner(): continue yield response + def _process_control_tokens(ctx, token_stream): + buffer_size = max(len(s) for s in ctx.sequences) + buffered_stream = deque() + + for tok in token_stream: + buffered_stream.append(tok) + if tok.match is not None: + for _ in tok.match: + buffered_stream.pop() + if len(buffered_stream) >= buffer_size: + yield buffered_stream.popleft() + while len(buffered_stream) > 0: + yield buffered_stream.popleft() + ctx = response_queue.get() if isinstance(ctx, Exception): raise ctx - return ctx, _inner() + return ctx, _process_control_tokens(ctx, _inner()) @property def cli_args(self): @@ -1011,11 +1138,18 @@ def do_POST(self): ) return - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}") - assert isinstance( - self.body, dict - ), f"Request should be dict, but got {type(self.body)}" + if logging.getLogger().isEnabledFor(logging.DEBUG): + debug_body = json.dumps(self.body, indent="\t") + logging.debug(f"Incoming Request Body: {debug_body}") + if not isinstance(self.body, dict): + debug_body = json.dumps(self.body, indent="\t") + logging.error(f"Invalid Request Body: {debug_body}") + self._set_completion_headers(400) + self.end_headers() + self.wfile.write( + json.dumps({"error": "Request should be a JSON dictionary"}).encode() + ) + return # Extract request parameters from the body self.stream = self.body.get("stream", False) @@ -1061,87 +1195,60 @@ def do_POST(self): request = request_factories[self.path]() self.handle_completion(request, stop_words) - def validate_model_parameters(self): - """ - Validate the model parameters passed in the request for the correct types and values. - """ - if not isinstance(self.stream, bool): - raise ValueError("stream must be a boolean") - - if not isinstance(self.max_tokens, int) or self.max_tokens < 0: - raise ValueError("max_tokens must be a non-negative integer") - - if not isinstance(self.temperature, (float, int)) or self.temperature < 0: - raise ValueError("temperature must be a non-negative float") - - if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1: - raise ValueError("top_p must be a float between 0 and 1") - - if not isinstance(self.top_k, int) or self.top_k < 0: - raise ValueError("top_k must be a non-negative integer") - - if not isinstance(self.min_p, (float, int)) or self.min_p < 0 or self.min_p > 1: - raise ValueError("min_p must be a float between 0 and 1") - - if not isinstance(self.num_draft_tokens, int) or self.num_draft_tokens < 0: - raise ValueError("num_draft_tokens must be a non-negative integer") - - if ( - not isinstance(self.repetition_penalty, (float, int)) - or self.repetition_penalty < 0 - ): - raise ValueError("repetition_penalty must be a non-negative float") - if ( - not isinstance(self.repetition_context_size, int) - or self.repetition_context_size < 0 - ): - raise ValueError("repetition_context_size must be a non-negative integer") - if not isinstance(self.presence_penalty, (float, int)): - raise ValueError("Presence penalty must be must be a float") - if ( - not isinstance(self.presence_context_size, int) - or self.presence_context_size < 0 - ): - raise ValueError("presence_context_size must be a non-negative integer") - if not isinstance(self.frequency_penalty, (float, int)): - raise ValueError("Presence penalty must be must be a float") - if ( - not isinstance(self.frequency_context_size, int) - or self.frequency_context_size < 0 - ): - raise ValueError("frequency_context_size must be a non-negative integer") - - if not isinstance(self.logprobs, bool): - raise ValueError("logprobs must be a boolean") + def _validate( + self, + name, + expected_type, + min_val=None, + max_val=None, + optional=False, + whitelist=None, + ): + value = getattr(self, name) + if optional and value is None: + return + if not isinstance(value, expected_type): + try: + allowed = tuple(et.__name__ for et in expected_type) + except TypeError: + allowed = expected_type.__name__ + raise ValueError(f"{name} must be of type {allowed}") + if whitelist is not None and value in whitelist: + return + if min_val is not None and value < min_val: + raise ValueError(f"{name} must be at least {min_val}") + if max_val is not None and value > max_val: + raise ValueError(f"{name} must be at most {max_val}") - if self.top_logprobs != -1 and not (0 < self.top_logprobs <= 10): - raise ValueError( - f"top_logprobs must be between 1 and 10 but got {self.top_logprobs:,}" - ) + def validate_model_parameters(self): + """Validate that the passed model parameters have correct types and values.""" + self._validate("stream", bool) + self._validate("max_tokens", int, min_val=0) + self._validate("temperature", (float, int), min_val=0) + self._validate("top_p", (float, int), min_val=0, max_val=1) + self._validate("top_k", int, min_val=0) + self._validate("min_p", (float, int), min_val=0, max_val=1) + self._validate("num_draft_tokens", int, min_val=0) + self._validate("repetition_penalty", (float, int), min_val=0) + self._validate("repetition_context_size", int, min_val=0) + self._validate("presence_penalty", (float, int)) + self._validate("presence_context_size", int, min_val=0) + self._validate("frequency_penalty", (float, int)) + self._validate("frequency_context_size", int, min_val=0) + self._validate("logprobs", bool) + self._validate("top_logprobs", int, min_val=0, max_val=11, whitelist=[-1]) + self._validate("xtc_probability", float, min_val=0, max_val=1) + self._validate("xtc_threshold", float, min_val=0, max_val=1) + self._validate("requested_model", str) + self._validate("adapter", str, optional=True) + self._validate("seed", int, optional=True) + self._validate("logit_bias", dict, optional=True) if self.logit_bias is not None: - if not isinstance(self.logit_bias, dict): - raise ValueError("logit_bias must be a dict of int to float") - try: - self.logit_bias = {int(k): v for k, v in self.logit_bias.items()} + self.logit_bias = {int(k): float(v) for k, v in self.logit_bias.items()} except ValueError: raise ValueError("logit_bias must be a dict of int to float") - if not ( - isinstance(self.xtc_probability, float) - and 0.00 <= self.xtc_probability <= 1.00 - ): - raise ValueError(f"xtc_probability must be a float between 0.00 and 1.00") - if not ( - isinstance(self.xtc_threshold, float) and 0.00 <= self.xtc_threshold <= 0.50 - ): - raise ValueError(f"xtc_threshold must be a float between 0.00 and 0.5") - if not isinstance(self.requested_model, str): - raise ValueError("model must be a string") - if self.adapter is not None and not isinstance(self.adapter, str): - raise ValueError("adapter must be a string") - if self.seed is not None and not isinstance(self.seed, int): - raise ValueError("seed must be an integer") def generate_response( self, @@ -1257,8 +1364,7 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): Args: prompt (List[int]): The tokenized prompt. - stop_words (List[str]): A list of stop words passed to the - stopping_criteria function + stop_words (List[str]): A list of stop words """ args = GenerationArguments( model=ModelDescription( @@ -1292,21 +1398,14 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): chat_template_kwargs=self.chat_template_kwargs, ) - # Create keepalive callback to send SSE comments during long prompt processing - def keepalive_callback(processed_tokens, total_tokens): - logging.info( - f"Prompt processing progress: {processed_tokens}/{total_tokens}" - ) + # Keep connection allive during long prompt processing (and also log + # the progress) + def keepalive_callback(processed, total): + logging.info(f"Prompt processing progress: {processed}/{total}") if self.stream: - try: - # Send SSE comment for keepalive - invisible to clients but keeps connection alive - self.wfile.write( - f": keepalive {processed_tokens}/{total_tokens}\n\n".encode() - ) - self.wfile.flush() - except (BrokenPipeError, ConnectionResetError, OSError): - # Client disconnected, ignore - pass + msg = f": keepalive {processed}/{total}\n\n".encode() + self.wfile.write(msg) + self.wfile.flush() # Create the token generator try: @@ -1318,7 +1417,7 @@ def keepalive_callback(processed_tokens, total_tokens): except Exception as e: self._set_completion_headers(404) self.end_headers() - self.wfile.write(json.dumps({"error": f"{e}"}).encode()) + self.wfile.write(json.dumps({"error": str(e)}).encode()) return # Prepare the headers @@ -1330,184 +1429,120 @@ def keepalive_callback(processed_tokens, total_tokens): self._set_completion_headers(200) logging.debug("Starting completion:") - # Variables to save the tool calls in as they are being generated by - # the model. - in_tool_call = False - made_tool_call = False - tool_calls = [] - tool_text = "" - tool_idx = 0 + # Tool call formatter + tool_formatter = ToolCallFormatter(ctx.tool_parser, request.tools, self.stream) - def format_tool_call(tool_call): - nonlocal tool_idx - tool_call_id = tool_call.pop("id", None) or str(uuid.uuid4()) - tool_call["arguments"] = json.dumps( - tool_call["arguments"], ensure_ascii=False - ) - out = { - "function": tool_call, - "type": "function", - "id": tool_call_id, - } - if self.stream: - out["index"] = tool_idx - tool_idx += 1 - return out - - def parse_tools(tool_calls): - if not tool_calls: - return [] - result = [] - for tool_text in tool_calls: - parsed = ctx.tool_parser(tool_text, request.tools) - if isinstance(parsed, list): - result.extend(format_tool_call(tc) for tc in parsed) - else: - result.append(format_tool_call(parsed)) - return result - - # Start out in reasoning if the model is a reasoning model and the - # prompt has an open think token but no closing think token - in_reasoning = False - if ctx.has_thinking: - for i in range(len(ctx.prompt) - 1, -1, -1): - if ctx.prompt[i] == ctx.think_end_id: - break - elif ctx.prompt[i] == ctx.think_start_id: - in_reasoning = True - break + # Variables to save the generated text, tokens, logprobs, tools etc + prev_state = None + finish_reason = "stop" reasoning_text = "" - - # Variables to save the generated tokens and the corresponding probs + made_tool_call = False + tool_text = "" + tool_calls = [] + text = "" tokens = [] token_logprobs = [] top_tokens = [] - # Variables to save the generated text - text = "" - segment = "" - - # Well finally save the reason for stopping - finish_reason = "length" - # Process the generated tokens - for gen in response: - logging.debug(gen.text) - - # Gather the text in tool calling or text variables - if in_reasoning: - if gen.text == ctx.think_end: - in_reasoning = False - else: + try: + for gen in response: + logging.debug(gen.text) + + # Collect the text according to our current state and state + # transitions. Reasoning or tool or normal text. + if gen.state == "reasoning": reasoning_text += gen.text - elif ctx.has_tool_calling and gen.text == ctx.tool_call_start: - made_tool_call = True - in_tool_call = True - elif in_tool_call: - if gen.text == ctx.tool_call_end: - tool_calls.append(tool_text) - tool_text = "" - in_tool_call = False - else: + elif gen.state == "tool": tool_text += gen.text - else: - text += gen.text - segment += gen.text - - # Save the token and its logprob - tokens.append(gen.token) - if args.logprobs: - token_logprobs.append(gen.logprob) - - # If requested save the k top logprobs - if args.top_logprobs > 0: - top_tokens.append(gen.top_tokens) - - # Check if we should stop early - stop_condition = stopping_criteria( - tokens, - ctx.eos_token_ids, - ctx.stop_token_sequences, - stop_words, - ) - if stop_condition.stop_met: - finish_reason = "tool_calls" if made_tool_call else "stop" - ctx.stop() - tokens = tokens[: len(tokens) - stop_condition.trim_length] - text = text[: len(text) - stop_condition.trim_text_length] - segment = "" - break + elif gen.state == "normal": + if prev_state == "tool": + tool_calls.append(tool_text) + tool_text = "" + made_tool_call = True + text += gen.text + + # Add the tokens and logprobs to the vars. + tokens.append(gen.token) + if args.logprobs: + token_logprobs.append(gen.logprob) + if args.top_logprobs > 0: + top_tokens.append(gen.top_tokens) - if self.stream and not in_tool_call: - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - ( - sequence_overlap(tokens, sequence) - for sequence in ctx.stop_token_sequences - ) + if ( + self.stream + and gen.state != "tool" + and (text or tool_calls or reasoning_text) ): - continue - elif segment or tool_calls or reasoning_text: - response = self.generate_response( - segment, + resp = self.generate_response( + text, None, - tool_calls=parse_tools(tool_calls), + tool_calls=tool_formatter(tool_calls), reasoning_text=reasoning_text, ) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) self.wfile.flush() reasoning_text = "" - segment = "" + text = "" tool_calls = [] - if gen.finish_reason is not None: - finish_reason = gen.finish_reason + if gen.finish_reason is not None: + finish_reason = gen.finish_reason - # Flush any remaining tool text (e.g. when tool_call_end is empty) - if in_tool_call and tool_text: - tool_calls.append(tool_text) + prev_state = gen.state - if self.stream: - response = self.generate_response( - segment, - finish_reason, - tool_calls=parse_tools(tool_calls), - reasoning_text=reasoning_text, - ) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response( + if prev_state == "tool" and tool_text: + tool_calls.append(tool_text) + made_tool_call = True + + if finish_reason == "stop" and made_tool_call: + finish_reason = "tool_calls" + + if self.stream: + resp = self.generate_response( + text, + finish_reason, + tool_calls=tool_formatter(tool_calls), + reasoning_text=reasoning_text, + ) + self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) + self.wfile.flush() + if ( + self.stream_options is not None + and self.stream_options["include_usage"] + ): + resp = self.completion_usage_response( + len(ctx.prompt), + len(tokens), + ctx.prompt_cache_count, + ) + self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode()) + self.wfile.flush() + else: + resp = self.generate_response( + text, + finish_reason, len(ctx.prompt), len(tokens), ctx.prompt_cache_count, + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + reasoning_text=reasoning_text, + tool_calls=tool_formatter(tool_calls), ) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + if logging.getLogger().isEnabledFor(logging.DEBUG): + response_debug = json.dumps(resp, indent="\t") + logging.debug(f"Outgoing Response: {response_debug}") + + response_json = json.dumps(resp).encode() + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + self.wfile.write(response_json) self.wfile.flush() - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() - else: - response = self.generate_response( - text, - finish_reason, - len(ctx.prompt), - len(tokens), - ctx.prompt_cache_count, - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - reasoning_text=reasoning_text, - tool_calls=parse_tools(tool_calls), - ) - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - - # Send an additional Content-Length header when it is known - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - self.wfile.write(response_json) - self.wfile.flush() + finally: + ctx.stop() def completion_usage_response( self, diff --git a/tests/test_generate.py b/tests/test_generate.py index fee5801a6..2a9a4b10c 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -9,6 +9,7 @@ from mlx_lm.generate import ( BatchGenerator, GenerationResponse, + SequenceStateMachine, batch_generate, generate, generate_step, @@ -199,7 +200,7 @@ def test_batch_matches_single(self): self.model, stop_tokens=self.tokenizer.eos_token_ids, max_tokens=1 ) uids = gen.insert(prompts) - batch_responses = {r.uid: r for r in gen.next()} + batch_responses = {r.uid: r for r in gen.next_generated()} # Do a test for each prompt the logits are close for e, prompt in enumerate(prompts): @@ -241,7 +242,7 @@ def test_many_batches(self): batch_responses = {} not_in = True iters = 0 - while responses := gen.next(): + while responses := gen.next_generated(): for r in responses: not_in &= r.uid not in batch_responses batch_responses[r.uid] = r @@ -289,7 +290,7 @@ def test_batch_unique_max_toks(self): num_toks = [2, 3, 4, 5] uids = gen.insert(prompts, max_tokens=num_toks) batch_responses = {uid: [] for uid in uids} - while responses := gen.next(): + while responses := gen.next_generated(): for r in responses: batch_responses[r.uid].append(r.token) @@ -337,7 +338,7 @@ def test_batch_sliding_window(self): ) uids = batch_gen.insert(prompts) batch_responses = {uid: [] for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: batch_responses[r.uid].append(r.logprobs) @@ -370,7 +371,7 @@ def test_batch_generate_with_logits_processors(self): ) prompt = self.tokenizer.encode("hello") uids = batch_gen.insert([prompt]) - response = batch_gen.next()[0] + response = batch_gen.next_generated()[0] logprobs = response.logprobs self.assertEqual(logprobs[0].item(), 0.0) self.assertEqual(logprobs.argmin().item(), 1) @@ -395,7 +396,7 @@ def test_batch_generate_with_logits_processors(self): processors = make_logits_processors(logit_bias) (uid2,) = batch_gen.insert([prompt], logits_processors=[processors]) - responses = batch_gen.next() + responses = batch_gen.next_generated() responses = {response.uid: response for response in responses} self.assertEqual(responses[uid0].logprobs[0].item(), 0.0) self.assertEqual(responses[uid1].logprobs[1].item(), 0.0) @@ -410,7 +411,7 @@ def test_batch_generate_with_samplers(self): ) prompt = self.tokenizer.encode("hello") uids = batch_gen.insert([prompt]) - response = batch_gen.next()[0] + response = batch_gen.next_generated()[0] self.assertEqual(response.token, 1) del batch_gen @@ -427,12 +428,47 @@ def test_batch_generate_with_samplers(self): samplers=[lambda _: mx.array([2]), lambda _: mx.array([3])], ) - responses = batch_gen.next() + responses = batch_gen.next_generated() responses = {response.uid: response for response in responses} self.assertEqual(responses[uid0].token, 1) self.assertEqual(responses[uid1].token, 2) self.assertEqual(responses[uid2].token, 3) + def test_batch_generate_with_state_machines(self): + """Test that batch_generate with per-sequence state_machines stops on different tokens.""" + batch_gen = BatchGenerator( + self.model, + max_tokens=10, + ) + prompt = self.tokenizer.encode("hello") + + sm_0 = SequenceStateMachine({"normal": [([0], None)]}, initial="normal") + sm_1 = SequenceStateMachine({"normal": [([1], None)]}, initial="normal") + sm_2 = SequenceStateMachine({"normal": [([2], None)]}, initial="normal") + + processor_0 = make_logits_processors({0: 2000.0}) + processor_1 = make_logits_processors({1: 2000.0}) + processor_2 = make_logits_processors({2: 2000.0}) + + uid0, uid1, uid2 = batch_gen.insert( + [prompt, prompt, prompt], + logits_processors=[processor_0, processor_1, processor_2], + state_machines=[sm_0, sm_1, sm_2], + ) + + responses = batch_gen.next_generated() + responses = {response.uid: response for response in responses} + + self.assertEqual(responses[uid0].token, 0) + self.assertEqual(responses[uid1].token, 1) + self.assertEqual(responses[uid2].token, 2) + self.assertEqual(responses[uid0].finish_reason, "stop") + self.assertEqual(responses[uid1].finish_reason, "stop") + self.assertEqual(responses[uid2].finish_reason, "stop") + self.assertEqual(responses[uid0].match_sequence, (0,)) + self.assertEqual(responses[uid1].match_sequence, (1,)) + self.assertEqual(responses[uid2].match_sequence, (2,)) + def test_batch_continued_generation(self): for rotating in [False, True]: if rotating: @@ -481,7 +517,7 @@ def test_batch_continued_generation(self): ) uids = batch_gen.insert(prompts_a) caches = {uid: None for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: if r.finish_reason is not None: caches[r.uid] = r.prompt_cache @@ -490,7 +526,7 @@ def test_batch_continued_generation(self): # Generate the 2nd time uids = batch_gen.insert(prompts_b, caches=caches) batch_responses = {uid: [] for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: batch_responses[r.uid].append(r.logprobs) @@ -543,7 +579,7 @@ def rand_prompt(n): uids = batch_gen.insert(prompts_a) caches = {uid: None for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: if r.finish_reason is not None: caches[r.uid] = r.prompt_cache @@ -553,7 +589,7 @@ def rand_prompt(n): # Generate the 2nd time uids = batch_gen.insert(prompts_b, caches=caches) batch_responses = {uid: [] for uid in uids} - while responses := batch_gen.next(): + while responses := batch_gen.next_generated(): for r in responses: batch_responses[r.uid].append(r.logprobs) @@ -632,6 +668,57 @@ def test_batch_continued_generation_gated_delta(self): model = qwen3_next.Model(args) self._continued_generation_test_helper(model) + def test_extend_cache_with_empty(self): + from mlx_lm.generate import _extend_cache + from mlx_lm.models.cache import make_prompt_cache + + cache_a = make_prompt_cache(self.model) + + prompt = mx.array([[1, 2, 3]]) + self.model(prompt, cache=cache_a) + mx.eval([c.state for c in cache_a]) + + result = _extend_cache(cache_a, []) + self.assertEqual(len(result), len(cache_a)) + for c in result: + self.assertGreater(c.offset, 0) + + result = _extend_cache([], cache_a) + self.assertEqual(len(result), len(cache_a)) + for c in result: + self.assertGreater(c.offset, 0) + + def test_remove_prompt_batch_updates_currently_processing(self): + prompt_a = self.tokenizer.encode("Write a long story about a cat") + prompt_b = self.tokenizer.encode("Write a long story about a dog") + + gen = BatchGenerator( + self.model, + max_tokens=5, + prefill_batch_size=2, + prefill_step_size=4, + completion_batch_size=4, + ) + uid_a, uid_b = gen.insert([prompt_a, prompt_b]) + + gen.next() + + found = gen._find_uids([uid_a, uid_b]) + for uid in [uid_a, uid_b]: + self.assertIn(uid, found) + self.assertEqual(found[uid][0], 1) + + gen.remove([uid_a]) + + self.assertEqual(len(gen._currently_processing), len(gen._prompt_batch)) + + found = gen._find_uids([uid_b]) + self.assertIn(uid_b, found) + + while responses := gen.next_generated(): + if all(r.finish_reason is not None for r in responses): + break + if __name__ == "__main__": unittest.main() diff --git a/tests/test_server.py b/tests/test_server.py index b6907ac34..48141c074 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -218,18 +218,6 @@ def test_handle_models(self): self.assertEqual(model["object"], "model") self.assertIn("created", model) - def test_sequence_overlap(self): - from mlx_lm.server import sequence_overlap - - self.assertTrue(sequence_overlap([1], [1])) - self.assertTrue(sequence_overlap([1, 2], [1, 2])) - self.assertTrue(sequence_overlap([1, 3], [3, 4])) - self.assertTrue(sequence_overlap([1, 2, 3], [2, 3])) - - self.assertFalse(sequence_overlap([1], [2])) - self.assertFalse(sequence_overlap([1, 2], [3, 4])) - self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3])) - class TestServerWithDraftModel(unittest.TestCase): @classmethod