From 98d37ad983b5966e2eba7373f4efb57c7990388f Mon Sep 17 00:00:00 2001 From: Raullen Date: Wed, 25 Feb 2026 12:13:27 -0800 Subject: [PATCH] fix: Fix broken prompt cache + add TTFT optimizations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The prompt cache was saving state AFTER yielding the finished chunk, but the caller breaks before the generator resumes — so _save_cache_snapshot() never executed. Every request did full prefill regardless of cache state (only 7 template tokens matched). Fix: move _save_cache_snapshot() BEFORE the final yield. Results (10K token prompt): 127s → 0.32s (113x speedup on cache hit) Results (5K token prompt): 21s → 0.28s (58x speedup on cache hit) Partial cache hits also work: same system prompt + different user message → 0.45s vs 16.4s cold (36x speedup). Additional changes: - Add prompt_tokens to StreamingOutput, removing double tokenization in SimpleEngine (was encoding the full prompt twice per request) - Add --prefill-step-size CLI arg (tune prefill chunk size) - Add --kv-bits / --kv-group-size CLI args (KV cache quantization) - Add TTFT breakdown logging (tokenize, prefill, total times) - Set _reasoning_parser_name in cli.py (was only set in server.py) Co-Authored-By: Claude Opus 4.6 --- vllm_mlx/cli.py | 23 +++++++++++++++++++ vllm_mlx/engine/simple.py | 21 +++++++++++------ vllm_mlx/models/llm.py | 47 ++++++++++++++++++++++++++++++++++----- vllm_mlx/server.py | 32 ++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 12 deletions(-) diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index aa7ad428..65f4f930 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -84,6 +84,7 @@ def serve_command(args): parser_cls = get_parser(args.reasoning_parser) server._reasoning_parser = parser_cls() + server._reasoning_parser_name = args.reasoning_parser logger.info(f"Reasoning parser enabled: {args.reasoning_parser}") except KeyError as e: print(f"Error: {e}") @@ -216,6 +217,9 @@ def serve_command(args): gpu_memory_utilization=args.gpu_memory_utilization, draft_model=args.draft_model, num_draft_tokens=args.num_draft_tokens, + prefill_step_size=args.prefill_step_size, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, ) # Start server @@ -845,6 +849,25 @@ def main(): default=4, help="Number of tokens to generate speculatively per step (default: 4)", ) + serve_parser.add_argument( + "--prefill-step-size", + type=int, + default=2048, + help="Tokens to process per prefill chunk in simple mode (default: 2048)", + ) + serve_parser.add_argument( + "--kv-bits", + type=int, + default=None, + choices=[4, 8], + help="KV cache quantization bits for simple mode (4 or 8). Reduces memory for long contexts.", + ) + serve_parser.add_argument( + "--kv-group-size", + type=int, + default=64, + help="Group size for KV cache quantization in simple mode (default: 64)", + ) # Reasoning parser options - choices loaded dynamically from registry from .reasoning import list_parsers diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 5540ea24..bebe6b8f 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -43,6 +43,9 @@ def __init__( force_mllm: bool = False, draft_model: str | None = None, num_draft_tokens: int = 4, + prefill_step_size: int = 2048, + kv_bits: int | None = None, + kv_group_size: int = 64, ): """ Initialize the simple engine. @@ -54,6 +57,9 @@ def __init__( force_mllm: Force loading as MLLM even if not auto-detected draft_model: Optional draft model path for speculative decoding num_draft_tokens: Number of tokens to generate speculatively per step + prefill_step_size: Tokens to process per prefill chunk (default: 2048) + kv_bits: KV cache quantization bits (None=no quantization, 4 or 8) + kv_group_size: Group size for KV cache quantization (default: 64) """ self._model_name = model_name self._trust_remote_code = trust_remote_code @@ -61,6 +67,9 @@ def __init__( self._is_mllm = force_mllm or is_mllm_model(model_name) self._draft_model_name = draft_model self._num_draft_tokens = num_draft_tokens + self._prefill_step_size = prefill_step_size + self._kv_bits = kv_bits + self._kv_group_size = kv_group_size self._model = None self._loaded = False @@ -110,6 +119,9 @@ async def start(self) -> None: trust_remote_code=self._trust_remote_code, draft_model=self._draft_model_name, num_draft_tokens=self._num_draft_tokens, + prefill_step_size=self._prefill_step_size, + kv_bits=self._kv_bits, + kv_group_size=self._kv_group_size, ) self._model.load() @@ -207,8 +219,7 @@ async def stream_generate( async with self._generation_lock: accumulated_text = "" - # Compute prompt tokens upfront since StreamingOutput doesn't carry them - prompt_tokens = len(self._model.tokenizer.encode(prompt)) + prompt_tokens = 0 completion_tokens = 0 finished = False @@ -220,11 +231,7 @@ async def stream_generate( stop=stop, **kwargs, ): - prompt_tokens = ( - chunk.prompt_tokens - if hasattr(chunk, "prompt_tokens") - else prompt_tokens - ) + prompt_tokens = getattr(chunk, "prompt_tokens", 0) or prompt_tokens completion_tokens += 1 new_text = chunk.text if hasattr(chunk, "text") else str(chunk) accumulated_text += new_text diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 84738db0..3645134f 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -35,6 +35,7 @@ class StreamingOutput: finished: bool = False finish_reason: str | None = None logprobs: Any = None # mx.array of shape [vocab_size] from mlx-lm + prompt_tokens: int = 0 class MLXLanguageModel: @@ -57,6 +58,9 @@ def __init__( trust_remote_code: bool = False, draft_model: str | None = None, num_draft_tokens: int = 4, + prefill_step_size: int = 2048, + kv_bits: int | None = None, + kv_group_size: int = 64, ): """ Initialize the MLX language model. @@ -67,12 +71,18 @@ def __init__( trust_remote_code: Whether to trust remote code draft_model: Optional draft model path for speculative decoding num_draft_tokens: Number of tokens to generate speculatively per step + prefill_step_size: Tokens to process per prefill chunk (default: 2048) + kv_bits: KV cache quantization bits (None=no quantization, 4 or 8) + kv_group_size: Group size for KV cache quantization (default: 64) """ self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name self.trust_remote_code = trust_remote_code self.draft_model_name = draft_model self.num_draft_tokens = num_draft_tokens + self.prefill_step_size = prefill_step_size + self.kv_bits = kv_bits + self.kv_group_size = kv_group_size self.model = None self.tokenizer = None @@ -320,8 +330,12 @@ def stream_generate( if not self._loaded: self.load() + import time as _time + from mlx_lm import stream_generate + t0 = _time.perf_counter() + # Tokenize the full prompt add_special_tokens = ( self.tokenizer.bos_token is None @@ -331,6 +345,8 @@ def stream_generate( prompt, add_special_tokens=add_special_tokens ) + t_tokenize = _time.perf_counter() + # Prepare cache and get only the tokens that need processing suffix_tokens = self._prepare_cache_for_prompt(full_token_ids) prefix_len = len(full_token_ids) - len(suffix_tokens) @@ -341,6 +357,10 @@ def stream_generate( f"{len(suffix_tokens)} new tokens " f"(saved {prefix_len} tokens of prefill)" ) + else: + logger.info( + f"Prompt cache miss: {len(full_token_ids)} tokens to prefill" + ) # Create sampler with parameters sampler = self._create_sampler(temperature, top_p) @@ -353,8 +373,14 @@ def stream_generate( "max_tokens": max_tokens, "sampler": sampler, "prompt_cache": self._prompt_cache, + "prefill_step_size": self.prefill_step_size, } + # KV cache quantization reduces memory pressure for long prompts + if self.kv_bits is not None: + gen_kwargs["kv_bits"] = self.kv_bits + gen_kwargs["kv_group_size"] = self.kv_group_size + # Add draft model for speculative decoding if available if self.draft_model is not None: gen_kwargs["draft_model"] = self.draft_model @@ -373,6 +399,7 @@ def stream_generate( else: prompt_to_send = suffix_tokens + t_first_token = None for response in stream_generate( self.model, self.tokenizer, @@ -380,6 +407,15 @@ def stream_generate( **gen_kwargs, ): token_count += 1 + if token_count == 1: + t_first_token = _time.perf_counter() + logger.info( + f"TTFT breakdown: tokenize={t_tokenize - t0:.3f}s, " + f"prefill+decode={t_first_token - t_tokenize:.3f}s, " + f"total={t_first_token - t0:.3f}s " + f"(prompt={len(full_token_ids)} tokens, " + f"prefilled={len(prompt_to_send)} tokens)" + ) # response.text is the new token text (not accumulated) new_text = response.text accumulated_text += new_text @@ -396,6 +432,11 @@ def stream_generate( finish_reason = None if finished: finish_reason = "stop" if should_stop else "length" + # Save cache BEFORE yielding the finished chunk. + # The caller may break/abandon this generator after + # receiving the finished chunk, so code after yield + # would never execute. + self._save_cache_snapshot(full_token_ids) yield StreamingOutput( text=new_text, @@ -403,16 +444,12 @@ def stream_generate( finished=finished, finish_reason=finish_reason, logprobs=getattr(response, "logprobs", None), + prompt_tokens=len(full_token_ids), ) if finished: break - # Save cache state: prompt tokens only (not generated tokens) - # The cache now has prompt + generated tokens; we save the prompt part - # so next request can match against it - self._save_cache_snapshot(full_token_ids) - def chat( self, messages: list[dict], diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 72d943c3..ed473806 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -622,6 +622,9 @@ def load_model( gpu_memory_utilization: float = 0.90, draft_model: str | None = None, num_draft_tokens: int = 4, + prefill_step_size: int = 2048, + kv_bits: int | None = None, + kv_group_size: int = 64, ): """ Load a model (auto-detects MLLM vs LLM). @@ -637,6 +640,9 @@ def load_model( limit and emergency threshold (0.0-1.0, default 0.90) draft_model: Optional draft model for speculative decoding num_draft_tokens: Number of tokens to generate speculatively per step + prefill_step_size: Tokens to process per prefill chunk (default: 2048) + kv_bits: KV cache quantization bits (None=no quantization, 4 or 8) + kv_group_size: Group size for KV cache quantization (default: 64) """ global _engine, _model_name, _default_max_tokens, _tool_parser_instance @@ -688,6 +694,9 @@ def load_model( force_mllm=force_mllm, draft_model=draft_model, num_draft_tokens=num_draft_tokens, + prefill_step_size=prefill_step_size, + kv_bits=kv_bits, + kv_group_size=kv_group_size, ) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) @@ -2798,6 +2807,26 @@ def main(): default=4, help="Number of tokens to generate speculatively per step (default: 4)", ) + parser.add_argument( + "--prefill-step-size", + type=int, + default=2048, + help="Tokens to process per prefill chunk (default: 2048). " + "Larger values may improve TTFT on Apple Silicon with sufficient memory.", + ) + parser.add_argument( + "--kv-bits", + type=int, + default=None, + choices=[4, 8], + help="KV cache quantization bits (4 or 8). Reduces memory for long contexts.", + ) + parser.add_argument( + "--kv-group-size", + type=int, + default=64, + help="Group size for KV cache quantization (default: 64)", + ) args = parser.parse_args() @@ -2858,6 +2887,9 @@ def main(): force_mllm=args.mllm, draft_model=args.draft_model, num_draft_tokens=args.num_draft_tokens, + prefill_step_size=args.prefill_step_size, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, ) # Start server