Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions vllm_mlx/api/tool_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,16 @@ def __init__(self, tokenizer: Any, bias_strength: float = 20.0):
self._active_pattern: str | None = None
self._pattern_pos = 0 # Position within active pattern's token sequence
self._last_param_close_pos = -1 # Track last </parameter> position to avoid re-triggering
self._consecutive_bias_count = 0 # Safety: escape hatch for stuck patterns
self._max_consecutive_bias = 50 # Max tokens to bias before force-resetting

def reset(self) -> None:
"""Reset state for a new generation."""
self._recent_text = ""
self._active_pattern = None
self._pattern_pos = 0
self._last_param_close_pos = -1
self._consecutive_bias_count = 0

def __call__(self, token_ids: Any, logits: Any) -> Any:
"""
Expand All @@ -126,6 +129,17 @@ def __call__(self, token_ids: Any, logits: Any) -> Any:
if not id_list:
return logits

# Safety: escape hatch if stuck in a bias loop
if self._consecutive_bias_count >= self._max_consecutive_bias:
logger.warning(
"Tool logits processor hit max consecutive bias limit "
f"({self._max_consecutive_bias}), resetting state"
)
self._active_pattern = None
self._pattern_pos = 0
self._consecutive_bias_count = 0
return logits

# Decode last token to update recent text
last_token_text = self.tokenizer.decode(
[id_list[-1]], skip_special_tokens=False
Expand All @@ -141,6 +155,7 @@ def __call__(self, token_ids: Any, logits: Any) -> Any:
if self._pattern_pos < len(pattern_tokens):
target_token = pattern_tokens[self._pattern_pos]
self._pattern_pos += 1
self._consecutive_bias_count += 1

# Add bias to the expected token
bias = mx.zeros_like(logits)
Expand All @@ -154,8 +169,12 @@ def __call__(self, token_ids: Any, logits: Any) -> Any:
# re-activating on stale _recent_text
self._active_pattern = None
self._pattern_pos = 0
self._consecutive_bias_count = 0
return logits

# Not biasing — reset counter
self._consecutive_bias_count = 0

# Check if we should start tracking a pattern
for pattern, trigger in self.PATTERNS:
if trigger and self._recent_text.rstrip().endswith(trigger):
Expand All @@ -166,6 +185,7 @@ def __call__(self, token_ids: Any, logits: Any) -> Any:
# Bias first token
target_token = pattern_tokens[0]
self._pattern_pos = 1
self._consecutive_bias_count = 1

bias = mx.zeros_like(logits)
if logits.ndim == 2:
Expand Down
122 changes: 120 additions & 2 deletions vllm_mlx/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
integrating with vLLM's model execution system.
"""

import copy
import logging
import time
from dataclasses import dataclass
from typing import Iterator

