diff --git a/pyproject.toml b/pyproject.toml index fa92e960..b036c78e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,10 @@ vision = [ "torch>=2.3.0", "torchvision>=0.18.0", ] +# Guided decoding with outlines for structured JSON output +guided = [ + "outlines[mlxlm]>=1.0.0", +] # Audio dependencies for TTS/STT (mlx-audio) audio = [ "mlx-audio>=0.2.9", diff --git a/vllm_mlx/api/guided.py b/vllm_mlx/api/guided.py new file mode 100644 index 00000000..4f20dd2b --- /dev/null +++ b/vllm_mlx/api/guided.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Guided generation for structured JSON output using outlines. + +This module provides constrained decoding for JSON schema enforcement, +ensuring model outputs strictly adhere to specified schemas. +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +# Check for outlines availability +try: + import mlx_lm + import outlines + + HAS_OUTLINES = True +except ImportError: + HAS_OUTLINES = False + outlines = None + mlx_lm = None + + +def is_guided_available() -> bool: + """Check if guided generation with outlines is available.""" + return HAS_OUTLINES + + +def json_schema_to_pydantic(schema: dict[str, Any]) -> type | None: + """ + Convert a JSON schema to a Pydantic model dynamically. + + Args: + schema: JSON schema dict + + Returns: + Dynamically created Pydantic model class, or None if conversion fails + """ + try: + from pydantic import create_model + + # Extract properties from schema + properties = schema.get("properties", {}) + required = set(schema.get("required", [])) + + # Build field definitions for Pydantic + field_definitions = {} + + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "null": type(None), + } + + for prop_name, prop_spec in properties.items(): + prop_type = prop_spec.get("type", "string") + + # Handle array type + if prop_type == "array": + items_type = prop_spec.get("items", {}).get("type", "string") + inner_type = type_mapping.get(items_type, str) + python_type = list[inner_type] + # Handle object type (nested) + elif prop_type == "object": + # For nested objects, use dict + python_type = dict + else: + python_type = type_mapping.get(prop_type, str) + + # Make optional if not required + if prop_name not in required: + python_type = python_type | None + default = None + else: + default = ... + + field_definitions[prop_name] = (python_type, default) + + # Create the model dynamically + model = create_model("DynamicModel", **field_definitions) + return model + + except Exception as e: + logger.warning(f"Failed to convert JSON schema to Pydantic: {e}") + return None + + +class GuidedGenerator: + """ + Guided generation using outlines for constrained JSON decoding. + + This class wraps an MLX model to provide structured output generation + that guarantees valid JSON matching a specified schema. + """ + + def __init__(self, model, tokenizer): + """ + Initialize the guided generator. + + Args: + model: MLX model instance + tokenizer: Tokenizer instance + """ + if not HAS_OUTLINES: + raise ImportError( + "outlines is required for guided generation. " + "Install with: pip install 'vllm-mlx[guided]'" + ) + + self._model = model + self._tokenizer = tokenizer + self._outlines_model = None + + def _get_outlines_model(self): + """Get or create the outlines model wrapper.""" + if self._outlines_model is None: + self._outlines_model = outlines.from_mlxlm(self._model, self._tokenizer) + return self._outlines_model + + def generate_json( + self, + prompt: str, + json_schema: dict[str, Any], + max_tokens: int = 256, + temperature: float = 0.7, + ) -> str: + """ + Generate JSON output constrained to a schema. + + Args: + prompt: Input prompt + json_schema: JSON schema to constrain output + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + + Returns: + JSON string matching the schema + """ + # Convert schema to Pydantic model + pydantic_model = json_schema_to_pydantic(json_schema) + + if pydantic_model is None: + logger.warning( + "Could not convert schema to Pydantic, falling back to raw generation" + ) + return None + + try: + outlines_model = self._get_outlines_model() + + # Generate with schema constraint + result = outlines_model( + prompt, + output_type=pydantic_model, + max_tokens=max_tokens, + ) + + # result is a JSON string, validate and return + return result + + except Exception as e: + logger.error(f"Guided generation failed: {e}") + return None + + def generate_json_object( + self, + prompt: str, + max_tokens: int = 256, + temperature: float = 0.7, + ) -> str: + """ + Generate any valid JSON object. + + Args: + prompt: Input prompt + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + + Returns: + JSON string + """ + try: + from outlines import generate + + outlines_model = self._get_outlines_model() + + # Use regex to constrain to valid JSON + json_regex = r"\{[^{}]*\}" + generator = generate.regex(outlines_model, json_regex) + result = generator(prompt, max_tokens=max_tokens) + + return result + + except Exception as e: + logger.error(f"JSON object generation failed: {e}") + return None + + +def generate_with_schema( + model, + tokenizer, + prompt: str, + json_schema: dict[str, Any], + max_tokens: int = 256, + temperature: float = 0.7, +) -> str | None: + """ + Convenience function for one-shot guided JSON generation. + + Args: + model: MLX model + tokenizer: Tokenizer + prompt: Input prompt + json_schema: JSON schema + max_tokens: Maximum tokens + temperature: Sampling temperature + + Returns: + JSON string or None if guided generation unavailable/failed + """ + if not HAS_OUTLINES: + return None + + try: + generator = GuidedGenerator(model, tokenizer) + return generator.generate_json( + prompt=prompt, + json_schema=json_schema, + max_tokens=max_tokens, + temperature=temperature, + ) + except Exception as e: + logger.error(f"generate_with_schema failed: {e}") + return None diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py index 1443c167..764159b6 100644 --- a/vllm_mlx/api/tool_calling.py +++ b/vllm_mlx/api/tool_calling.py @@ -21,6 +21,42 @@ from .models import FunctionCall, ResponseFormat, ToolCall +def _is_tool_call_json(obj: dict) -> bool: + """ + Check if a JSON object looks like a tool call. + + A tool call must have: + - "name" key with a string value (function name) + - "arguments" key (the function arguments) + + This prevents false positives where regular JSON like {"name": "John", "age": 30} + would be incorrectly parsed as a tool call. + + Args: + obj: JSON object to check + + Returns: + True if object appears to be a tool call + """ + if not isinstance(obj, dict): + return False + + # Must have both "name" and "arguments" keys + if "name" not in obj or "arguments" not in obj: + return False + + # "name" must be a non-empty string (function name) + if not isinstance(obj["name"], str) or not obj["name"].strip(): + return False + + # "arguments" must be a dict or string + args = obj["arguments"] + if not isinstance(args, (dict, str)): + return False + + return True + + def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: """ Parse raw JSON tool calls from model output. @@ -30,6 +66,9 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: - Multiple objects separated by commas: {...}, {...} - JSON array: [{...}, {...}] + Only objects with BOTH "name" AND "arguments" keys are considered tool calls. + This prevents false positives with regular JSON objects. + Args: text: Raw model output text @@ -46,7 +85,7 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: try: parsed = json.loads(text) if isinstance(parsed, list) and all( - isinstance(item, dict) and "name" in item for item in parsed + _is_tool_call_json(item) for item in parsed ): return [ {"name": item["name"], "arguments": item.get("arguments", {})} @@ -71,7 +110,8 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: json_str = text[start : i + 1] try: obj = json.loads(json_str) - if isinstance(obj, dict) and "name" in obj: + # Only consider as tool call if it has both "name" AND "arguments" + if _is_tool_call_json(obj): tool_calls.append( {"name": obj["name"], "arguments": obj.get("arguments", {})} ) @@ -524,8 +564,14 @@ def build_json_system_prompt( if format_type == "json_object": return ( - "You must respond with valid JSON only. " - "Do not include any explanation or text outside the JSON object." + "⚠️ JSON OUTPUT REQUIRED ⚠️\n\n" + "You MUST respond with ONLY a valid JSON object.\n\n" + "RULES:\n" + "- Start response with { or [\n" + "- NO text before or after JSON\n" + "- NO thinking or explanations\n" + "- NO markdown code blocks (```)\n" + "- ONLY the raw JSON object" ) if format_type == "json_schema": @@ -533,14 +579,89 @@ def build_json_system_prompt( schema = json_schema_spec.get("schema", {}) name = json_schema_spec.get("name", "response") description = json_schema_spec.get("description", "") + strict = json_schema_spec.get("strict", False) + + # If strict mode is enabled and guided generation is available, + # the server should use guided decoding instead of prompt injection + if strict: + # Return stronger instruction for strict mode + prompt = ( + f"⚠️ STRICT JSON OUTPUT REQUIRED ⚠️\n\n" + f"You MUST respond with ONLY a valid JSON object matching the '{name}' schema.\n" + ) + if description: + prompt += f"Purpose: {description}\n" + prompt += ( + f"\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n" + "STRICT RULES:\n" + "- Start response with {{ or [\n" + "- NO text before or after JSON\n" + "- NO thinking, reasoning, or explanations\n" + "- NO markdown code blocks\n" + "- ONLY the JSON object" + ) + return prompt - prompt = f"You must respond with valid JSON matching the '{name}' schema." + # Standard (non-strict) mode - still strong but less aggressive + prompt = ( + f"⚠️ JSON OUTPUT REQUIRED ⚠️\n\n" + f"Respond with a valid JSON object matching the '{name}' schema.\n" + ) if description: - prompt += f" {description}" + prompt += f"Purpose: {description}\n" prompt += ( - f"\n\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n" - "Respond with only the JSON object, no additional text or explanation." + f"\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n" + "RULES:\n" + "- Start response with { or [\n" + "- NO text before or after JSON\n" + "- NO markdown code blocks\n" + "- ONLY the JSON object" ) return prompt return None + + +def extract_json_schema_for_guided( + response_format: Optional[Union[ResponseFormat, Dict[str, Any]]] = None, +) -> Optional[Dict[str, Any]]: + """ + Extract JSON schema from response_format for guided generation. + + Returns the schema dict if response_format specifies json_schema type, + otherwise returns None. + + Args: + response_format: ResponseFormat specification + + Returns: + JSON schema dict or None + """ + if response_format is None: + return None + + # Normalize to dict + if isinstance(response_format, ResponseFormat): + rf_dict = {"type": response_format.type, "json_schema": None} + if response_format.json_schema: + rf_dict["json_schema"] = { + "name": response_format.json_schema.name, + "description": response_format.json_schema.description, + "schema": response_format.json_schema.schema_, + "strict": response_format.json_schema.strict, + } + else: + rf_dict = response_format + + format_type = rf_dict.get("type", "text") + + if format_type != "json_schema": + return None + + json_schema_spec = rf_dict.get("json_schema", {}) + schema = json_schema_spec.get("schema", {}) + + if not schema: + return None + + return schema diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index e916ae19..bcc1a405 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -101,6 +101,86 @@ def clean_output_text(text: str) -> str: return text +# Pattern to match ... blocks +THINK_PATTERN = re.compile(r"[\s\S]*?\s*", re.DOTALL) + + +def strip_thinking_tags(text: str) -> str: + """ + Remove ... blocks from model output. + + Used when the client expects pure content (e.g., JSON) without + reasoning blocks that would break parsing. + + Args: + text: Model output that may contain thinking blocks + + Returns: + Text with thinking blocks removed + """ + if not text: + return text + return THINK_PATTERN.sub("", text).strip() + + +def extract_json_from_response(text: str) -> str: + """ + Extract JSON object/array from model response that may contain reasoning text. + + Qwen3 and other reasoning models often output: + "Let me think... reasoning text... {\"result\": 123}" + + This function extracts just the JSON part when present. + + Args: + text: Model output that may contain text before/after JSON + + Returns: + Extracted JSON string if found, otherwise original text + """ + if not text: + return text + + text = text.strip() + + # If already valid JSON, return as-is + if (text.startswith("{") and text.endswith("}")) or ( + text.startswith("[") and text.endswith("]") + ): + return text + + # Try to find JSON object at the end of the response + # Find the last { and match to the end + last_brace = text.rfind("{") + if last_brace != -1 and text.endswith("}"): + potential_json = text[last_brace:] + # Validate it's balanced + depth = 0 + for char in potential_json: + if char == "{": + depth += 1 + elif char == "}": + depth -= 1 + if depth == 0: + return potential_json + + # Try to find JSON array at the end + last_bracket = text.rfind("[") + if last_bracket != -1 and text.endswith("]"): + potential_json = text[last_bracket:] + depth = 0 + for char in potential_json: + if char == "[": + depth += 1 + elif char == "]": + depth -= 1 + if depth == 0: + return potential_json + + # No JSON found, return original + return text + + # ============================================================================= # Model Detection # ============================================================================= diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index dcbee8ac..c9923296 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -107,6 +107,10 @@ def serve_command(args): print(f"Loading model: {args.model}") print(f"Default max tokens: {args.max_tokens}") + if args.draft_model: + print("Speculative decoding: ENABLED") + print(f" Draft model: {args.draft_model}") + print(f" Draft tokens: {args.num_draft_tokens}") # Store MCP config path for FastAPI startup if args.mcp_config: @@ -187,6 +191,8 @@ def serve_command(args): stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, force_mllm=args.mllm, + draft_model=args.draft_model, + num_draft_tokens=args.num_draft_tokens, ) # Start server @@ -786,6 +792,19 @@ def main(): "Required for --enable-auto-tool-choice." ), ) + # Speculative decoding options + serve_parser.add_argument( + "--draft-model", + type=str, + default=None, + help="Draft model for speculative decoding (must use same tokenizer as main model)", + ) + serve_parser.add_argument( + "--num-draft-tokens", + type=int, + default=4, + help="Number of tokens to generate speculatively per step (default: 4)", + ) # Reasoning parser options - choices loaded dynamically from registry from .reasoning import list_parsers diff --git a/vllm_mlx/engine/__init__.py b/vllm_mlx/engine/__init__.py index f6625abd..d33c4bc4 100644 --- a/vllm_mlx/engine/__init__.py +++ b/vllm_mlx/engine/__init__.py @@ -12,6 +12,7 @@ from .base import BaseEngine, GenerationOutput from .simple import SimpleEngine from .batched import BatchedEngine +from .hybrid import HybridEngine # Re-export from parent engine.py for backwards compatibility from ..engine_core import EngineCore, AsyncEngineCore, EngineConfig @@ -21,6 +22,7 @@ "GenerationOutput", "SimpleEngine", "BatchedEngine", + "HybridEngine", # Core engine components "EngineCore", "AsyncEngineCore", diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628..05906192 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -161,6 +161,7 @@ def __init__( self._mllm_scheduler = None # MLLMScheduler for MLLM self._mllm_instance = None # MLXMultimodalLM instance self._loaded = False + self._engine_started = False # Track if engine loop is running @property def model_name(self) -> str: @@ -800,3 +801,52 @@ def load_cache_from_disk(self, cache_dir: str) -> int: if self._engine: return self._engine.load_cache_from_disk(cache_dir) return 0 + + async def _inject_shared_model( + self, + model, + tokenizer, + start_engine: bool = True, + ) -> None: + """ + Inject a pre-loaded shared model instead of loading a new one. + + This is used by HybridEngine to share a single model instance + between SimpleEngine and BatchedEngine, saving ~44GB of RAM. + + Args: + model: Pre-loaded MLX model + tokenizer: Pre-loaded tokenizer + start_engine: Whether to start the engine loop immediately. + Set to False for HybridEngine (lazy start on first use). + """ + from ..engine_core import AsyncEngineCore, EngineConfig + from ..scheduler import SchedulerConfig + + self._model = model + self._tokenizer = tokenizer + + # Create engine config + scheduler_config = self._scheduler_config or SchedulerConfig() + engine_config = EngineConfig( + model_name=self._model_name, + scheduler_config=scheduler_config, + stream_interval=self._stream_interval, + ) + + # Create async engine with shared model + self._engine = AsyncEngineCore( + model=self._model, + tokenizer=self._tokenizer, + config=engine_config, + ) + + # Only start engine loop if requested (HybridEngine starts lazily) + if start_engine: + await self._engine.engine.start() + + self._loaded = True + self._engine_started = start_engine + logger.info( + f"BatchedEngine injected with shared model: {self._model_name} (started={start_engine})" + ) diff --git a/vllm_mlx/engine/hybrid.py b/vllm_mlx/engine/hybrid.py new file mode 100644 index 00000000..c0b7c2aa --- /dev/null +++ b/vllm_mlx/engine/hybrid.py @@ -0,0 +1,509 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Hybrid engine combining speculative decoding with continuous batching. + +This engine shares a single model instance between SimpleEngine and BatchedEngine, +saving ~44GB RAM while providing: +- Speculative decoding (80+ tok/s) for single-user scenarios +- Continuous batching (60-70 tok/s) for multi-user scenarios + +The engine automatically switches between modes based on concurrent request count. + +Architecture: + HybridEngine + ├── Shared components (loaded once) + │ ├── _shared_model (44GB) + │ ├── _shared_tokenizer + │ └── _ownership_lock + │ + ├── SimpleEngine (speculative) + │ └── + draft_model (350MB) + │ + └── BatchedEngine (batching) + └── Scheduler + BatchGenerator + +Mode switching logic: + if active_requests < switch_threshold: + → SimpleEngine (speculative, 80+ tok/s) + else: + → BatchedEngine (batching, 60-70 tok/s) +""" + +import asyncio +import logging +from collections.abc import AsyncIterator +from typing import Any + +from mlx_lm import load + +from .base import BaseEngine, GenerationOutput +from .batched import BatchedEngine +from .simple import SimpleEngine +from ..model_registry import get_registry + +logger = logging.getLogger(__name__) + + +class HybridEngine(BaseEngine): + """ + Hybrid engine: speculative decoding for single-user, + continuous batching for multi-user. Shares model instance. + + This engine provides the best of both worlds: + - When serving a single user: uses SimpleEngine with speculative decoding + for maximum throughput (80+ tok/s with Qwen3-0.6B draft model) + - When serving multiple concurrent users: switches to BatchedEngine + for efficient continuous batching + + The key innovation is sharing a single model instance (~44GB for Qwen3-Next-80B) + between both engines, cutting memory usage in half compared to running + separate servers. + """ + + def __init__( + self, + model_name: str, + draft_model: str | None = None, + num_draft_tokens: int = 5, + scheduler_config: Any | None = None, + stream_interval: int = 1, + trust_remote_code: bool = True, + force_mllm: bool = False, + switch_threshold: int = 2, + ): + """ + Initialize the hybrid engine. + + Args: + model_name: HuggingFace model name or local path + draft_model: Draft model for speculative decoding (e.g., Qwen3-0.6B-4bit) + num_draft_tokens: Number of tokens to generate speculatively (default: 5) + scheduler_config: Scheduler config for batched mode + stream_interval: Tokens to batch before streaming (batched mode only) + trust_remote_code: Whether to trust remote code + force_mllm: Force loading as MLLM even if not auto-detected + switch_threshold: Number of concurrent requests to trigger batch mode (default: 2) + """ + self._model_name = model_name + self._draft_model_name = draft_model + self._num_draft_tokens = num_draft_tokens + self._scheduler_config = scheduler_config + self._stream_interval = stream_interval + self._trust_remote_code = trust_remote_code + self._force_mllm = force_mllm + self._switch_threshold = switch_threshold + + # Shared resources (loaded once) + self._shared_model = None + self._shared_tokenizer = None + + # Engine instances + self._simple: SimpleEngine | None = None + self._batched: BatchedEngine | None = None + self._current_mode: str | None = None # 'simple' or 'batched' + + # Concurrency tracking + self._active_requests = 0 + self._lock = asyncio.Lock() + self._switch_lock = asyncio.Lock() + + # State + self._loaded = False + self._is_mllm = False + + @property + def model_name(self) -> str: + """Get the model name.""" + return self._model_name + + @property + def is_mllm(self) -> bool: + """Check if this is a multimodal model.""" + return self._is_mllm + + @property + def tokenizer(self) -> Any: + """Get the tokenizer.""" + return self._shared_tokenizer + + async def start(self) -> None: + """Start the engine (load shared model and initialize sub-engines).""" + if self._loaded: + return + + logger.info(f"HybridEngine loading shared model: {self._model_name}") + + # Load model once using mlx-lm + self._shared_model, self._shared_tokenizer = load(self._model_name) + + # Check if MLLM + from ..api.utils import is_mllm_model + + self._is_mllm = self._force_mllm or is_mllm_model(self._model_name) + + if self._is_mllm: + logger.warning( + "HybridEngine does not support MLLM models yet. " + "Using BatchedEngine only." + ) + # For MLLM, just use BatchedEngine (no speculative) + self._batched = BatchedEngine( + model_name=self._model_name, + scheduler_config=self._scheduler_config, + stream_interval=self._stream_interval, + force_mllm=True, + ) + await self._batched.start() + self._current_mode = "batched" + else: + # Create SimpleEngine with draft model support + self._simple = SimpleEngine( + model_name=self._model_name, + trust_remote_code=self._trust_remote_code, + draft_model=self._draft_model_name, + num_draft_tokens=self._num_draft_tokens, + ) + # Inject shared model instead of loading again + await self._simple._inject_shared_model( + self._shared_model, + self._shared_tokenizer, + ) + + # Create BatchedEngine (lazy start - don't start engine loop yet) + self._batched = BatchedEngine( + model_name=self._model_name, + trust_remote_code=self._trust_remote_code, + scheduler_config=self._scheduler_config, + stream_interval=self._stream_interval, + ) + # Inject shared model but DON'T start engine loop yet + # Engine will be started on first switch to batched mode + await self._batched._inject_shared_model( + self._shared_model, + self._shared_tokenizer, + start_engine=False, # Lazy start for HybridEngine + ) + self._batched_engine_started = False + + # Start in simple mode (speculative decoding) + self._current_mode = "simple" + + self._loaded = True + + spec_info = "" + if self._draft_model_name and not self._is_mllm: + spec_info = f", draft={self._draft_model_name}, k={self._num_draft_tokens}" + + logger.info( + f"HybridEngine ready: {self._model_name} " + f"(mode={self._current_mode}, threshold={self._switch_threshold}{spec_info})" + ) + + async def stop(self) -> None: + """Stop the engine and cleanup resources.""" + if self._simple: + await self._simple.stop() + self._simple = None + + if self._batched: + await self._batched.stop() + self._batched = None + + self._shared_model = None + self._shared_tokenizer = None + self._loaded = False + self._current_mode = None + + logger.info("HybridEngine stopped") + + def _get_engine_for_request(self) -> BaseEngine: + """ + Get the appropriate engine for the current request. + + Note: This doesn't switch modes - it returns the engine based on + current mode. Mode switching happens separately. + """ + # For MLLM, always use batched + if self._is_mllm: + return self._batched + + return self._simple if self._current_mode == "simple" else self._batched + + async def _switch_to_mode(self, target_mode: str) -> None: + """ + Switch to the specified mode, handling ownership transfer. + + This method ensures proper model ownership transfer between engines + to prevent KV cache conflicts. + + Args: + target_mode: 'simple' or 'batched' + """ + if self._current_mode == target_mode: + return + + async with self._switch_lock: + # Double-check after acquiring lock + if self._current_mode == target_mode: + return + + old_mode = self._current_mode + + if target_mode == "batched": + # Switching to batched mode + if self._batched and self._batched._engine: + # Start BatchedEngine's engine loop if not started yet (lazy start) + if not getattr(self, "_batched_engine_started", True): + logger.info("HybridEngine: starting BatchedEngine (lazy start)") + await self._batched._engine.engine.start() + self._batched_engine_started = True + + # BatchedEngine needs model ownership for its BatchGenerator + registry = get_registry() + try: + registry.acquire( + model=self._shared_model, + engine=self._batched._engine.engine, + engine_id=self._batched._engine.engine.engine_id, + force=True, # Force transfer from SimpleEngine + ) + logger.info( + "HybridEngine: model ownership transferred to BatchedEngine" + ) + except Exception as e: + logger.warning(f"Ownership transfer failed: {e}") + # Continue anyway - the transfer may have succeeded partially + else: + # Switching to simple mode + # SimpleEngine doesn't need registry ownership (uses mlx_lm.generate directly) + # Just reset BatchedEngine's scheduler to clear KV cache state + if self._batched and self._batched._engine: + try: + self._batched._engine.engine.scheduler.deep_reset() + logger.debug("Reset BatchedEngine scheduler for mode switch") + except Exception as e: + logger.warning(f"Failed to reset BatchedEngine: {e}") + + self._current_mode = target_mode + logger.info(f"HybridEngine: mode switched {old_mode} -> {target_mode}") + + async def _decide_and_switch_mode(self, entering: bool = True) -> str: + """ + Decide which mode to use and switch if necessary. + + Args: + entering: True when entering a request, False when exiting + + Returns: + The mode to use for this request ('simple' or 'batched') + """ + if self._is_mllm: + return "batched" + + # Decide based on active request count + # When entering: count includes this request + # When exiting: count excludes this request + active = self._active_requests + + if active >= self._switch_threshold: + target_mode = "batched" + else: + target_mode = "simple" + + # Only switch when safe (no active requests in the other engine) + # This is a simplified heuristic - we switch when crossing the threshold + if target_mode != self._current_mode: + # Log the decision + logger.debug( + f"HybridEngine: active_requests={active}, " + f"threshold={self._switch_threshold}, " + f"current={self._current_mode}, target={target_mode}" + ) + + # Only switch to batched immediately, switch to simple when quiet + if target_mode == "batched": + await self._switch_to_mode("batched") + elif active == 0: + # Switch back to simple only when completely idle + await self._switch_to_mode("simple") + + return self._current_mode + + async def generate( + self, + prompt: str, + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + stop: list[str] | None = None, + **kwargs, + ) -> GenerationOutput: + """Generate a complete response (non-streaming).""" + if not self._loaded: + await self.start() + + async with self._lock: + self._active_requests += 1 + + try: + # Decide mode and switch if needed + mode = await self._decide_and_switch_mode(entering=True) + engine = self._simple if mode == "simple" else self._batched + + return await engine.generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ) + finally: + async with self._lock: + self._active_requests -= 1 + # Check if we should switch back to simple mode + await self._decide_and_switch_mode(entering=False) + + async def stream_generate( + self, + prompt: str, + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + stop: list[str] | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Stream generation token by token.""" + if not self._loaded: + await self.start() + + async with self._lock: + self._active_requests += 1 + + try: + # Decide mode and switch if needed + mode = await self._decide_and_switch_mode(entering=True) + engine = self._simple if mode == "simple" else self._batched + + async for output in engine.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ): + yield output + finally: + async with self._lock: + self._active_requests -= 1 + # Check if we should switch back to simple mode + await self._decide_and_switch_mode(entering=False) + + async def chat( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + images: list[str] | None = None, + videos: list[str] | None = None, + **kwargs, + ) -> GenerationOutput: + """Chat completion (non-streaming).""" + if not self._loaded: + await self.start() + + async with self._lock: + self._active_requests += 1 + + try: + # Decide mode and switch if needed + mode = await self._decide_and_switch_mode(entering=True) + engine = self._simple if mode == "simple" else self._batched + + return await engine.chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + **kwargs, + ) + finally: + async with self._lock: + self._active_requests -= 1 + # Check if we should switch back to simple mode + await self._decide_and_switch_mode(entering=False) + + async def stream_chat( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + images: list[str] | None = None, + videos: list[str] | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Stream chat completion token by token.""" + if not self._loaded: + await self.start() + + async with self._lock: + self._active_requests += 1 + + try: + # Decide mode and switch if needed + mode = await self._decide_and_switch_mode(entering=True) + engine = self._simple if mode == "simple" else self._batched + + async for output in engine.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + **kwargs, + ): + yield output + finally: + async with self._lock: + self._active_requests -= 1 + # Check if we should switch back to simple mode + await self._decide_and_switch_mode(entering=False) + + def get_stats(self) -> dict[str, Any]: + """Get engine statistics.""" + stats = { + "engine_type": "hybrid", + "model_name": self._model_name, + "is_mllm": self._is_mllm, + "loaded": self._loaded, + "current_mode": self._current_mode, + "active_requests": self._active_requests, + "switch_threshold": self._switch_threshold, + "draft_model": self._draft_model_name, + "num_draft_tokens": self._num_draft_tokens, + } + + if self._simple: + stats["simple_engine"] = self._simple.get_stats() + if self._batched: + stats["batched_engine"] = self._batched.get_stats() + + return stats + + def get_cache_stats(self) -> dict[str, Any] | None: + """Get cache statistics from active engine.""" + if self._current_mode == "simple" and self._simple: + return self._simple.get_cache_stats() + elif self._batched: + return self._batched.get_cache_stats() + return None diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index ed118bbb..df42b3e8 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -17,6 +17,15 @@ logger = logging.getLogger(__name__) +# Check for guided generation availability +try: + from ..api.guided import GuidedGenerator, is_guided_available + + HAS_GUIDED = is_guided_available() +except ImportError: + HAS_GUIDED = False + GuidedGenerator = None + class SimpleEngine(BaseEngine): """ @@ -32,6 +41,8 @@ def __init__( trust_remote_code: bool = True, enable_cache: bool = True, force_mllm: bool = False, + draft_model: str | None = None, + num_draft_tokens: int = 4, ): """ Initialize the simple engine. @@ -41,11 +52,15 @@ def __init__( trust_remote_code: Whether to trust remote code enable_cache: Enable VLM cache for multimodal models force_mllm: Force loading as MLLM even if not auto-detected + draft_model: Optional draft model path for speculative decoding + num_draft_tokens: Number of tokens to generate speculatively per step """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._enable_cache = enable_cache self._is_mllm = force_mllm or is_mllm_model(model_name) + self._draft_model_name = draft_model + self._num_draft_tokens = num_draft_tokens self._model = None self._loaded = False @@ -85,17 +100,27 @@ async def start(self) -> None: trust_remote_code=self._trust_remote_code, enable_cache=self._enable_cache, ) + if self._draft_model_name: + logger.warning("Speculative decoding is not supported with MLLM models") else: from ..models.llm import MLXLanguageModel self._model = MLXLanguageModel( self._model_name, trust_remote_code=self._trust_remote_code, + draft_model=self._draft_model_name, + num_draft_tokens=self._num_draft_tokens, ) self._model.load() self._loaded = True - logger.info(f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm})") + + spec_info = "" + if self._draft_model_name and not self._is_mllm: + spec_info = f", speculative={self._draft_model_name}" + logger.info( + f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm}{spec_info})" + ) async def stop(self) -> None: """Stop the engine and cleanup resources.""" @@ -440,3 +465,174 @@ def get_cache_stats(self) -> dict[str, Any] | None: if self._is_mllm and self._model is not None: return self._model.get_cache_stats() return None + + async def _inject_shared_model( + self, + model, + tokenizer, + ) -> None: + """ + Inject a pre-loaded shared model instead of loading a new one. + + This is used by HybridEngine to share a single model instance + between SimpleEngine and BatchedEngine, saving ~44GB of RAM. + + Args: + model: Pre-loaded MLX model + tokenizer: Pre-loaded tokenizer + """ + from ..models.llm import MLXLanguageModel + + # Create MLXLanguageModel wrapper without loading + self._model = MLXLanguageModel.__new__(MLXLanguageModel) + self._model.model_name = self._model_name + self._model.tokenizer_name = self._model_name + self._model.trust_remote_code = self._trust_remote_code + self._model.draft_model_name = self._draft_model_name + self._model.num_draft_tokens = self._num_draft_tokens + self._model.model = model + self._model.tokenizer = tokenizer + self._model.draft_model = None + self._model._loaded = True + + # Load draft model separately if specified + if self._draft_model_name: + from mlx_lm import load as mlx_load + + logger.info( + f"Loading draft model for speculative decoding: {self._draft_model_name}" + ) + self._model.draft_model, draft_tokenizer = mlx_load(self._draft_model_name) + + # Validate tokenizer compatibility + if draft_tokenizer.vocab_size != tokenizer.vocab_size: + logger.warning( + f"Draft model tokenizer vocab size ({draft_tokenizer.vocab_size}) " + f"differs from main model ({tokenizer.vocab_size}). " + "This may reduce speculative decoding effectiveness." + ) + + logger.info( + f"Speculative decoding enabled: draft={self._draft_model_name}, " + f"num_draft_tokens={self._num_draft_tokens}" + ) + + self._loaded = True + logger.info(f"SimpleEngine injected with shared model: {self._model_name}") + + @property + def supports_guided_generation(self) -> bool: + """Check if guided generation is available.""" + return HAS_GUIDED and not self._is_mllm + + async def generate_with_schema( + self, + messages: list[dict[str, Any]], + json_schema: dict[str, Any], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + **kwargs, + ) -> GenerationOutput: + """ + Generate JSON output constrained to a schema using guided decoding. + + This method uses outlines for constrained generation to guarantee + the output is valid JSON matching the specified schema. + + Args: + messages: List of chat messages + json_schema: JSON schema to constrain output + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling + **kwargs: Additional parameters + + Returns: + GenerationOutput with JSON text matching the schema + """ + if not self.supports_guided_generation: + raise RuntimeError( + "Guided generation not available. " + "Install with: pip install 'vllm-mlx[guided]'" + ) + + if not self._loaded: + await self.start() + + # Build prompt from messages + tokenizer = self._model.tokenizer + if hasattr(tokenizer, "apply_chat_template"): + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + else: + prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages) + prompt += "\nassistant:" + + async with self._generation_lock: + # Run guided generation in thread pool + result = await asyncio.to_thread( + self._run_guided_generation, + prompt=prompt, + json_schema=json_schema, + max_tokens=max_tokens, + temperature=temperature, + ) + + if result is None: + # Fallback to regular generation + logger.warning( + "Guided generation failed, falling back to regular generation" + ) + return await self.generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + # Tokenize for completion count + tokens = tokenizer.encode(result) + + return GenerationOutput( + text=result, + tokens=tokens, + prompt_tokens=len(tokenizer.encode(prompt)), + completion_tokens=len(tokens), + finish_reason="stop", + ) + + def _run_guided_generation( + self, + prompt: str, + json_schema: dict[str, Any], + max_tokens: int, + temperature: float, + ) -> str | None: + """ + Run guided generation synchronously (called from thread pool). + + Args: + prompt: Input prompt + json_schema: JSON schema + max_tokens: Maximum tokens + temperature: Sampling temperature + + Returns: + JSON string or None if failed + """ + try: + generator = GuidedGenerator(self._model.model, self._model.tokenizer) + return generator.generate_json( + prompt=prompt, + json_schema=json_schema, + max_tokens=max_tokens, + temperature=temperature, + ) + except Exception as e: + logger.error(f"Guided generation error: {e}") + return None diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 092c060e..1c75fb3a 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -50,6 +50,8 @@ def __init__( model_name: str, tokenizer_name: str | None = None, trust_remote_code: bool = False, + draft_model: str | None = None, + num_draft_tokens: int = 4, ): """ Initialize the MLX language model. @@ -58,13 +60,18 @@ def __init__( model_name: HuggingFace model name or local path tokenizer_name: Optional separate tokenizer name trust_remote_code: Whether to trust remote code + draft_model: Optional draft model path for speculative decoding + num_draft_tokens: Number of tokens to generate speculatively per step """ self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name self.trust_remote_code = trust_remote_code + self.draft_model_name = draft_model + self.num_draft_tokens = num_draft_tokens self.model = None self.tokenizer = None + self.draft_model = None self._loaded = False def load(self) -> None: @@ -91,6 +98,28 @@ def load(self) -> None: tokenizer_config=tokenizer_config, ) + # Load draft model for speculative decoding if specified + if self.draft_model_name: + logger.info( + f"Loading draft model for speculative decoding: {self.draft_model_name}" + ) + from mlx_lm import load as mlx_load + + self.draft_model, draft_tokenizer = mlx_load(self.draft_model_name) + + # Validate tokenizer compatibility + if draft_tokenizer.vocab_size != self.tokenizer.vocab_size: + logger.warning( + f"Draft model tokenizer vocab size ({draft_tokenizer.vocab_size}) " + f"differs from main model ({self.tokenizer.vocab_size}). " + "This may reduce speculative decoding effectiveness." + ) + + logger.info( + f"Speculative decoding enabled: draft={self.draft_model_name}, " + f"num_draft_tokens={self.num_draft_tokens}" + ) + self._loaded = True logger.info(f"Model loaded successfully: {self.model_name}") @@ -147,15 +176,31 @@ def generate( # Create sampler with parameters sampler = self._create_sampler(temperature, top_p) - # Generate text - output_text = generate( - self.model, - self.tokenizer, - prompt=prompt, - max_tokens=max_tokens, - sampler=sampler, - verbose=False, - ) + # Note: mlx_lm.generate() doesn't support draft_model directly, + # speculative decoding is only available via stream_generate() + if self.draft_model is not None: + # Use streaming with draft model and collect result + output_text = "" + for chunk in self.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + ): + output_text += chunk.text + if chunk.finished: + break + else: + # Generate text without speculative decoding + output_text = generate( + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + verbose=False, + ) # Tokenize output to get token IDs tokens = self.tokenizer.encode(output_text) @@ -203,12 +248,22 @@ def stream_generate( token_count = 0 accumulated_text = "" + # Build generation kwargs + gen_kwargs = { + "max_tokens": max_tokens, + "sampler": sampler, + } + + # Add draft model for speculative decoding if available + if self.draft_model is not None: + gen_kwargs["draft_model"] = self.draft_model + gen_kwargs["num_draft_tokens"] = self.num_draft_tokens + for response in stream_generate( self.model, self.tokenizer, prompt=prompt, - max_tokens=max_tokens, - sampler=sampler, + **gen_kwargs, ): token_count += 1 # response.text is the new token text (not accumulated) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 138c166a..30b6ef9a 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -93,16 +93,25 @@ from .api.tool_calling import ( build_json_system_prompt, convert_tools_for_template, + extract_json_schema_for_guided, parse_json_output, parse_tool_calls, ) from .api.utils import ( SPECIAL_TOKENS_PATTERN, clean_output_text, + extract_json_from_response, extract_multimodal_content, is_mllm_model, # noqa: F401 + strip_thinking_tags, +) +from .engine import ( + BaseEngine, + BatchedEngine, + GenerationOutput, + HybridEngine, + SimpleEngine, ) -from .engine import BaseEngine, BatchedEngine, GenerationOutput, SimpleEngine from .tool_parsers import ToolParserManager logging.basicConfig(level=logging.INFO) @@ -467,6 +476,8 @@ def load_model( stream_interval: int = 1, max_tokens: int = 32768, force_mllm: bool = False, + draft_model: str | None = None, + num_draft_tokens: int = 4, ): """ Load a model (auto-detects MLLM vs LLM). @@ -478,6 +489,8 @@ def load_model( stream_interval: Tokens to batch before streaming (batched mode only) max_tokens: Default max tokens for generation force_mllm: Force loading as MLLM even if not auto-detected + draft_model: Optional draft model for speculative decoding + num_draft_tokens: Number of tokens to generate speculatively per step """ global _engine, _model_name, _default_max_tokens, _tool_parser_instance @@ -489,7 +502,28 @@ def load_model( if force_mllm: logger.info("Force MLLM mode enabled via --mllm flag") - if use_batching: + if draft_model: + logger.info(f"Speculative decoding enabled with draft model: {draft_model}") + logger.info(f" num_draft_tokens: {num_draft_tokens}") + + if use_batching and draft_model: + # Hybrid mode: shared model with speculative decoding + continuous batching + logger.info(f"Loading model with HybridEngine: {model_name}") + logger.info( + " Hybrid mode: speculative decoding for single user, " + "batching for multiple users" + ) + _engine = HybridEngine( + model_name=model_name, + draft_model=draft_model, + num_draft_tokens=num_draft_tokens, + scheduler_config=scheduler_config, + stream_interval=stream_interval, + force_mllm=force_mllm, + ) + # HybridEngine will be started in lifespan (uvicorn's event loop) + logger.info(f"Model loaded (hybrid mode): {model_name}") + elif use_batching: logger.info(f"Loading model with BatchedEngine: {model_name}") _engine = BatchedEngine( model_name=model_name, @@ -502,7 +536,12 @@ def load_model( logger.info(f"Model loaded (batched mode): {model_name}") else: logger.info(f"Loading model with SimpleEngine: {model_name}") - _engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm) + _engine = SimpleEngine( + model_name=model_name, + force_mllm=force_mllm, + draft_model=draft_model, + num_draft_tokens=num_draft_tokens, + ) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) loop = asyncio.new_event_loop() @@ -1370,11 +1409,34 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re start_time = time.perf_counter() timeout = request.timeout or _default_timeout - output = await _wait_with_disconnect( - engine.chat(messages=messages, **chat_kwargs), - raw_request, - timeout=timeout, - ) + # Check if we should use guided generation for JSON schema + use_guided = False + json_schema = None + if response_format and not request.tools: + json_schema = extract_json_schema_for_guided(response_format) + if json_schema and hasattr(engine, "supports_guided_generation"): + use_guided = engine.supports_guided_generation + if use_guided: + logger.info("Using guided generation for JSON schema enforcement") + + if use_guided and json_schema: + # Use guided generation for constrained JSON output + output = await _wait_with_disconnect( + engine.generate_with_schema( + messages=messages, + json_schema=json_schema, + **chat_kwargs, + ), + raw_request, + timeout=timeout, + ) + else: + # Standard generation + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + raw_request, + timeout=timeout, + ) if output is None: return Response(status_code=499) # Client closed request @@ -1408,12 +1470,22 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re # Determine finish reason finish_reason = "tool_calls" if tool_calls else output.finish_reason + # Clean and strip thinking tags from content + # Thinking tags break JSON parsing in clients expecting pure content + # Also extract JSON if model outputs reasoning text before JSON + final_content = None + if cleaned_text: + final_content = strip_thinking_tags(clean_output_text(cleaned_text)) + # If response looks like it ends with JSON, extract just the JSON part + # This handles Qwen3 reasoning mode: "Let me think... {json}" + final_content = extract_json_from_response(final_content) + return ChatCompletionResponse( model=request.model, choices=[ ChatCompletionChoice( message=AssistantMessage( - content=clean_output_text(cleaned_text) if cleaned_text else None, + content=final_content, reasoning=reasoning_text, tool_calls=tool_calls, ), @@ -1432,7 +1504,11 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages. - If a system message exists, append to it. Otherwise, prepend a new system message. + PREPENDS instruction to system message for better instruction following. + JSON formatting instructions at the beginning of system message are more + likely to be followed than instructions at the end. + + If a system message exists, prepend to it. Otherwise, prepend a new system message. """ messages = list(messages) # Make a copy @@ -1445,14 +1521,15 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: break if system_idx is not None: - # Append to existing system message + # PREPEND to existing system message (not append!) + # Instructions at the start are more effective msg = messages[system_idx] if isinstance(msg, dict): existing = msg.get("content", "") - msg["content"] = f"{existing}\n\n{instruction}" + msg["content"] = f"{instruction}\n\n{existing}" else: existing = getattr(msg, "content", "") or "" - msg.content = f"{existing}\n\n{instruction}" + msg.content = f"{instruction}\n\n{existing}" else: # Prepend new system message messages.insert(0, {"role": "system", "content": instruction}) @@ -2219,6 +2296,18 @@ def main(): default=None, help="Default top_p for generation when not specified in request", ) + parser.add_argument( + "--draft-model", + type=str, + default=None, + help="Draft model for speculative decoding (must use same tokenizer as main model)", + ) + parser.add_argument( + "--num-draft-tokens", + type=int, + default=4, + help="Number of tokens to generate speculatively per step (default: 4)", + ) args = parser.parse_args() @@ -2276,6 +2365,8 @@ def main(): use_batching=args.continuous_batching, max_tokens=args.max_tokens, force_mllm=args.mllm, + draft_model=args.draft_model, + num_draft_tokens=args.num_draft_tokens, ) # Start server diff --git a/vllm_mlx/speculative/__init__.py b/vllm_mlx/speculative/__init__.py new file mode 100644 index 00000000..019b8895 --- /dev/null +++ b/vllm_mlx/speculative/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Speculative decoding utilities for vllm-mlx. +""" + +from .prompt_lookup import PromptLookupDecoder + +__all__ = ["PromptLookupDecoder"] diff --git a/vllm_mlx/speculative/prompt_lookup.py b/vllm_mlx/speculative/prompt_lookup.py new file mode 100644 index 00000000..31c06ef8 --- /dev/null +++ b/vllm_mlx/speculative/prompt_lookup.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Prompt Lookup Decoding for speculative token generation. + +This module implements Prompt Lookup Decoding, a draft-model-free approach +to speculative decoding that uses n-gram matching from the prompt and +generated text to predict future tokens. + +Reference: https://github.com/apoorvumang/prompt-lookup-decoding +""" + +import logging +from collections import defaultdict +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +class PromptLookupDecoder: + """ + Prompt Lookup Decoder for speculative token generation. + + Uses n-gram matching to find repeating patterns in the prompt + and generated text to speculate future tokens without a draft model. + + This is particularly effective for: + - Code generation (repetitive patterns) + - Structured text (JSON, XML, markdown) + - Text with common phrases + - Translation tasks (similar patterns) + + Args: + num_draft_tokens: Number of tokens to draft (default: 4) + ngram_size: Size of n-gram to match (default: 3) + min_matches: Minimum number of matching tokens required (default: 2) + """ + + def __init__( + self, + num_draft_tokens: int = 4, + ngram_size: int = 3, + min_matches: int = 2, + ): + self.num_draft_tokens = num_draft_tokens + self.ngram_size = ngram_size + self.min_matches = min_matches + + # Token history for n-gram lookup + self._token_history: list[int] = [] + + # N-gram index: maps (token_1, ..., token_n) -> [positions] + self._ngram_index: dict[tuple, list[int]] = defaultdict(list) + + # Statistics + self.total_drafts = 0 + self.successful_drafts = 0 + self.total_draft_tokens = 0 + self.accepted_tokens = 0 + + def reset(self): + """Reset the decoder state for a new generation.""" + self._token_history = [] + self._ngram_index = defaultdict(list) + + def add_prompt_tokens(self, tokens: list[int]): + """ + Add prompt tokens to the history for lookup. + + Args: + tokens: List of prompt token IDs + """ + for token in tokens: + self._add_token(token) + + def _add_token(self, token: int): + """Add a single token to history and update n-gram index.""" + self._token_history.append(token) + + # Update n-gram index for all n-gram sizes up to ngram_size + pos = len(self._token_history) - 1 + for n in range(1, min(self.ngram_size + 1, pos + 1)): + if pos >= n: + ngram = tuple(self._token_history[pos - n : pos]) + self._ngram_index[ngram].append(pos) + + def add_generated_token(self, token: int): + """ + Add a generated token to history. + + Args: + token: Generated token ID + """ + self._add_token(token) + + def get_draft_tokens(self) -> list[int]: + """ + Get draft tokens based on n-gram lookup. + + Returns: + List of draft token IDs (may be empty if no match found) + """ + if len(self._token_history) < self.ngram_size: + return [] + + # Get the last ngram_size tokens as the query + query = tuple(self._token_history[-self.ngram_size :]) + + # Look up positions where this n-gram occurred + positions = self._ngram_index.get(query, []) + + if not positions: + return [] + + # Find the best match (most recent, or one with longest continuation) + draft_tokens = [] + best_continuation_length = 0 + + for pos in positions: + # Skip if this is the current position + if pos == len(self._token_history) - 1: + continue + + # Check if there are tokens after this position + if pos < len(self._token_history) - 1: + # Get continuation tokens + continuation_end = min( + pos + self.num_draft_tokens + 1, len(self._token_history) + ) + continuation = self._token_history[pos + 1 : continuation_end] + + if len(continuation) > best_continuation_length: + best_continuation_length = len(continuation) + draft_tokens = continuation[: self.num_draft_tokens] + + if len(draft_tokens) >= self.min_matches: + self.total_drafts += 1 + self.total_draft_tokens += len(draft_tokens) + return draft_tokens + + return [] + + def record_accepted(self, num_accepted: int): + """Record statistics about accepted draft tokens.""" + if num_accepted > 0: + self.successful_drafts += 1 + self.accepted_tokens += num_accepted + + def get_stats(self) -> dict: + """Get decoder statistics.""" + acceptance_rate = 0.0 + if self.total_draft_tokens > 0: + acceptance_rate = self.accepted_tokens / self.total_draft_tokens + + return { + "total_drafts": self.total_drafts, + "successful_drafts": self.successful_drafts, + "total_draft_tokens": self.total_draft_tokens, + "accepted_tokens": self.accepted_tokens, + "acceptance_rate": acceptance_rate, + "history_size": len(self._token_history), + } + + +def prompt_lookup_generate_step( + prompt: mx.array, + model, + *, + num_draft_tokens: int = 4, + ngram_size: int = 3, + max_tokens: int = 256, + sampler=None, + logits_processors=None, + prompt_cache=None, + prefill_step_size: int = 512, +): + """ + Generator for token generation with prompt lookup speculation. + + This is a drop-in replacement for generate_step that uses n-gram + matching for speculation instead of a draft model. + + Args: + prompt: Input token array + model: The main model + num_draft_tokens: Number of tokens to draft + ngram_size: N-gram size for matching + max_tokens: Maximum tokens to generate + sampler: Token sampler (default: argmax) + logits_processors: Optional logits processors + prompt_cache: Optional KV cache + prefill_step_size: Prefill chunk size + + Yields: + Tuple of (token, logprobs, from_draft) + """ + from mlx_lm.models import cache + + y = prompt.astype(mx.uint32) + + # Create KV cache if not provided + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model) + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + # Initialize prompt lookup decoder + decoder = PromptLookupDecoder( + num_draft_tokens=num_draft_tokens, + ngram_size=ngram_size, + ) + + # Add prompt tokens to decoder + decoder.add_prompt_tokens(prompt.tolist()) + + def _step(tokens: mx.array, n_predict: int = 1): + """Run model on tokens and get predictions.""" + logits = model(tokens[None], cache=prompt_cache) + logits = logits[:, -n_predict:, :] + + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + sampled = sampler(logprobs) + + return sampled.squeeze(0), logprobs.squeeze(0) + + def _prefill(tokens: mx.array): + """Process prompt tokens in chunks.""" + while tokens.size > prefill_step_size: + model(tokens[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + tokens = tokens[prefill_step_size:] + mx.clear_cache() + return tokens + + # Prefill prompt + y = _prefill(y) + + # Generate first token + current_token, logprobs = _step(y) + mx.eval(current_token, logprobs) + + ntoks = 0 + + while ntoks < max_tokens: + # Yield current token + yield current_token.item(), logprobs, False + ntoks += 1 + + if ntoks >= max_tokens: + break + + # Add token to decoder history + decoder.add_generated_token(current_token.item()) + + # Try to get draft tokens via n-gram lookup + draft_tokens = decoder.get_draft_tokens() + + if draft_tokens and len(draft_tokens) > 0: + # Speculative path: verify draft tokens + verify_input = mx.array([current_token.item()] + draft_tokens, mx.uint32) + verified_tokens, verified_logprobs = _step( + verify_input, n_predict=len(draft_tokens) + 1 + ) + mx.eval(verified_tokens, verified_logprobs) + + verified_tokens = verified_tokens.tolist() + + # Check how many draft tokens were accepted + n_accepted = 0 + for i, (draft_t, verify_t) in enumerate( + zip(draft_tokens, verified_tokens[:-1]) + ): + if draft_t == verify_t: + n_accepted += 1 + ntoks += 1 + decoder.add_generated_token(draft_t) + yield draft_t, verified_logprobs[i], True # from_draft=True + if ntoks >= max_tokens: + break + else: + break + + decoder.record_accepted(n_accepted) + + if ntoks >= max_tokens: + break + + # Trim cache for rejected tokens + if n_accepted < len(draft_tokens): + cache.trim_prompt_cache(prompt_cache, len(draft_tokens) - n_accepted) + + # Next token is the first non-accepted verification result + current_token = mx.array(verified_tokens[n_accepted], mx.uint32) + logprobs = verified_logprobs[n_accepted] + else: + # Standard path: single token generation + next_token, next_logprobs = _step( + mx.array([current_token.item()], mx.uint32) + ) + mx.eval(next_token, next_logprobs) + current_token = next_token + logprobs = next_logprobs + + if ntoks % 256 == 0: + mx.clear_cache() + + # Log final stats + stats = decoder.get_stats() + if stats["total_drafts"] > 0: + logger.info( + f"Prompt Lookup stats: {stats['accepted_tokens']}/{stats['total_draft_tokens']} " + f"tokens accepted ({stats['acceptance_rate']:.1%})" + )