diff --git a/vllm_mlx/api/tool_logits.py b/vllm_mlx/api/tool_logits.py index fbe29c49..854af08d 100644 --- a/vllm_mlx/api/tool_logits.py +++ b/vllm_mlx/api/tool_logits.py @@ -96,6 +96,8 @@ def __init__(self, tokenizer: Any, bias_strength: float = 20.0): self._active_pattern: str | None = None self._pattern_pos = 0 # Position within active pattern's token sequence self._last_param_close_pos = -1 # Track last position to avoid re-triggering + self._consecutive_bias_count = 0 # Safety: escape hatch for stuck patterns + self._max_consecutive_bias = 50 # Max tokens to bias before force-resetting def reset(self) -> None: """Reset state for a new generation.""" @@ -103,6 +105,7 @@ def reset(self) -> None: self._active_pattern = None self._pattern_pos = 0 self._last_param_close_pos = -1 + self._consecutive_bias_count = 0 def __call__(self, token_ids: Any, logits: Any) -> Any: """ @@ -126,6 +129,17 @@ def __call__(self, token_ids: Any, logits: Any) -> Any: if not id_list: return logits + # Safety: escape hatch if stuck in a bias loop + if self._consecutive_bias_count >= self._max_consecutive_bias: + logger.warning( + "Tool logits processor hit max consecutive bias limit " + f"({self._max_consecutive_bias}), resetting state" + ) + self._active_pattern = None + self._pattern_pos = 0 + self._consecutive_bias_count = 0 + return logits + # Decode last token to update recent text last_token_text = self.tokenizer.decode( [id_list[-1]], skip_special_tokens=False @@ -141,6 +155,7 @@ def __call__(self, token_ids: Any, logits: Any) -> Any: if self._pattern_pos < len(pattern_tokens): target_token = pattern_tokens[self._pattern_pos] self._pattern_pos += 1 + self._consecutive_bias_count += 1 # Add bias to the expected token bias = mx.zeros_like(logits) @@ -154,8 +169,12 @@ def __call__(self, token_ids: Any, logits: Any) -> Any: # re-activating on stale _recent_text self._active_pattern = None self._pattern_pos = 0 + self._consecutive_bias_count = 0 return logits + # Not biasing — reset counter + self._consecutive_bias_count = 0 + # Check if we should start tracking a pattern for pattern, trigger in self.PATTERNS: if trigger and self._recent_text.rstrip().endswith(trigger): @@ -166,6 +185,7 @@ def __call__(self, token_ids: Any, logits: Any) -> Any: # Bias first token target_token = pattern_tokens[0] self._pattern_pos = 1 + self._consecutive_bias_count = 1 bias = mx.zeros_like(logits) if logits.ndim == 2: diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 1c75fb3a..c95b8e98 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -6,10 +6,14 @@ integrating with vLLM's model execution system. """ +import copy import logging +import time from dataclasses import dataclass from typing import Iterator +import mlx.core as mx + logger = logging.getLogger(__name__) @@ -74,6 +78,11 @@ def __init__( self.draft_model = None self._loaded = False + # Prompt cache for KV reuse across requests + self._prompt_cache = None + self._cached_token_ids: list[int] = [] + self._cache_lock = False # Simple guard against concurrent use + def load(self) -> None: """Load the model and tokenizer.""" if self._loaded: @@ -214,6 +223,71 @@ def generate( finish_reason=finish_reason, ) + def _find_common_prefix_len(self, new_tokens: list[int]) -> int: + """Find the length of the common prefix between cached and new tokens.""" + common = 0 + limit = min(len(self._cached_token_ids), len(new_tokens)) + for i in range(limit): + if self._cached_token_ids[i] != new_tokens[i]: + break + common += 1 + return common + + def _save_cache_snapshot(self, token_ids: list[int]) -> None: + """Save a deep copy of the prompt cache state for future reuse.""" + if self._prompt_cache is None: + return + # Store the token IDs that correspond to this cache state + # The cache itself is the live object — we just track what's in it + self._cached_token_ids = list(token_ids) + + def _prepare_cache_for_prompt(self, prompt_token_ids: list[int]) -> list[int]: + """ + Prepare the prompt cache and return only the tokens that need processing. + + If the new prompt shares a prefix with the cached tokens, trim the cache + to the common prefix and return only the suffix tokens. + + The cache may contain more entries than _cached_token_ids because + generated tokens from the previous call are also in the cache. + We must trim based on actual cache offset, not just tracked token count. + + Returns: + Token IDs that still need to be processed (the non-cached suffix). + """ + if self._prompt_cache is None: + # First call — create fresh cache + from mlx_lm.models.cache import make_prompt_cache + self._prompt_cache = make_prompt_cache(self.model) + self._cached_token_ids = [] + return prompt_token_ids + + common_len = self._find_common_prefix_len(prompt_token_ids) + + if common_len == 0: + # No overlap — reset cache entirely + for c in self._prompt_cache: + current = c.offset if hasattr(c, 'offset') else 0 + if current > 0: + c.trim(current) + self._cached_token_ids = [] + return prompt_token_ids + + # Trim cache to common prefix length. + # Cache offset = prompt_tokens + generated_tokens from last call, + # so we must trim (cache_offset - common_len), not just + # (cached_token_ids_len - common_len). + for c in self._prompt_cache: + current = c.offset if hasattr(c, 'offset') else 0 + to_trim = current - common_len + if to_trim > 0: + c.trim(to_trim) + self._cached_token_ids = self._cached_token_ids[:common_len] + + # Return only the suffix that needs processing + suffix = prompt_token_ids[common_len:] + return suffix + def stream_generate( self, prompt: str, @@ -224,7 +298,12 @@ def stream_generate( stop: list[str] | None = None, ) -> Iterator[StreamingOutput]: """ - Stream text generation token by token. + Stream text generation token by token with KV cache reuse. + + Maintains a persistent prompt cache across calls. When consecutive + requests share a common prefix (e.g. same system prompt + tools), + only the new suffix tokens are processed, dramatically reducing + prefill time. Args: prompt: Input prompt text @@ -242,6 +321,26 @@ def stream_generate( from mlx_lm import stream_generate + # Tokenize the full prompt + add_special_tokens = ( + self.tokenizer.bos_token is None + or not prompt.startswith(self.tokenizer.bos_token) + ) + full_token_ids = self.tokenizer.encode( + prompt, add_special_tokens=add_special_tokens + ) + + # Prepare cache and get only the tokens that need processing + suffix_tokens = self._prepare_cache_for_prompt(full_token_ids) + prefix_len = len(full_token_ids) - len(suffix_tokens) + + if prefix_len > 0 and len(suffix_tokens) < len(full_token_ids): + logger.info( + f"Prompt cache hit: {prefix_len} cached / " + f"{len(suffix_tokens)} new tokens " + f"(saved {prefix_len} tokens of prefill)" + ) + # Create sampler with parameters sampler = self._create_sampler(temperature, top_p) @@ -252,6 +351,7 @@ def stream_generate( gen_kwargs = { "max_tokens": max_tokens, "sampler": sampler, + "prompt_cache": self._prompt_cache, } # Add draft model for speculative decoding if available @@ -259,10 +359,23 @@ def stream_generate( gen_kwargs["draft_model"] = self.draft_model gen_kwargs["num_draft_tokens"] = self.num_draft_tokens + # Pass token IDs (not string) so mlx-lm skips re-tokenization. + # If suffix is empty (exact same prompt), we still need at least 1 token + # for generate_step. Pop the last token from cache and re-process it. + if not suffix_tokens: + if self._prompt_cache and full_token_ids: + for c in self._prompt_cache: + c.trim(1) + prompt_to_send = full_token_ids[-1:] + else: + prompt_to_send = full_token_ids + else: + prompt_to_send = suffix_tokens + for response in stream_generate( self.model, self.tokenizer, - prompt=prompt, + prompt=prompt_to_send, **gen_kwargs, ): token_count += 1 @@ -293,6 +406,11 @@ def stream_generate( if finished: break + # Save cache state: prompt tokens only (not generated tokens) + # The cache now has prompt + generated tokens; we save the prompt part + # so next request can match against it + self._save_cache_snapshot(full_token_ids) + def chat( self, messages: list[dict],