import mlx.core as mx

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -74,6 +78,11 @@ def __init__(
self.draft_model = None
self._loaded = False

# Prompt cache for KV reuse across requests
self._prompt_cache = None
self._cached_token_ids: list[int] = []
self._cache_lock = False # Simple guard against concurrent use

def load(self) -> None:
"""Load the model and tokenizer."""
if self._loaded:
Expand Down Expand Up @@ -214,6 +223,71 @@ def generate(
finish_reason=finish_reason,
)

def _find_common_prefix_len(self, new_tokens: list[int]) -> int:
"""Find the length of the common prefix between cached and new tokens."""
common = 0
limit = min(len(self._cached_token_ids), len(new_tokens))
for i in range(limit):
if self._cached_token_ids[i] != new_tokens[i]:
break
common += 1
return common

def _save_cache_snapshot(self, token_ids: list[int]) -> None:
"""Save a deep copy of the prompt cache state for future reuse."""
if self._prompt_cache is None:
return
# Store the token IDs that correspond to this cache state
# The cache itself is the live object — we just track what's in it
self._cached_token_ids = list(token_ids)

def _prepare_cache_for_prompt(self, prompt_token_ids: list[int]) -> list[int]:
"""
Prepare the prompt cache and return only the tokens that need processing.

If the new prompt shares a prefix with the cached tokens, trim the cache
to the common prefix and return only the suffix tokens.

The cache may contain more entries than _cached_token_ids because
generated tokens from the previous call are also in the cache.
We must trim based on actual cache offset, not just tracked token count.

Returns:
Token IDs that still need to be processed (the non-cached suffix).
"""
if self._prompt_cache is None:
# First call — create fresh cache
from mlx_lm.models.cache import make_prompt_cache
self._prompt_cache = make_prompt_cache(self.model)
self._cached_token_ids = []
return prompt_token_ids

common_len = self._find_common_prefix_len(prompt_token_ids)

if common_len == 0:
# No overlap — reset cache entirely
for c in self._prompt_cache:
current = c.offset if hasattr(c, 'offset') else 0
if current > 0:
c.trim(current)
self._cached_token_ids = []
return prompt_token_ids

# Trim cache to common prefix length.
# Cache offset = prompt_tokens + generated_tokens from last call,
# so we must trim (cache_offset - common_len), not just
# (cached_token_ids_len - common_len).
for c in self._prompt_cache:
current = c.offset if hasattr(c, 'offset') else 0
to_trim = current - common_len
if to_trim > 0:
c.trim(to_trim)
self._cached_token_ids = self._cached_token_ids[:common_len]

# Return only the suffix that needs processing
suffix = prompt_token_ids[common_len:]
return suffix

def stream_generate(
self,
prompt: str,
Expand All @@ -224,7 +298,12 @@ def stream_generate(
stop: list[str] | None = None,
) -> Iterator[StreamingOutput]:
"""
Stream text generation token by token.
Stream text generation token by token with KV cache reuse.

Maintains a persistent prompt cache across calls. When consecutive
requests share a common prefix (e.g. same system prompt + tools),
only the new suffix tokens are processed, dramatically reducing
prefill time.

Args:
prompt: Input prompt text
Expand All @@ -242,6 +321,26 @@ def stream_generate(

from mlx_lm import stream_generate

# Tokenize the full prompt
add_special_tokens = (
self.tokenizer.bos_token is None
or not prompt.startswith(self.tokenizer.bos_token)
)
full_token_ids = self.tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)

# Prepare cache and get only the tokens that need processing
suffix_tokens = self._prepare_cache_for_prompt(full_token_ids)
prefix_len = len(full_token_ids) - len(suffix_tokens)

if prefix_len > 0 and len(suffix_tokens) < len(full_token_ids):
logger.info(
f"Prompt cache hit: {prefix_len} cached / "
f"{len(suffix_tokens)} new tokens "
f"(saved {prefix_len} tokens of prefill)"
)

# Create sampler with parameters
sampler = self._create_sampler(temperature, top_p)

Expand All @@ -252,17 +351,31 @@ def stream_generate(
gen_kwargs = {
"max_tokens": max_tokens,
"sampler": sampler,
"prompt_cache": self._prompt_cache,
}

# Add draft model for speculative decoding if available
if self.draft_model is not None:
gen_kwargs["draft_model"] = self.draft_model
gen_kwargs["num_draft_tokens"] = self.num_draft_tokens

# Pass token IDs (not string) so mlx-lm skips re-tokenization.
# If suffix is empty (exact same prompt), we still need at least 1 token
# for generate_step. Pop the last token from cache and re-process it.
if not suffix_tokens:
if self._prompt_cache and full_token_ids:
for c in self._prompt_cache:
c.trim(1)
prompt_to_send = full_token_ids[-1:]
else:
prompt_to_send = full_token_ids
else:
prompt_to_send = suffix_tokens

for response in stream_generate(
self.model,
self.tokenizer,
prompt=prompt,
prompt=prompt_to_send,
**gen_kwargs,
):
token_count += 1
Expand Down Expand Up @@ -293,6 +406,11 @@ def stream_generate(
if finished:
break

# Save cache state: prompt tokens only (not generated tokens)
# The cache now has prompt + generated tokens; we save the prompt part
# so next request can match against it
self._save_cache_snapshot(full_token_ids)

def chat(
self,
messages: list[dict],
Expand Down