diff --git a/tests/smoke_test_specdec.py b/tests/smoke_test_specdec.py new file mode 100644 index 00000000..452202c7 --- /dev/null +++ b/tests/smoke_test_specdec.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +Smoke test for speculative decoding with real models. + +Usage: python tests/smoke_test_specdec.py + +Uses Qwen3.5-35B-A3B-8bit as target, Qwen3.5-4B-4bit as draft. +Tests the SimpleEngine path (mlx_lm.stream_generate with draft_model). +""" + +import os +import sys +import time + +# Add project to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +TARGET = os.path.expanduser("~/ai-models/mlx_models/Qwen3.5-35B-A3B-8bit") +DRAFT = os.path.expanduser("~/ai-models/mlx_models/Qwen3.5-4B-4bit") +PROMPT = "What is the capital of France? Answer in one sentence." +MAX_TOKENS = 64 +NUM_DRAFT = 3 + + +def test_without_draft(): + """Baseline: generate without speculative decoding.""" + from mlx_lm import load, stream_generate + + print("=" * 60) + print("Loading target model (no draft)...") + model, tokenizer = load(TARGET) + print(f"Target loaded. Generating {MAX_TOKENS} tokens...") + + tokens = [] + t0 = time.perf_counter() + for resp in stream_generate(model, tokenizer, prompt=PROMPT, max_tokens=MAX_TOKENS): + tokens.append(resp.token) + elapsed = time.perf_counter() - t0 + text = tokenizer.decode(tokens) + print(f"Output ({len(tokens)} tokens, {len(tokens)/elapsed:.1f} tok/s):") + print(f" {text}") + print() + return len(tokens), elapsed + + +def test_with_draft(): + """Speculative: generate with draft model.""" + from mlx_lm import load, stream_generate + + print("=" * 60) + print("Loading target + draft model...") + model, tokenizer = load(TARGET) + draft_model, _ = load(DRAFT) + + # Verify vocab match — walk model structure to find embed_tokens + def _get_vocab_size(m): + for attr in ["model", "language_model"]: + sub = getattr(m, attr, None) + if sub is not None: + et = getattr(sub, "embed_tokens", None) + if et is not None: + return et.weight.shape[0] + return None + + target_vocab = _get_vocab_size(model) + draft_vocab = _get_vocab_size(draft_model) + print(f"Target vocab: {target_vocab}, Draft vocab: {draft_vocab}") + if target_vocab and draft_vocab: + assert target_vocab == draft_vocab, "Vocab size mismatch!" + + print(f"Generating {MAX_TOKENS} tokens with num_draft_tokens={NUM_DRAFT}...") + + tokens = [] + from_draft_count = 0 + t0 = time.perf_counter() + for resp in stream_generate( + model, + tokenizer, + prompt=PROMPT, + max_tokens=MAX_TOKENS, + draft_model=draft_model, + num_draft_tokens=NUM_DRAFT, + ): + tokens.append(resp.token) + if resp.from_draft: + from_draft_count += 1 + elapsed = time.perf_counter() - t0 + + text = tokenizer.decode(tokens) + accept_rate = from_draft_count / len(tokens) * 100 if tokens else 0 + print(f"Output ({len(tokens)} tokens, {len(tokens)/elapsed:.1f} tok/s):") + print(f" {text}") + print(f"Draft acceptance: {from_draft_count}/{len(tokens)} ({accept_rate:.0f}%)") + print() + return len(tokens), elapsed + + +if __name__ == "__main__": + print("Speculative Decoding Smoke Test") + print("Target:", TARGET) + print("Draft:", DRAFT) + print() + + n1, t1 = test_without_draft() + # Clear model from memory + import gc + import mlx.core as mx + + gc.collect() + mx.clear_cache() + + n2, t2 = test_with_draft() + + print("=" * 60) + print("RESULTS:") + print(f" Without draft: {n1} tokens in {t1:.2f}s ({n1/t1:.1f} tok/s)") + print(f" With draft: {n2} tokens in {t2:.2f}s ({n2/t2:.1f} tok/s)") + if t1 > 0 and t2 > 0: + speedup = (n1 / t1) / (n2 / t2) if n2 / t2 > 0 else 0 + print(f" Speedup: {1/speedup:.2f}x" if speedup > 0 else " N/A") diff --git a/tests/test_speculative_decoding.py b/tests/test_speculative_decoding.py new file mode 100644 index 00000000..3f9a8b72 --- /dev/null +++ b/tests/test_speculative_decoding.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for speculative decoding with a separate draft model (SimpleEngine path).""" + +import pytest + +try: + import mlx.core as mx # noqa: F401 + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +# --------------------------------------------------------------------------- +# Tests: CLI args +# --------------------------------------------------------------------------- + + +class TestCLIArgs: + def test_draft_model_arg_parsing(self): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--draft-model", type=str, default=None) + parser.add_argument("--num-draft-tokens", type=int, default=3) + + args = parser.parse_args( + ["--draft-model", "/path/to/model", "--num-draft-tokens", "5"] + ) + assert args.draft_model == "/path/to/model" + assert args.num_draft_tokens == 5 + + def test_default_num_draft_tokens(self): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--num-draft-tokens", type=int, default=3) + + args = parser.parse_args([]) + assert args.num_draft_tokens == 3 + + +# --------------------------------------------------------------------------- +# Tests: SimpleEngine draft model +# --------------------------------------------------------------------------- + + +class TestSimpleEngineDraftModel: + def test_draft_model_params_stored(self): + from vllm_mlx.engine.simple import SimpleEngine + + engine = SimpleEngine( + model_name="test-model", + draft_model_path="/path/to/draft", + num_draft_tokens=5, + ) + assert engine._draft_model_path == "/path/to/draft" + assert engine._num_draft_tokens == 5 + + def test_no_draft_model_by_default(self): + from vllm_mlx.engine.simple import SimpleEngine + + engine = SimpleEngine(model_name="test-model") + assert engine._draft_model_path is None + assert engine._num_draft_tokens == 3 + + +class TestMLXLanguageModelDraftModel: + def test_draft_model_params_stored(self): + from vllm_mlx.models.llm import MLXLanguageModel + + model = MLXLanguageModel( + model_name="test-model", + draft_model_path="/path/to/draft", + num_draft_tokens=5, + ) + assert model._draft_model_path == "/path/to/draft" + assert model._num_draft_tokens == 5 + assert model.draft_model is None + + def test_no_draft_model_by_default(self): + from vllm_mlx.models.llm import MLXLanguageModel + + model = MLXLanguageModel(model_name="test-model") + assert model._draft_model_path is None + assert model.draft_model is None diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628..c83519ee 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -11,16 +11,40 @@ LLM engine), so text-only requests must also be routed through it. """ +import asyncio import logging from collections.abc import AsyncIterator from typing import Any from ..api.tool_calling import convert_tools_for_template from ..api.utils import clean_output_text, extract_multimodal_content, is_mllm_model +from ..message_utils import _normalize_messages from .base import BaseEngine, GenerationOutput logger = logging.getLogger(__name__) +_MEDIA_TYPES = frozenset( + { + "image_url", + "video_url", + "audio_url", + "image", + "video", + "audio", + } +) + + +def _has_media_content(messages: list) -> bool: + """Check if any message contains media content (images, video, audio).""" + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") in _MEDIA_TYPES: + return True + return False + def _extract_media_from_messages(messages: list[dict[str, Any]]) -> tuple: """ @@ -137,6 +161,12 @@ def __init__( scheduler_config: Any | None = None, stream_interval: int = 1, force_mllm: bool = False, + mtp: bool = False, + prefill_step_size: int | None = None, + specprefill_enabled: bool = False, + specprefill_draft_model_path: str | None = None, + specprefill_threshold: int = 8192, + specprefill_keep_pct: float = 0.3, ): """ Initialize the batched engine. @@ -147,12 +177,28 @@ def __init__( scheduler_config: Optional scheduler configuration stream_interval: Tokens to batch before streaming (1=every token) force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable MTP per-request routing (text-only → TextModel, media → MLLM) + prefill_step_size: Chunk size for prompt prefill (default 2048) + specprefill_enabled: Enable SpecPrefill sparse prefill + specprefill_draft_model_path: Draft model directory name under ~/ai-models/mlx_models/ + specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill (default 8192) + specprefill_keep_pct: Fraction of tokens to keep (default 0.3) """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._scheduler_config = scheduler_config self._stream_interval = stream_interval self._is_mllm = force_mllm or is_mllm_model(model_name) + self._mtp = mtp + self._prefill_step_size = prefill_step_size or 2048 + + # SpecPrefill configuration + self._specprefill_enabled = specprefill_enabled + self._specprefill_draft_model_path = specprefill_draft_model_path + self._specprefill_threshold = specprefill_threshold + self._specprefill_keep_pct = specprefill_keep_pct + self._specprefill_lock = asyncio.Lock() + self._draft_model = None self._model = None self._processor = None # For MLLM @@ -162,6 +208,16 @@ def __init__( self._mllm_instance = None # MLXMultimodalLM instance self._loaded = False + # Per-request routing state (MLLM+MTP mode) + self._text_model = None + self._text_tokenizer = None + self._text_generation_lock = asyncio.Lock() + + # System prompt KV cache (reduces repeated prefill across requests) + self._system_kv_snapshot = None # List of (keys, values) per backbone layer + self._system_kv_hash = None # Hash of system prefix text + self._system_kv_token_count = 0 # Tokens in cached prefix + @property def model_name(self) -> str: """Get the model name.""" @@ -241,6 +297,73 @@ async def _start_mllm(self) -> None: f"completion_batch={completion_batch_size}" ) + # Build TextModel for MTP per-request routing (text-only → MTP, media → MLLM) + if self._mtp: + try: + from ..text_model_from_vlm import build_text_model + + self._text_model = build_text_model( + self._mllm_instance.model, self._model_name + ) + if self._text_model is not None: + # Get tokenizer from the MLLM instance (same model, shared tokenizer) + self._text_tokenizer = self._mllm_instance.get_tokenizer() + + # Apply Qwen3.5 eos_token fix (matches SimpleEngine pattern) + if "qwen3" in self._model_name.lower(): + self._text_tokenizer.eos_token = "<|im_end|>" + self._text_tokenizer.eos_token_id = ( + self._text_tokenizer.convert_tokens_to_ids("<|im_end|>") + ) + + # Check if TextModel actually has MTP + has_mtp = ( + hasattr(self._text_model, "mtp") + and self._text_model.mtp is not None + ) + if has_mtp: + logger.info( + "BatchedEngine MLLM+MTP routing: " + "text-only → TextModel (MTP), media → MLLM" + ) + else: + logger.warning( + "TextModel built but no MTP head — " + "text-only won't use MTP" + ) + self._text_model = None + self._text_tokenizer = None + except Exception as e: + logger.error(f"MTP TextModel build failed: {e}") + self._text_model = None + self._text_tokenizer = None + + # Load SpecPrefill draft model (for TextModel path — sparse cache + # is incompatible with MTP, so specprefill generates autoregressively) + if self._specprefill_enabled and self._specprefill_draft_model_path: + try: + from pathlib import Path + + from mlx_lm import load as mlx_lm_load + + draft_path = str( + Path.home() + / "ai-models" + / "mlx_models" + / self._specprefill_draft_model_path + ) + self._draft_model, _ = mlx_lm_load(draft_path) + logger.info( + "SpecPrefill draft model loaded: %s (threshold=%d, keep=%.0f%%)", + self._specprefill_draft_model_path, + self._specprefill_threshold, + self._specprefill_keep_pct * 100, + ) + except Exception as e: + logger.warning("Failed to load SpecPrefill draft model: %s", e) + self._specprefill_enabled = False + self._draft_model = None + async def _start_llm(self) -> None: """Start the LLM engine with AsyncEngineCore.""" from ..engine_core import AsyncEngineCore, EngineConfig @@ -327,6 +450,12 @@ async def stop(self) -> None: self._tokenizer = None self._processor = None self._mllm_instance = None + self._text_model = None + self._text_tokenizer = None + self._draft_model = None + self._system_kv_snapshot = None + self._system_kv_hash = None + self._system_kv_token_count = 0 self._loaded = False logger.info("BatchedEngine stopped") @@ -612,6 +741,20 @@ async def chat( if not self._loaded: await self.start() + # Normalize messages before any path (developer->system, merge consecutive) + messages = _normalize_messages(messages) + + # Per-request MTP routing: text-only → TextModel, media → MLLM + if self._text_model is not None and not _has_media_content(messages): + return await self._chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ) + # Extract images/videos from messages (OpenAI multimodal format) # Note: We only use extracted media here, messages are already processed by server _, extracted_images, extracted_videos = extract_multimodal_content(messages) @@ -723,6 +866,22 @@ async def stream_chat( if not self._loaded: await self.start() + # Normalize messages before any path (developer->system, merge consecutive) + messages = _normalize_messages(messages) + + # Per-request MTP routing: text-only → TextModel, media → MLLM + if self._text_model is not None and not _has_media_content(messages): + async for output in self._stream_chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ): + yield output + return + # Extract images/videos from messages (OpenAI multimodal format) # Note: We only use extracted media here, messages are already processed by server _, extracted_images, extracted_videos = extract_multimodal_content(messages) @@ -755,6 +914,469 @@ async def stream_chat( ): yield output + async def _chat_text_model( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + **kwargs, + ) -> GenerationOutput: + """Non-streaming text-only generation via mlx_lm TextModel with MTP. + + Collects all streaming output into a single GenerationOutput. + Used when MLLM+MTP routing is active and the request has no media. + """ + logger.info("Text-only request → TextModel (MTP) [non-streaming]") + accumulated_text = "" + last_chunk = None + async for chunk in self._stream_chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ): + accumulated_text = chunk.text + last_chunk = chunk + if last_chunk is not None: + return GenerationOutput( + text=accumulated_text, + prompt_tokens=last_chunk.prompt_tokens, + completion_tokens=last_chunk.completion_tokens, + finish_reason=last_chunk.finish_reason, + ) + return GenerationOutput(text="", finish_reason="stop") + + async def _stream_chat_text_model( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Streaming text-only generation via mlx_lm TextModel with MTP. + + Used when MLLM+MTP routing is active and the request has no media. + Runs the full generation in a single thread to maintain Metal safety. + + System prompt KV caching: on the first request, prefills system tokens + and snapshots backbone KV state. Subsequent requests with the same + system prompt restore the snapshot and only prefill the suffix tokens. + + SpecPrefill: when a draft model is loaded and the prompt exceeds the + threshold, uses attention-based sparse prefill for faster TTFT. + Composes with system KV cache (sparse-prefill only the suffix when + cache hits). Falls back to normal path on any error. + """ + import hashlib + import os + + import mlx.core as mx + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + + # Per-request specprefill overrides (from extra_body) + specprefill_override = kwargs.pop("specprefill", None) + specprefill_keep_pct_override = kwargs.pop("specprefill_keep_pct", None) + + # Convert tools for template + template_tools = convert_tools_for_template(tools) if tools else None + + # Read enable_thinking from env (set by runtime_patches, consistent with MLLM path) + enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") + enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") + + # Apply chat template + template_kwargs = { + "tokenize": False, + "add_generation_prompt": True, + "enable_thinking": enable_thinking, + } + if template_tools: + template_kwargs["tools"] = template_tools + + try: + prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + except TypeError: + # Template doesn't accept tools= or enable_thinking= + template_kwargs.pop("tools", None) + template_kwargs.pop("enable_thinking", None) + prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + + # Build sampler + sampler = make_sampler(temp=temperature, top_p=top_p) + max_tokens = max_tokens or 4096 + + # --- System KV cache: find system prefix boundary --- + # ChatML (Qwen 3.5): everything before first <|im_start|>user is the system prefix + USER_MARKER = "<|im_start|>user" + marker_pos = prompt.find(USER_MARKER) + if marker_pos > 0: + system_prefix = prompt[:marker_pos] + suffix = prompt[marker_pos:] + prefix_hash = hashlib.sha256(system_prefix.encode()).hexdigest()[:16] + else: + system_prefix = None + suffix = prompt + prefix_hash = None + + # Check for cache hit + cache_hit = ( + prefix_hash is not None + and prefix_hash == self._system_kv_hash + and self._system_kv_snapshot is not None + ) + + if cache_hit: + logger.info( + "Text-only request → TextModel (MTP) [streaming, system KV cache HIT: " + "reusing %d cached tokens, hash=%s]", + self._system_kv_token_count, + prefix_hash, + ) + else: + logger.info("Text-only request → TextModel (MTP) [streaming]") + + prefill_step_size = self._prefill_step_size + + # --- SpecPrefill decision --- + # Determine whether to use specprefill for this request. + # Must be decided before entering the generation lock so we can + # tokenize and check the threshold outside the critical section. + _SPECPREFILL_MAX_TOKENS = 196608 + use_specprefill = False + if self._draft_model is not None: + if specprefill_override is True: + use_specprefill = True + elif specprefill_override is None and self._specprefill_enabled: + use_specprefill = True + # specprefill_override=False explicitly disables + + # Tokenize to determine token count for specprefill threshold check. + # We need this for both specprefill and normal paths anyway. + sp_tokens = None # tokens to score (suffix or full prompt) + sp_offset = 0 # position offset for sparse_prefill + sp_n_total = 0 # total prompt tokens (for logging / threshold) + + if use_specprefill: + if cache_hit: + # Score only the suffix — system prefix is already cached + sp_tokens = self._text_tokenizer.encode(suffix) + sp_offset = self._system_kv_token_count + sp_n_total = sp_offset + len(sp_tokens) + else: + # Score the full prompt + sp_tokens = self._text_tokenizer.encode(prompt) + sp_offset = 0 + sp_n_total = len(sp_tokens) + + n_sp_tokens = len(sp_tokens) + + # Threshold check (skip when force-enabled via per-request override) + if ( + specprefill_override is not True + and n_sp_tokens <= self._specprefill_threshold + ): + use_specprefill = False + + # Upper bound: cap to avoid draft model OOM + if use_specprefill and n_sp_tokens > _SPECPREFILL_MAX_TOKENS: + logger.warning( + "SpecPrefill: prompt %d tokens exceeds max %d, " + "falling back to normal path", + n_sp_tokens, + _SPECPREFILL_MAX_TOKENS, + ) + use_specprefill = False + + # Run under generation lock, all tokens in single thread (Metal safety) + async with self._text_generation_lock: + + def _run_with_cache(): + if use_specprefill: + try: + return _run_specprefill() + except Exception as e: + logger.error( + "SpecPrefill failed, falling back to normal path: %s", e + ) + # Fall through to normal path + if cache_hit: + return _run_cache_hit() + else: + return _run_cache_miss() + + def _run_specprefill(): + """Score tokens, sparse prefill, generate autoregressively. + + Composes with system KV cache: when cache_hit, restores the + system KV snapshot first, then sparse-prefills only the suffix + tokens with position_offset = system_kv_token_count. + + Does NOT use MTP (sparse cache is incompatible with MTP + speculative decoding). + """ + import time + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) + + # Build target cache (optionally restore system KV snapshot) + target_cache = make_prompt_cache(self._text_model) + if cache_hit: + for layer_idx, snapshot_state in enumerate( + self._system_kv_snapshot + ): + if layer_idx < len(target_cache): + target_cache[layer_idx].state = snapshot_state + mx.eval([c.state for c in target_cache if hasattr(c, "state")]) + + try: + # Phase 1: Score with draft model + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + sp_tokens, + prefill_step_size=prefill_step_size, + ) + t_score = time.monotonic() - t0 + + # Phase 2: Select important chunks + effective_keep = ( + specprefill_keep_pct_override or self._specprefill_keep_pct + ) + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] + n_scored = len(sp_tokens) + + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + self._text_model, + sp_tokens, + selected, + target_cache, + step_size=prefill_step_size, + position_offset=sp_offset, + ) + t_prefill = time.monotonic() - t0 + + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " + "(offset=%d, effective_keep=%.2f)", + n_scored, + t_score, + n_selected, + n_scored, + n_selected / n_scored * 100, + t_prefill, + sp_offset, + effective_keep, + ) + + # Phase 4: Generate (simple autoregressive, no MTP) + eos_id = self._text_tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) + + results = [] + generated_ids = [] + prev_decoded = "" + + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) + + # Incremental text decode + decoded = self._text_tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded + + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, + ) + ) + + if is_eos: + break + + # Next token + logits = self._text_model(y.reshape(1, -1), cache=target_cache) + y = sampler(logits[:, -1, :]) + mx.eval(y) + + return results, sp_n_total + + finally: + cleanup_rope(self._text_model) + + def _run_cache_hit(): + """Restore system KV snapshot, prefill only suffix, generate.""" + # Restore cached KV state into a fresh cache + restored_cache = make_prompt_cache(self._text_model) + for layer_idx, snapshot_state in enumerate(self._system_kv_snapshot): + if layer_idx < len(restored_cache): + restored_cache[layer_idx].state = snapshot_state + mx.eval([c.state for c in restored_cache if hasattr(c, "state")]) + + # Tokenize just the suffix and generate with the primed cache. + # stream_generate accepts mx.array prompt (skips tokenization) + # and prompt_cache is forwarded to mtp_generate_step. + suffix_tokens = self._text_tokenizer.encode(suffix) + suffix_array = mx.array(suffix_tokens) + n_suffix = len(suffix_tokens) + + logger.info( + "System KV cache HIT: prefilling %d suffix tokens " + "(skipped %d cached tokens)", + n_suffix, + self._system_kv_token_count, + ) + + results = [] + for resp in mlx_stream_generate( + self._text_model, + self._text_tokenizer, + prompt=suffix_array, + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + prompt_cache=restored_cache, + prefill_step_size=prefill_step_size, + ): + results.append(resp) + return results, self._system_kv_token_count + len(suffix_tokens) + + def _run_cache_miss(): + """Full prefill + generation, then snapshot system KV for next time.""" + results = [] + for resp in mlx_stream_generate( + self._text_model, + self._text_tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + prefill_step_size=prefill_step_size, + ): + results.append(resp) + + # Snapshot system KV for next request (if we found a system prefix) + if prefix_hash is not None and system_prefix is not None: + try: + _snapshot_system_kv() + except Exception as e: + logger.warning("Failed to snapshot system KV cache: %s", e) + + # Get total prompt token count from generation response + prompt_tokens = 0 + if results and hasattr(results[0], "prompt_tokens"): + prompt_tokens = results[0].prompt_tokens + return results, prompt_tokens + + def _snapshot_system_kv(): + """Prefill just the system prefix on a fresh cache and save snapshot.""" + snapshot_cache = make_prompt_cache(self._text_model) + prefix_tokens = self._text_tokenizer.encode(system_prefix) + prefix_ids = mx.array(prefix_tokens) + + # Chunked prefill of system prefix + for i in range(0, prefix_ids.size, prefill_step_size): + chunk = prefix_ids[i : i + prefill_step_size] + self._text_model(chunk[None], cache=snapshot_cache) + mx.eval([c.state for c in snapshot_cache if hasattr(c, "state")]) + + # Save snapshot: deep copy of each cache layer's state + self._system_kv_snapshot = [] + for c in snapshot_cache: + state = c.state + if isinstance(state, tuple) and len(state) == 2: + # KVCache: (keys, values) — copy to detach from cache + keys, values = state + self._system_kv_snapshot.append( + (mx.array(keys), mx.array(values)) + ) + elif isinstance(state, list): + # ArraysCache: list of arrays (Mamba/hybrid) + self._system_kv_snapshot.append( + [mx.array(a) if a is not None else None for a in state] + ) + else: + # Unknown cache type — store as-is + self._system_kv_snapshot.append(state) + + self._system_kv_token_count = len(prefix_tokens) + self._system_kv_hash = prefix_hash + + cache_bytes = 0 + for entry in self._system_kv_snapshot: + if isinstance(entry, tuple) and len(entry) == 2: + cache_bytes += entry[0].nbytes + entry[1].nbytes + elif isinstance(entry, list): + cache_bytes += sum(a.nbytes for a in entry if a is not None) + logger.info( + "System KV cache: stored %d-token snapshot " "(%.1f MB), hash=%s", + len(prefix_tokens), + cache_bytes / 1e6, + prefix_hash, + ) + + result = await asyncio.to_thread(_run_with_cache) + all_resps, prompt_token_count = result + + # Yield results as GenerationOutput + accumulated_text = "" + token_count = 0 + finished = False + for i, resp in enumerate(all_resps): + token_count += 1 + new_text = resp.text if hasattr(resp, "text") else str(resp) + accumulated_text += new_text + + is_last = i == len(all_resps) - 1 + finished = is_last or token_count >= max_tokens + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=prompt_token_count, + completion_tokens=token_count, + finished=finished, + finish_reason="stop" if finished else None, + ) + + if finished: + break + + if not finished: + yield GenerationOutput( + text=accumulated_text, + new_text="", + prompt_tokens=prompt_token_count, + completion_tokens=token_count, + finished=True, + finish_reason="length", + ) + def get_stats(self) -> dict[str, Any]: """Get engine statistics.""" stats = { @@ -779,6 +1401,29 @@ def get_stats(self) -> dict[str, Any]: elif self._engine: stats.update(self._engine.get_stats()) + # SpecPrefill stats + if self._draft_model is not None: + stats["specprefill"] = { + "enabled": self._specprefill_enabled, + "draft_model": self._specprefill_draft_model_path, + "threshold": self._specprefill_threshold, + "keep_pct": self._specprefill_keep_pct, + } + + # System KV cache stats + if self._system_kv_snapshot is not None: + cache_bytes = 0 + for entry in self._system_kv_snapshot: + if isinstance(entry, tuple) and len(entry) == 2: + cache_bytes += entry[0].nbytes + entry[1].nbytes + elif isinstance(entry, list): + cache_bytes += sum(a.nbytes for a in entry if a is not None) + stats["system_kv_cache"] = { + "tokens": self._system_kv_token_count, + "hash": self._system_kv_hash, + "memory_mb": round(cache_bytes / 1e6, 1), + } + return stats def get_cache_stats(self) -> dict[str, Any] | None: diff --git a/vllm_mlx/message_utils.py b/vllm_mlx/message_utils.py new file mode 100644 index 00000000..621ac057 --- /dev/null +++ b/vllm_mlx/message_utils.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Shared message normalization utilities. + +Provides ``_normalize_messages()`` which maps non-standard roles, merges +consecutive same-role messages, and hoists system messages to position [0]. +Used by both SimpleEngine and BatchedEngine before ``apply_chat_template``. +""" + +import logging + +logger = logging.getLogger(__name__) + + +def _normalize_messages(messages: list[dict]) -> list[dict]: + """Normalize message roles and merge consecutive same-role messages. + + 1. Maps non-standard roles to standard ones (e.g. ``developer`` -> ``system``). + 2. Merges consecutive same-role messages to satisfy chat template constraints + (Qwen 3.5, Llama, etc. require alternating roles). + + Only merges when both messages have string content. Messages with list + content (multimodal) are left as-is to preserve image/video attachments. + """ + _ROLE_MAP = {"developer": "system"} + + if not messages: + return messages + + merged = [messages[0].copy()] + if merged[0]["role"] in _ROLE_MAP: + merged[0]["role"] = _ROLE_MAP[merged[0]["role"]] + for msg in messages[1:]: + prev = merged[-1] + role = _ROLE_MAP.get(msg["role"], msg["role"]) + if ( + role == prev["role"] + and isinstance(prev.get("content"), str) + and isinstance(msg.get("content"), str) + ): + prev["content"] = prev["content"] + "\n\n" + msg["content"] + logger.debug( + f"Merged consecutive {role} messages " + f"({len(prev['content'])} chars total)" + ) + else: + copy = msg.copy() + copy["role"] = role + merged.append(copy) + + # Hoist system messages to position [0] and merge them. + # Many CLIs (OpenCode, Qwen Code, Kilo) send system messages mid-conversation; + # the Qwen 3.5 chat template rejects any system message not at position [0]. + system_msgs = [m for m in merged if m["role"] == "system"] + non_system = [m for m in merged if m["role"] != "system"] + if system_msgs and (len(system_msgs) > 1 or merged[0]["role"] != "system"): + # Combine all system message content (string only) into one + parts = [] + for m in system_msgs: + c = m.get("content") + if isinstance(c, str): + parts.append(c) + elif isinstance(c, list): + # Multimodal system message — extract text parts + for part in c: + if isinstance(part, str): + parts.append(part) + elif isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + if parts: + combined_system = {"role": "system", "content": "\n\n".join(parts)} + merged = [combined_system] + non_system + logger.info( + f"Hoisted {len(system_msgs)} system message(s) to position [0] " + f"({len(combined_system['content'])} chars)" + ) + else: + # No string content — just move the first system msg to front + merged = system_msgs[:1] + non_system + logger.info("Hoisted system message to position [0]") + + merged_count = len(messages) - len(merged) + if merged_count: + logger.info(f"Normalized messages: merged {len(messages)} -> {len(merged)}") + + return merged