diff --git a/pyproject.toml b/pyproject.toml
index fa92e960..b036c78e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -77,6 +77,10 @@ vision = [
"torch>=2.3.0",
"torchvision>=0.18.0",
]
+# Guided decoding with outlines for structured JSON output
+guided = [
+ "outlines[mlxlm]>=1.0.0",
+]
# Audio dependencies for TTS/STT (mlx-audio)
audio = [
"mlx-audio>=0.2.9",
diff --git a/vllm_mlx/api/guided.py b/vllm_mlx/api/guided.py
new file mode 100644
index 00000000..4f20dd2b
--- /dev/null
+++ b/vllm_mlx/api/guided.py
@@ -0,0 +1,238 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Guided generation for structured JSON output using outlines.
+
+This module provides constrained decoding for JSON schema enforcement,
+ensuring model outputs strictly adhere to specified schemas.
+"""
+
+import logging
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+# Check for outlines availability
+try:
+ import mlx_lm
+ import outlines
+
+ HAS_OUTLINES = True
+except ImportError:
+ HAS_OUTLINES = False
+ outlines = None
+ mlx_lm = None
+
+
+def is_guided_available() -> bool:
+ """Check if guided generation with outlines is available."""
+ return HAS_OUTLINES
+
+
+def json_schema_to_pydantic(schema: dict[str, Any]) -> type | None:
+ """
+ Convert a JSON schema to a Pydantic model dynamically.
+
+ Args:
+ schema: JSON schema dict
+
+ Returns:
+ Dynamically created Pydantic model class, or None if conversion fails
+ """
+ try:
+ from pydantic import create_model
+
+ # Extract properties from schema
+ properties = schema.get("properties", {})
+ required = set(schema.get("required", []))
+
+ # Build field definitions for Pydantic
+ field_definitions = {}
+
+ type_mapping = {
+ "string": str,
+ "integer": int,
+ "number": float,
+ "boolean": bool,
+ "null": type(None),
+ }
+
+ for prop_name, prop_spec in properties.items():
+ prop_type = prop_spec.get("type", "string")
+
+ # Handle array type
+ if prop_type == "array":
+ items_type = prop_spec.get("items", {}).get("type", "string")
+ inner_type = type_mapping.get(items_type, str)
+ python_type = list[inner_type]
+ # Handle object type (nested)
+ elif prop_type == "object":
+ # For nested objects, use dict
+ python_type = dict
+ else:
+ python_type = type_mapping.get(prop_type, str)
+
+ # Make optional if not required
+ if prop_name not in required:
+ python_type = python_type | None
+ default = None
+ else:
+ default = ...
+
+ field_definitions[prop_name] = (python_type, default)
+
+ # Create the model dynamically
+ model = create_model("DynamicModel", **field_definitions)
+ return model
+
+ except Exception as e:
+ logger.warning(f"Failed to convert JSON schema to Pydantic: {e}")
+ return None
+
+
+class GuidedGenerator:
+ """
+ Guided generation using outlines for constrained JSON decoding.
+
+ This class wraps an MLX model to provide structured output generation
+ that guarantees valid JSON matching a specified schema.
+ """
+
+ def __init__(self, model, tokenizer):
+ """
+ Initialize the guided generator.
+
+ Args:
+ model: MLX model instance
+ tokenizer: Tokenizer instance
+ """
+ if not HAS_OUTLINES:
+ raise ImportError(
+ "outlines is required for guided generation. "
+ "Install with: pip install 'vllm-mlx[guided]'"
+ )
+
+ self._model = model
+ self._tokenizer = tokenizer
+ self._outlines_model = None
+
+ def _get_outlines_model(self):
+ """Get or create the outlines model wrapper."""
+ if self._outlines_model is None:
+ self._outlines_model = outlines.from_mlxlm(self._model, self._tokenizer)
+ return self._outlines_model
+
+ def generate_json(
+ self,
+ prompt: str,
+ json_schema: dict[str, Any],
+ max_tokens: int = 256,
+ temperature: float = 0.7,
+ ) -> str:
+ """
+ Generate JSON output constrained to a schema.
+
+ Args:
+ prompt: Input prompt
+ json_schema: JSON schema to constrain output
+ max_tokens: Maximum tokens to generate
+ temperature: Sampling temperature
+
+ Returns:
+ JSON string matching the schema
+ """
+ # Convert schema to Pydantic model
+ pydantic_model = json_schema_to_pydantic(json_schema)
+
+ if pydantic_model is None:
+ logger.warning(
+ "Could not convert schema to Pydantic, falling back to raw generation"
+ )
+ return None
+
+ try:
+ outlines_model = self._get_outlines_model()
+
+ # Generate with schema constraint
+ result = outlines_model(
+ prompt,
+ output_type=pydantic_model,
+ max_tokens=max_tokens,
+ )
+
+ # result is a JSON string, validate and return
+ return result
+
+ except Exception as e:
+ logger.error(f"Guided generation failed: {e}")
+ return None
+
+ def generate_json_object(
+ self,
+ prompt: str,
+ max_tokens: int = 256,
+ temperature: float = 0.7,
+ ) -> str:
+ """
+ Generate any valid JSON object.
+
+ Args:
+ prompt: Input prompt
+ max_tokens: Maximum tokens to generate
+ temperature: Sampling temperature
+
+ Returns:
+ JSON string
+ """
+ try:
+ from outlines import generate
+
+ outlines_model = self._get_outlines_model()
+
+ # Use regex to constrain to valid JSON
+ json_regex = r"\{[^{}]*\}"
+ generator = generate.regex(outlines_model, json_regex)
+ result = generator(prompt, max_tokens=max_tokens)
+
+ return result
+
+ except Exception as e:
+ logger.error(f"JSON object generation failed: {e}")
+ return None
+
+
+def generate_with_schema(
+ model,
+ tokenizer,
+ prompt: str,
+ json_schema: dict[str, Any],
+ max_tokens: int = 256,
+ temperature: float = 0.7,
+) -> str | None:
+ """
+ Convenience function for one-shot guided JSON generation.
+
+ Args:
+ model: MLX model
+ tokenizer: Tokenizer
+ prompt: Input prompt
+ json_schema: JSON schema
+ max_tokens: Maximum tokens
+ temperature: Sampling temperature
+
+ Returns:
+ JSON string or None if guided generation unavailable/failed
+ """
+ if not HAS_OUTLINES:
+ return None
+
+ try:
+ generator = GuidedGenerator(model, tokenizer)
+ return generator.generate_json(
+ prompt=prompt,
+ json_schema=json_schema,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ )
+ except Exception as e:
+ logger.error(f"generate_with_schema failed: {e}")
+ return None
diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py
index 1443c167..764159b6 100644
--- a/vllm_mlx/api/tool_calling.py
+++ b/vllm_mlx/api/tool_calling.py
@@ -21,6 +21,42 @@
from .models import FunctionCall, ResponseFormat, ToolCall
+def _is_tool_call_json(obj: dict) -> bool:
+ """
+ Check if a JSON object looks like a tool call.
+
+ A tool call must have:
+ - "name" key with a string value (function name)
+ - "arguments" key (the function arguments)
+
+ This prevents false positives where regular JSON like {"name": "John", "age": 30}
+ would be incorrectly parsed as a tool call.
+
+ Args:
+ obj: JSON object to check
+
+ Returns:
+ True if object appears to be a tool call
+ """
+ if not isinstance(obj, dict):
+ return False
+
+ # Must have both "name" and "arguments" keys
+ if "name" not in obj or "arguments" not in obj:
+ return False
+
+ # "name" must be a non-empty string (function name)
+ if not isinstance(obj["name"], str) or not obj["name"].strip():
+ return False
+
+ # "arguments" must be a dict or string
+ args = obj["arguments"]
+ if not isinstance(args, (dict, str)):
+ return False
+
+ return True
+
+
def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]:
"""
Parse raw JSON tool calls from model output.
@@ -30,6 +66,9 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]:
- Multiple objects separated by commas: {...}, {...}
- JSON array: [{...}, {...}]
+ Only objects with BOTH "name" AND "arguments" keys are considered tool calls.
+ This prevents false positives with regular JSON objects.
+
Args:
text: Raw model output text
@@ -46,7 +85,7 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]:
try:
parsed = json.loads(text)
if isinstance(parsed, list) and all(
- isinstance(item, dict) and "name" in item for item in parsed
+ _is_tool_call_json(item) for item in parsed
):
return [
{"name": item["name"], "arguments": item.get("arguments", {})}
@@ -71,7 +110,8 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]:
json_str = text[start : i + 1]
try:
obj = json.loads(json_str)
- if isinstance(obj, dict) and "name" in obj:
+ # Only consider as tool call if it has both "name" AND "arguments"
+ if _is_tool_call_json(obj):
tool_calls.append(
{"name": obj["name"], "arguments": obj.get("arguments", {})}
)
@@ -524,8 +564,14 @@ def build_json_system_prompt(
if format_type == "json_object":
return (
- "You must respond with valid JSON only. "
- "Do not include any explanation or text outside the JSON object."
+ "⚠️ JSON OUTPUT REQUIRED ⚠️\n\n"
+ "You MUST respond with ONLY a valid JSON object.\n\n"
+ "RULES:\n"
+ "- Start response with { or [\n"
+ "- NO text before or after JSON\n"
+ "- NO thinking or explanations\n"
+ "- NO markdown code blocks (```)\n"
+ "- ONLY the raw JSON object"
)
if format_type == "json_schema":
@@ -533,14 +579,89 @@ def build_json_system_prompt(
schema = json_schema_spec.get("schema", {})
name = json_schema_spec.get("name", "response")
description = json_schema_spec.get("description", "")
+ strict = json_schema_spec.get("strict", False)
+
+ # If strict mode is enabled and guided generation is available,
+ # the server should use guided decoding instead of prompt injection
+ if strict:
+ # Return stronger instruction for strict mode
+ prompt = (
+ f"⚠️ STRICT JSON OUTPUT REQUIRED ⚠️\n\n"
+ f"You MUST respond with ONLY a valid JSON object matching the '{name}' schema.\n"
+ )
+ if description:
+ prompt += f"Purpose: {description}\n"
+ prompt += (
+ f"\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n"
+ "STRICT RULES:\n"
+ "- Start response with {{ or [\n"
+ "- NO text before or after JSON\n"
+ "- NO thinking, reasoning, or explanations\n"
+ "- NO markdown code blocks\n"
+ "- ONLY the JSON object"
+ )
+ return prompt
- prompt = f"You must respond with valid JSON matching the '{name}' schema."
+ # Standard (non-strict) mode - still strong but less aggressive
+ prompt = (
+ f"⚠️ JSON OUTPUT REQUIRED ⚠️\n\n"
+ f"Respond with a valid JSON object matching the '{name}' schema.\n"
+ )
if description:
- prompt += f" {description}"
+ prompt += f"Purpose: {description}\n"
prompt += (
- f"\n\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n"
- "Respond with only the JSON object, no additional text or explanation."
+ f"\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n"
+ "RULES:\n"
+ "- Start response with { or [\n"
+ "- NO text before or after JSON\n"
+ "- NO markdown code blocks\n"
+ "- ONLY the JSON object"
)
return prompt
return None
+
+
+def extract_json_schema_for_guided(
+ response_format: Optional[Union[ResponseFormat, Dict[str, Any]]] = None,
+) -> Optional[Dict[str, Any]]:
+ """
+ Extract JSON schema from response_format for guided generation.
+
+ Returns the schema dict if response_format specifies json_schema type,
+ otherwise returns None.
+
+ Args:
+ response_format: ResponseFormat specification
+
+ Returns:
+ JSON schema dict or None
+ """
+ if response_format is None:
+ return None
+
+ # Normalize to dict
+ if isinstance(response_format, ResponseFormat):
+ rf_dict = {"type": response_format.type, "json_schema": None}
+ if response_format.json_schema:
+ rf_dict["json_schema"] = {
+ "name": response_format.json_schema.name,
+ "description": response_format.json_schema.description,
+ "schema": response_format.json_schema.schema_,
+ "strict": response_format.json_schema.strict,
+ }
+ else:
+ rf_dict = response_format
+
+ format_type = rf_dict.get("type", "text")
+
+ if format_type != "json_schema":
+ return None
+
+ json_schema_spec = rf_dict.get("json_schema", {})
+ schema = json_schema_spec.get("schema", {})
+
+ if not schema:
+ return None
+
+ return schema
diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py
index e916ae19..bcc1a405 100644
--- a/vllm_mlx/api/utils.py
+++ b/vllm_mlx/api/utils.py
@@ -101,6 +101,86 @@ def clean_output_text(text: str) -> str:
return text
+# Pattern to match ... blocks
+THINK_PATTERN = re.compile(r"[\s\S]*?\s*", re.DOTALL)
+
+
+def strip_thinking_tags(text: str) -> str:
+ """
+ Remove ... blocks from model output.
+
+ Used when the client expects pure content (e.g., JSON) without
+ reasoning blocks that would break parsing.
+
+ Args:
+ text: Model output that may contain thinking blocks
+
+ Returns:
+ Text with thinking blocks removed
+ """
+ if not text:
+ return text
+ return THINK_PATTERN.sub("", text).strip()
+
+
+def extract_json_from_response(text: str) -> str:
+ """
+ Extract JSON object/array from model response that may contain reasoning text.
+
+ Qwen3 and other reasoning models often output:
+ "Let me think... reasoning text... {\"result\": 123}"
+
+ This function extracts just the JSON part when present.
+
+ Args:
+ text: Model output that may contain text before/after JSON
+
+ Returns:
+ Extracted JSON string if found, otherwise original text
+ """
+ if not text:
+ return text
+
+ text = text.strip()
+
+ # If already valid JSON, return as-is
+ if (text.startswith("{") and text.endswith("}")) or (
+ text.startswith("[") and text.endswith("]")
+ ):
+ return text
+
+ # Try to find JSON object at the end of the response
+ # Find the last { and match to the end
+ last_brace = text.rfind("{")
+ if last_brace != -1 and text.endswith("}"):
+ potential_json = text[last_brace:]
+ # Validate it's balanced
+ depth = 0
+ for char in potential_json:
+ if char == "{":
+ depth += 1
+ elif char == "}":
+ depth -= 1
+ if depth == 0:
+ return potential_json
+
+ # Try to find JSON array at the end
+ last_bracket = text.rfind("[")
+ if last_bracket != -1 and text.endswith("]"):
+ potential_json = text[last_bracket:]
+ depth = 0
+ for char in potential_json:
+ if char == "[":
+ depth += 1
+ elif char == "]":
+ depth -= 1
+ if depth == 0:
+ return potential_json
+
+ # No JSON found, return original
+ return text
+
+
# =============================================================================
# Model Detection
# =============================================================================
diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py
index dcbee8ac..c9923296 100644
--- a/vllm_mlx/cli.py
+++ b/vllm_mlx/cli.py
@@ -107,6 +107,10 @@ def serve_command(args):
print(f"Loading model: {args.model}")
print(f"Default max tokens: {args.max_tokens}")
+ if args.draft_model:
+ print("Speculative decoding: ENABLED")
+ print(f" Draft model: {args.draft_model}")
+ print(f" Draft tokens: {args.num_draft_tokens}")
# Store MCP config path for FastAPI startup
if args.mcp_config:
@@ -187,6 +191,8 @@ def serve_command(args):
stream_interval=args.stream_interval if args.continuous_batching else 1,
max_tokens=args.max_tokens,
force_mllm=args.mllm,
+ draft_model=args.draft_model,
+ num_draft_tokens=args.num_draft_tokens,
)
# Start server
@@ -786,6 +792,19 @@ def main():
"Required for --enable-auto-tool-choice."
),
)
+ # Speculative decoding options
+ serve_parser.add_argument(
+ "--draft-model",
+ type=str,
+ default=None,
+ help="Draft model for speculative decoding (must use same tokenizer as main model)",
+ )
+ serve_parser.add_argument(
+ "--num-draft-tokens",
+ type=int,
+ default=4,
+ help="Number of tokens to generate speculatively per step (default: 4)",
+ )
# Reasoning parser options - choices loaded dynamically from registry
from .reasoning import list_parsers
diff --git a/vllm_mlx/engine/__init__.py b/vllm_mlx/engine/__init__.py
index f6625abd..d33c4bc4 100644
--- a/vllm_mlx/engine/__init__.py
+++ b/vllm_mlx/engine/__init__.py
@@ -12,6 +12,7 @@
from .base import BaseEngine, GenerationOutput
from .simple import SimpleEngine
from .batched import BatchedEngine
+from .hybrid import HybridEngine
# Re-export from parent engine.py for backwards compatibility
from ..engine_core import EngineCore, AsyncEngineCore, EngineConfig
@@ -21,6 +22,7 @@
"GenerationOutput",
"SimpleEngine",
"BatchedEngine",
+ "HybridEngine",
# Core engine components
"EngineCore",
"AsyncEngineCore",
diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py
index ce33e628..05906192 100644
--- a/vllm_mlx/engine/batched.py
+++ b/vllm_mlx/engine/batched.py
@@ -161,6 +161,7 @@ def __init__(
self._mllm_scheduler = None # MLLMScheduler for MLLM
self._mllm_instance = None # MLXMultimodalLM instance
self._loaded = False
+ self._engine_started = False # Track if engine loop is running
@property
def model_name(self) -> str:
@@ -800,3 +801,52 @@ def load_cache_from_disk(self, cache_dir: str) -> int:
if self._engine:
return self._engine.load_cache_from_disk(cache_dir)
return 0
+
+ async def _inject_shared_model(
+ self,
+ model,
+ tokenizer,
+ start_engine: bool = True,
+ ) -> None:
+ """
+ Inject a pre-loaded shared model instead of loading a new one.
+
+ This is used by HybridEngine to share a single model instance
+ between SimpleEngine and BatchedEngine, saving ~44GB of RAM.
+
+ Args:
+ model: Pre-loaded MLX model
+ tokenizer: Pre-loaded tokenizer
+ start_engine: Whether to start the engine loop immediately.
+ Set to False for HybridEngine (lazy start on first use).
+ """
+ from ..engine_core import AsyncEngineCore, EngineConfig
+ from ..scheduler import SchedulerConfig
+
+ self._model = model
+ self._tokenizer = tokenizer
+
+ # Create engine config
+ scheduler_config = self._scheduler_config or SchedulerConfig()
+ engine_config = EngineConfig(
+ model_name=self._model_name,
+ scheduler_config=scheduler_config,
+ stream_interval=self._stream_interval,
+ )
+
+ # Create async engine with shared model
+ self._engine = AsyncEngineCore(
+ model=self._model,
+ tokenizer=self._tokenizer,
+ config=engine_config,
+ )
+
+ # Only start engine loop if requested (HybridEngine starts lazily)
+ if start_engine:
+ await self._engine.engine.start()
+
+ self._loaded = True
+ self._engine_started = start_engine
+ logger.info(
+ f"BatchedEngine injected with shared model: {self._model_name} (started={start_engine})"
+ )
diff --git a/vllm_mlx/engine/hybrid.py b/vllm_mlx/engine/hybrid.py
new file mode 100644
index 00000000..c0b7c2aa
--- /dev/null
+++ b/vllm_mlx/engine/hybrid.py
@@ -0,0 +1,509 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Hybrid engine combining speculative decoding with continuous batching.
+
+This engine shares a single model instance between SimpleEngine and BatchedEngine,
+saving ~44GB RAM while providing:
+- Speculative decoding (80+ tok/s) for single-user scenarios
+- Continuous batching (60-70 tok/s) for multi-user scenarios
+
+The engine automatically switches between modes based on concurrent request count.
+
+Architecture:
+ HybridEngine
+ ├── Shared components (loaded once)
+ │ ├── _shared_model (44GB)
+ │ ├── _shared_tokenizer
+ │ └── _ownership_lock
+ │
+ ├── SimpleEngine (speculative)
+ │ └── + draft_model (350MB)
+ │
+ └── BatchedEngine (batching)
+ └── Scheduler + BatchGenerator
+
+Mode switching logic:
+ if active_requests < switch_threshold:
+ → SimpleEngine (speculative, 80+ tok/s)
+ else:
+ → BatchedEngine (batching, 60-70 tok/s)
+"""
+
+import asyncio
+import logging
+from collections.abc import AsyncIterator
+from typing import Any
+
+from mlx_lm import load
+
+from .base import BaseEngine, GenerationOutput
+from .batched import BatchedEngine
+from .simple import SimpleEngine
+from ..model_registry import get_registry
+
+logger = logging.getLogger(__name__)
+
+
+class HybridEngine(BaseEngine):
+ """
+ Hybrid engine: speculative decoding for single-user,
+ continuous batching for multi-user. Shares model instance.
+
+ This engine provides the best of both worlds:
+ - When serving a single user: uses SimpleEngine with speculative decoding
+ for maximum throughput (80+ tok/s with Qwen3-0.6B draft model)
+ - When serving multiple concurrent users: switches to BatchedEngine
+ for efficient continuous batching
+
+ The key innovation is sharing a single model instance (~44GB for Qwen3-Next-80B)
+ between both engines, cutting memory usage in half compared to running
+ separate servers.
+ """
+
+ def __init__(
+ self,
+ model_name: str,
+ draft_model: str | None = None,
+ num_draft_tokens: int = 5,
+ scheduler_config: Any | None = None,
+ stream_interval: int = 1,
+ trust_remote_code: bool = True,
+ force_mllm: bool = False,
+ switch_threshold: int = 2,
+ ):
+ """
+ Initialize the hybrid engine.
+
+ Args:
+ model_name: HuggingFace model name or local path
+ draft_model: Draft model for speculative decoding (e.g., Qwen3-0.6B-4bit)
+ num_draft_tokens: Number of tokens to generate speculatively (default: 5)
+ scheduler_config: Scheduler config for batched mode
+ stream_interval: Tokens to batch before streaming (batched mode only)
+ trust_remote_code: Whether to trust remote code
+ force_mllm: Force loading as MLLM even if not auto-detected
+ switch_threshold: Number of concurrent requests to trigger batch mode (default: 2)
+ """
+ self._model_name = model_name
+ self._draft_model_name = draft_model
+ self._num_draft_tokens = num_draft_tokens
+ self._scheduler_config = scheduler_config
+ self._stream_interval = stream_interval
+ self._trust_remote_code = trust_remote_code
+ self._force_mllm = force_mllm
+ self._switch_threshold = switch_threshold
+
+ # Shared resources (loaded once)
+ self._shared_model = None
+ self._shared_tokenizer = None
+
+ # Engine instances
+ self._simple: SimpleEngine | None = None
+ self._batched: BatchedEngine | None = None
+ self._current_mode: str | None = None # 'simple' or 'batched'
+
+ # Concurrency tracking
+ self._active_requests = 0
+ self._lock = asyncio.Lock()
+ self._switch_lock = asyncio.Lock()
+
+ # State
+ self._loaded = False
+ self._is_mllm = False
+
+ @property
+ def model_name(self) -> str:
+ """Get the model name."""
+ return self._model_name
+
+ @property
+ def is_mllm(self) -> bool:
+ """Check if this is a multimodal model."""
+ return self._is_mllm
+
+ @property
+ def tokenizer(self) -> Any:
+ """Get the tokenizer."""
+ return self._shared_tokenizer
+
+ async def start(self) -> None:
+ """Start the engine (load shared model and initialize sub-engines)."""
+ if self._loaded:
+ return
+
+ logger.info(f"HybridEngine loading shared model: {self._model_name}")
+
+ # Load model once using mlx-lm
+ self._shared_model, self._shared_tokenizer = load(self._model_name)
+
+ # Check if MLLM
+ from ..api.utils import is_mllm_model
+
+ self._is_mllm = self._force_mllm or is_mllm_model(self._model_name)
+
+ if self._is_mllm:
+ logger.warning(
+ "HybridEngine does not support MLLM models yet. "
+ "Using BatchedEngine only."
+ )
+ # For MLLM, just use BatchedEngine (no speculative)
+ self._batched = BatchedEngine(
+ model_name=self._model_name,
+ scheduler_config=self._scheduler_config,
+ stream_interval=self._stream_interval,
+ force_mllm=True,
+ )
+ await self._batched.start()
+ self._current_mode = "batched"
+ else:
+ # Create SimpleEngine with draft model support
+ self._simple = SimpleEngine(
+ model_name=self._model_name,
+ trust_remote_code=self._trust_remote_code,
+ draft_model=self._draft_model_name,
+ num_draft_tokens=self._num_draft_tokens,
+ )
+ # Inject shared model instead of loading again
+ await self._simple._inject_shared_model(
+ self._shared_model,
+ self._shared_tokenizer,
+ )
+
+ # Create BatchedEngine (lazy start - don't start engine loop yet)
+ self._batched = BatchedEngine(
+ model_name=self._model_name,
+ trust_remote_code=self._trust_remote_code,
+ scheduler_config=self._scheduler_config,
+ stream_interval=self._stream_interval,
+ )
+ # Inject shared model but DON'T start engine loop yet
+ # Engine will be started on first switch to batched mode
+ await self._batched._inject_shared_model(
+ self._shared_model,
+ self._shared_tokenizer,
+ start_engine=False, # Lazy start for HybridEngine
+ )
+ self._batched_engine_started = False
+
+ # Start in simple mode (speculative decoding)
+ self._current_mode = "simple"
+
+ self._loaded = True
+
+ spec_info = ""
+ if self._draft_model_name and not self._is_mllm:
+ spec_info = f", draft={self._draft_model_name}, k={self._num_draft_tokens}"
+
+ logger.info(
+ f"HybridEngine ready: {self._model_name} "
+ f"(mode={self._current_mode}, threshold={self._switch_threshold}{spec_info})"
+ )
+
+ async def stop(self) -> None:
+ """Stop the engine and cleanup resources."""
+ if self._simple:
+ await self._simple.stop()
+ self._simple = None
+
+ if self._batched:
+ await self._batched.stop()
+ self._batched = None
+
+ self._shared_model = None
+ self._shared_tokenizer = None
+ self._loaded = False
+ self._current_mode = None
+
+ logger.info("HybridEngine stopped")
+
+ def _get_engine_for_request(self) -> BaseEngine:
+ """
+ Get the appropriate engine for the current request.
+
+ Note: This doesn't switch modes - it returns the engine based on
+ current mode. Mode switching happens separately.
+ """
+ # For MLLM, always use batched
+ if self._is_mllm:
+ return self._batched
+
+ return self._simple if self._current_mode == "simple" else self._batched
+
+ async def _switch_to_mode(self, target_mode: str) -> None:
+ """
+ Switch to the specified mode, handling ownership transfer.
+
+ This method ensures proper model ownership transfer between engines
+ to prevent KV cache conflicts.
+
+ Args:
+ target_mode: 'simple' or 'batched'
+ """
+ if self._current_mode == target_mode:
+ return
+
+ async with self._switch_lock:
+ # Double-check after acquiring lock
+ if self._current_mode == target_mode:
+ return
+
+ old_mode = self._current_mode
+
+ if target_mode == "batched":
+ # Switching to batched mode
+ if self._batched and self._batched._engine:
+ # Start BatchedEngine's engine loop if not started yet (lazy start)
+ if not getattr(self, "_batched_engine_started", True):
+ logger.info("HybridEngine: starting BatchedEngine (lazy start)")
+ await self._batched._engine.engine.start()
+ self._batched_engine_started = True
+
+ # BatchedEngine needs model ownership for its BatchGenerator
+ registry = get_registry()
+ try:
+ registry.acquire(
+ model=self._shared_model,
+ engine=self._batched._engine.engine,
+ engine_id=self._batched._engine.engine.engine_id,
+ force=True, # Force transfer from SimpleEngine
+ )
+ logger.info(
+ "HybridEngine: model ownership transferred to BatchedEngine"
+ )
+ except Exception as e:
+ logger.warning(f"Ownership transfer failed: {e}")
+ # Continue anyway - the transfer may have succeeded partially
+ else:
+ # Switching to simple mode
+ # SimpleEngine doesn't need registry ownership (uses mlx_lm.generate directly)
+ # Just reset BatchedEngine's scheduler to clear KV cache state
+ if self._batched and self._batched._engine:
+ try:
+ self._batched._engine.engine.scheduler.deep_reset()
+ logger.debug("Reset BatchedEngine scheduler for mode switch")
+ except Exception as e:
+ logger.warning(f"Failed to reset BatchedEngine: {e}")
+
+ self._current_mode = target_mode
+ logger.info(f"HybridEngine: mode switched {old_mode} -> {target_mode}")
+
+ async def _decide_and_switch_mode(self, entering: bool = True) -> str:
+ """
+ Decide which mode to use and switch if necessary.
+
+ Args:
+ entering: True when entering a request, False when exiting
+
+ Returns:
+ The mode to use for this request ('simple' or 'batched')
+ """
+ if self._is_mllm:
+ return "batched"
+
+ # Decide based on active request count
+ # When entering: count includes this request
+ # When exiting: count excludes this request
+ active = self._active_requests
+
+ if active >= self._switch_threshold:
+ target_mode = "batched"
+ else:
+ target_mode = "simple"
+
+ # Only switch when safe (no active requests in the other engine)
+ # This is a simplified heuristic - we switch when crossing the threshold
+ if target_mode != self._current_mode:
+ # Log the decision
+ logger.debug(
+ f"HybridEngine: active_requests={active}, "
+ f"threshold={self._switch_threshold}, "
+ f"current={self._current_mode}, target={target_mode}"
+ )
+
+ # Only switch to batched immediately, switch to simple when quiet
+ if target_mode == "batched":
+ await self._switch_to_mode("batched")
+ elif active == 0:
+ # Switch back to simple only when completely idle
+ await self._switch_to_mode("simple")
+
+ return self._current_mode
+
+ async def generate(
+ self,
+ prompt: str,
+ max_tokens: int = 256,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ stop: list[str] | None = None,
+ **kwargs,
+ ) -> GenerationOutput:
+ """Generate a complete response (non-streaming)."""
+ if not self._loaded:
+ await self.start()
+
+ async with self._lock:
+ self._active_requests += 1
+
+ try:
+ # Decide mode and switch if needed
+ mode = await self._decide_and_switch_mode(entering=True)
+ engine = self._simple if mode == "simple" else self._batched
+
+ return await engine.generate(
+ prompt=prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ stop=stop,
+ **kwargs,
+ )
+ finally:
+ async with self._lock:
+ self._active_requests -= 1
+ # Check if we should switch back to simple mode
+ await self._decide_and_switch_mode(entering=False)
+
+ async def stream_generate(
+ self,
+ prompt: str,
+ max_tokens: int = 256,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ stop: list[str] | None = None,
+ **kwargs,
+ ) -> AsyncIterator[GenerationOutput]:
+ """Stream generation token by token."""
+ if not self._loaded:
+ await self.start()
+
+ async with self._lock:
+ self._active_requests += 1
+
+ try:
+ # Decide mode and switch if needed
+ mode = await self._decide_and_switch_mode(entering=True)
+ engine = self._simple if mode == "simple" else self._batched
+
+ async for output in engine.stream_generate(
+ prompt=prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ stop=stop,
+ **kwargs,
+ ):
+ yield output
+ finally:
+ async with self._lock:
+ self._active_requests -= 1
+ # Check if we should switch back to simple mode
+ await self._decide_and_switch_mode(entering=False)
+
+ async def chat(
+ self,
+ messages: list[dict[str, Any]],
+ max_tokens: int = 256,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ tools: list[dict] | None = None,
+ images: list[str] | None = None,
+ videos: list[str] | None = None,
+ **kwargs,
+ ) -> GenerationOutput:
+ """Chat completion (non-streaming)."""
+ if not self._loaded:
+ await self.start()
+
+ async with self._lock:
+ self._active_requests += 1
+
+ try:
+ # Decide mode and switch if needed
+ mode = await self._decide_and_switch_mode(entering=True)
+ engine = self._simple if mode == "simple" else self._batched
+
+ return await engine.chat(
+ messages=messages,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ tools=tools,
+ images=images,
+ videos=videos,
+ **kwargs,
+ )
+ finally:
+ async with self._lock:
+ self._active_requests -= 1
+ # Check if we should switch back to simple mode
+ await self._decide_and_switch_mode(entering=False)
+
+ async def stream_chat(
+ self,
+ messages: list[dict[str, Any]],
+ max_tokens: int = 256,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ tools: list[dict] | None = None,
+ images: list[str] | None = None,
+ videos: list[str] | None = None,
+ **kwargs,
+ ) -> AsyncIterator[GenerationOutput]:
+ """Stream chat completion token by token."""
+ if not self._loaded:
+ await self.start()
+
+ async with self._lock:
+ self._active_requests += 1
+
+ try:
+ # Decide mode and switch if needed
+ mode = await self._decide_and_switch_mode(entering=True)
+ engine = self._simple if mode == "simple" else self._batched
+
+ async for output in engine.stream_chat(
+ messages=messages,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ tools=tools,
+ images=images,
+ videos=videos,
+ **kwargs,
+ ):
+ yield output
+ finally:
+ async with self._lock:
+ self._active_requests -= 1
+ # Check if we should switch back to simple mode
+ await self._decide_and_switch_mode(entering=False)
+
+ def get_stats(self) -> dict[str, Any]:
+ """Get engine statistics."""
+ stats = {
+ "engine_type": "hybrid",
+ "model_name": self._model_name,
+ "is_mllm": self._is_mllm,
+ "loaded": self._loaded,
+ "current_mode": self._current_mode,
+ "active_requests": self._active_requests,
+ "switch_threshold": self._switch_threshold,
+ "draft_model": self._draft_model_name,
+ "num_draft_tokens": self._num_draft_tokens,
+ }
+
+ if self._simple:
+ stats["simple_engine"] = self._simple.get_stats()
+ if self._batched:
+ stats["batched_engine"] = self._batched.get_stats()
+
+ return stats
+
+ def get_cache_stats(self) -> dict[str, Any] | None:
+ """Get cache statistics from active engine."""
+ if self._current_mode == "simple" and self._simple:
+ return self._simple.get_cache_stats()
+ elif self._batched:
+ return self._batched.get_cache_stats()
+ return None
diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py
index ed118bbb..df42b3e8 100644
--- a/vllm_mlx/engine/simple.py
+++ b/vllm_mlx/engine/simple.py
@@ -17,6 +17,15 @@
logger = logging.getLogger(__name__)
+# Check for guided generation availability
+try:
+ from ..api.guided import GuidedGenerator, is_guided_available
+
+ HAS_GUIDED = is_guided_available()
+except ImportError:
+ HAS_GUIDED = False
+ GuidedGenerator = None
+
class SimpleEngine(BaseEngine):
"""
@@ -32,6 +41,8 @@ def __init__(
trust_remote_code: bool = True,
enable_cache: bool = True,
force_mllm: bool = False,
+ draft_model: str | None = None,
+ num_draft_tokens: int = 4,
):
"""
Initialize the simple engine.
@@ -41,11 +52,15 @@ def __init__(
trust_remote_code: Whether to trust remote code
enable_cache: Enable VLM cache for multimodal models
force_mllm: Force loading as MLLM even if not auto-detected
+ draft_model: Optional draft model path for speculative decoding
+ num_draft_tokens: Number of tokens to generate speculatively per step
"""
self._model_name = model_name
self._trust_remote_code = trust_remote_code
self._enable_cache = enable_cache
self._is_mllm = force_mllm or is_mllm_model(model_name)
+ self._draft_model_name = draft_model
+ self._num_draft_tokens = num_draft_tokens
self._model = None
self._loaded = False
@@ -85,17 +100,27 @@ async def start(self) -> None:
trust_remote_code=self._trust_remote_code,
enable_cache=self._enable_cache,
)
+ if self._draft_model_name:
+ logger.warning("Speculative decoding is not supported with MLLM models")
else:
from ..models.llm import MLXLanguageModel
self._model = MLXLanguageModel(
self._model_name,
trust_remote_code=self._trust_remote_code,
+ draft_model=self._draft_model_name,
+ num_draft_tokens=self._num_draft_tokens,
)
self._model.load()
self._loaded = True
- logger.info(f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm})")
+
+ spec_info = ""
+ if self._draft_model_name and not self._is_mllm:
+ spec_info = f", speculative={self._draft_model_name}"
+ logger.info(
+ f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm}{spec_info})"
+ )
async def stop(self) -> None:
"""Stop the engine and cleanup resources."""
@@ -440,3 +465,174 @@ def get_cache_stats(self) -> dict[str, Any] | None:
if self._is_mllm and self._model is not None:
return self._model.get_cache_stats()
return None
+
+ async def _inject_shared_model(
+ self,
+ model,
+ tokenizer,
+ ) -> None:
+ """
+ Inject a pre-loaded shared model instead of loading a new one.
+
+ This is used by HybridEngine to share a single model instance
+ between SimpleEngine and BatchedEngine, saving ~44GB of RAM.
+
+ Args:
+ model: Pre-loaded MLX model
+ tokenizer: Pre-loaded tokenizer
+ """
+ from ..models.llm import MLXLanguageModel
+
+ # Create MLXLanguageModel wrapper without loading
+ self._model = MLXLanguageModel.__new__(MLXLanguageModel)
+ self._model.model_name = self._model_name
+ self._model.tokenizer_name = self._model_name
+ self._model.trust_remote_code = self._trust_remote_code
+ self._model.draft_model_name = self._draft_model_name
+ self._model.num_draft_tokens = self._num_draft_tokens
+ self._model.model = model
+ self._model.tokenizer = tokenizer
+ self._model.draft_model = None
+ self._model._loaded = True
+
+ # Load draft model separately if specified
+ if self._draft_model_name:
+ from mlx_lm import load as mlx_load
+
+ logger.info(
+ f"Loading draft model for speculative decoding: {self._draft_model_name}"
+ )
+ self._model.draft_model, draft_tokenizer = mlx_load(self._draft_model_name)
+
+ # Validate tokenizer compatibility
+ if draft_tokenizer.vocab_size != tokenizer.vocab_size:
+ logger.warning(
+ f"Draft model tokenizer vocab size ({draft_tokenizer.vocab_size}) "
+ f"differs from main model ({tokenizer.vocab_size}). "
+ "This may reduce speculative decoding effectiveness."
+ )
+
+ logger.info(
+ f"Speculative decoding enabled: draft={self._draft_model_name}, "
+ f"num_draft_tokens={self._num_draft_tokens}"
+ )
+
+ self._loaded = True
+ logger.info(f"SimpleEngine injected with shared model: {self._model_name}")
+
+ @property
+ def supports_guided_generation(self) -> bool:
+ """Check if guided generation is available."""
+ return HAS_GUIDED and not self._is_mllm
+
+ async def generate_with_schema(
+ self,
+ messages: list[dict[str, Any]],
+ json_schema: dict[str, Any],
+ max_tokens: int = 256,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ **kwargs,
+ ) -> GenerationOutput:
+ """
+ Generate JSON output constrained to a schema using guided decoding.
+
+ This method uses outlines for constrained generation to guarantee
+ the output is valid JSON matching the specified schema.
+
+ Args:
+ messages: List of chat messages
+ json_schema: JSON schema to constrain output
+ max_tokens: Maximum tokens to generate
+ temperature: Sampling temperature
+ top_p: Top-p sampling
+ **kwargs: Additional parameters
+
+ Returns:
+ GenerationOutput with JSON text matching the schema
+ """
+ if not self.supports_guided_generation:
+ raise RuntimeError(
+ "Guided generation not available. "
+ "Install with: pip install 'vllm-mlx[guided]'"
+ )
+
+ if not self._loaded:
+ await self.start()
+
+ # Build prompt from messages
+ tokenizer = self._model.tokenizer
+ if hasattr(tokenizer, "apply_chat_template"):
+ prompt = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ else:
+ prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
+ prompt += "\nassistant:"
+
+ async with self._generation_lock:
+ # Run guided generation in thread pool
+ result = await asyncio.to_thread(
+ self._run_guided_generation,
+ prompt=prompt,
+ json_schema=json_schema,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ )
+
+ if result is None:
+ # Fallback to regular generation
+ logger.warning(
+ "Guided generation failed, falling back to regular generation"
+ )
+ return await self.generate(
+ prompt=prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ **kwargs,
+ )
+
+ # Tokenize for completion count
+ tokens = tokenizer.encode(result)
+
+ return GenerationOutput(
+ text=result,
+ tokens=tokens,
+ prompt_tokens=len(tokenizer.encode(prompt)),
+ completion_tokens=len(tokens),
+ finish_reason="stop",
+ )
+
+ def _run_guided_generation(
+ self,
+ prompt: str,
+ json_schema: dict[str, Any],
+ max_tokens: int,
+ temperature: float,
+ ) -> str | None:
+ """
+ Run guided generation synchronously (called from thread pool).
+
+ Args:
+ prompt: Input prompt
+ json_schema: JSON schema
+ max_tokens: Maximum tokens
+ temperature: Sampling temperature
+
+ Returns:
+ JSON string or None if failed
+ """
+ try:
+ generator = GuidedGenerator(self._model.model, self._model.tokenizer)
+ return generator.generate_json(
+ prompt=prompt,
+ json_schema=json_schema,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ )
+ except Exception as e:
+ logger.error(f"Guided generation error: {e}")
+ return None
diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py
index 092c060e..1c75fb3a 100644
--- a/vllm_mlx/models/llm.py
+++ b/vllm_mlx/models/llm.py
@@ -50,6 +50,8 @@ def __init__(
model_name: str,
tokenizer_name: str | None = None,
trust_remote_code: bool = False,
+ draft_model: str | None = None,
+ num_draft_tokens: int = 4,
):
"""
Initialize the MLX language model.
@@ -58,13 +60,18 @@ def __init__(
model_name: HuggingFace model name or local path
tokenizer_name: Optional separate tokenizer name
trust_remote_code: Whether to trust remote code
+ draft_model: Optional draft model path for speculative decoding
+ num_draft_tokens: Number of tokens to generate speculatively per step
"""
self.model_name = model_name
self.tokenizer_name = tokenizer_name or model_name
self.trust_remote_code = trust_remote_code
+ self.draft_model_name = draft_model
+ self.num_draft_tokens = num_draft_tokens
self.model = None
self.tokenizer = None
+ self.draft_model = None
self._loaded = False
def load(self) -> None:
@@ -91,6 +98,28 @@ def load(self) -> None:
tokenizer_config=tokenizer_config,
)
+ # Load draft model for speculative decoding if specified
+ if self.draft_model_name:
+ logger.info(
+ f"Loading draft model for speculative decoding: {self.draft_model_name}"
+ )
+ from mlx_lm import load as mlx_load
+
+ self.draft_model, draft_tokenizer = mlx_load(self.draft_model_name)
+
+ # Validate tokenizer compatibility
+ if draft_tokenizer.vocab_size != self.tokenizer.vocab_size:
+ logger.warning(
+ f"Draft model tokenizer vocab size ({draft_tokenizer.vocab_size}) "
+ f"differs from main model ({self.tokenizer.vocab_size}). "
+ "This may reduce speculative decoding effectiveness."
+ )
+
+ logger.info(
+ f"Speculative decoding enabled: draft={self.draft_model_name}, "
+ f"num_draft_tokens={self.num_draft_tokens}"
+ )
+
self._loaded = True
logger.info(f"Model loaded successfully: {self.model_name}")
@@ -147,15 +176,31 @@ def generate(
# Create sampler with parameters
sampler = self._create_sampler(temperature, top_p)
- # Generate text
- output_text = generate(
- self.model,
- self.tokenizer,
- prompt=prompt,
- max_tokens=max_tokens,
- sampler=sampler,
- verbose=False,
- )
+ # Note: mlx_lm.generate() doesn't support draft_model directly,
+ # speculative decoding is only available via stream_generate()
+ if self.draft_model is not None:
+ # Use streaming with draft model and collect result
+ output_text = ""
+ for chunk in self.stream_generate(
+ prompt=prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ stop=stop,
+ ):
+ output_text += chunk.text
+ if chunk.finished:
+ break
+ else:
+ # Generate text without speculative decoding
+ output_text = generate(
+ self.model,
+ self.tokenizer,
+ prompt=prompt,
+ max_tokens=max_tokens,
+ sampler=sampler,
+ verbose=False,
+ )
# Tokenize output to get token IDs
tokens = self.tokenizer.encode(output_text)
@@ -203,12 +248,22 @@ def stream_generate(
token_count = 0
accumulated_text = ""
+ # Build generation kwargs
+ gen_kwargs = {
+ "max_tokens": max_tokens,
+ "sampler": sampler,
+ }
+
+ # Add draft model for speculative decoding if available
+ if self.draft_model is not None:
+ gen_kwargs["draft_model"] = self.draft_model
+ gen_kwargs["num_draft_tokens"] = self.num_draft_tokens
+
for response in stream_generate(
self.model,
self.tokenizer,
prompt=prompt,
- max_tokens=max_tokens,
- sampler=sampler,
+ **gen_kwargs,
):
token_count += 1
# response.text is the new token text (not accumulated)
diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py
index 138c166a..30b6ef9a 100644
--- a/vllm_mlx/server.py
+++ b/vllm_mlx/server.py
@@ -93,16 +93,25 @@
from .api.tool_calling import (
build_json_system_prompt,
convert_tools_for_template,
+ extract_json_schema_for_guided,
parse_json_output,
parse_tool_calls,
)
from .api.utils import (
SPECIAL_TOKENS_PATTERN,
clean_output_text,
+ extract_json_from_response,
extract_multimodal_content,
is_mllm_model, # noqa: F401
+ strip_thinking_tags,
+)
+from .engine import (
+ BaseEngine,
+ BatchedEngine,
+ GenerationOutput,
+ HybridEngine,
+ SimpleEngine,
)
-from .engine import BaseEngine, BatchedEngine, GenerationOutput, SimpleEngine
from .tool_parsers import ToolParserManager
logging.basicConfig(level=logging.INFO)
@@ -467,6 +476,8 @@ def load_model(
stream_interval: int = 1,
max_tokens: int = 32768,
force_mllm: bool = False,
+ draft_model: str | None = None,
+ num_draft_tokens: int = 4,
):
"""
Load a model (auto-detects MLLM vs LLM).
@@ -478,6 +489,8 @@ def load_model(
stream_interval: Tokens to batch before streaming (batched mode only)
max_tokens: Default max tokens for generation
force_mllm: Force loading as MLLM even if not auto-detected
+ draft_model: Optional draft model for speculative decoding
+ num_draft_tokens: Number of tokens to generate speculatively per step
"""
global _engine, _model_name, _default_max_tokens, _tool_parser_instance
@@ -489,7 +502,28 @@ def load_model(
if force_mllm:
logger.info("Force MLLM mode enabled via --mllm flag")
- if use_batching:
+ if draft_model:
+ logger.info(f"Speculative decoding enabled with draft model: {draft_model}")
+ logger.info(f" num_draft_tokens: {num_draft_tokens}")
+
+ if use_batching and draft_model:
+ # Hybrid mode: shared model with speculative decoding + continuous batching
+ logger.info(f"Loading model with HybridEngine: {model_name}")
+ logger.info(
+ " Hybrid mode: speculative decoding for single user, "
+ "batching for multiple users"
+ )
+ _engine = HybridEngine(
+ model_name=model_name,
+ draft_model=draft_model,
+ num_draft_tokens=num_draft_tokens,
+ scheduler_config=scheduler_config,
+ stream_interval=stream_interval,
+ force_mllm=force_mllm,
+ )
+ # HybridEngine will be started in lifespan (uvicorn's event loop)
+ logger.info(f"Model loaded (hybrid mode): {model_name}")
+ elif use_batching:
logger.info(f"Loading model with BatchedEngine: {model_name}")
_engine = BatchedEngine(
model_name=model_name,
@@ -502,7 +536,12 @@ def load_model(
logger.info(f"Model loaded (batched mode): {model_name}")
else:
logger.info(f"Loading model with SimpleEngine: {model_name}")
- _engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm)
+ _engine = SimpleEngine(
+ model_name=model_name,
+ force_mllm=force_mllm,
+ draft_model=draft_model,
+ num_draft_tokens=num_draft_tokens,
+ )
# Start SimpleEngine synchronously (no background loop)
# Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated)
loop = asyncio.new_event_loop()
@@ -1370,11 +1409,34 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
start_time = time.perf_counter()
timeout = request.timeout or _default_timeout
- output = await _wait_with_disconnect(
- engine.chat(messages=messages, **chat_kwargs),
- raw_request,
- timeout=timeout,
- )
+ # Check if we should use guided generation for JSON schema
+ use_guided = False
+ json_schema = None
+ if response_format and not request.tools:
+ json_schema = extract_json_schema_for_guided(response_format)
+ if json_schema and hasattr(engine, "supports_guided_generation"):
+ use_guided = engine.supports_guided_generation
+ if use_guided:
+ logger.info("Using guided generation for JSON schema enforcement")
+
+ if use_guided and json_schema:
+ # Use guided generation for constrained JSON output
+ output = await _wait_with_disconnect(
+ engine.generate_with_schema(
+ messages=messages,
+ json_schema=json_schema,
+ **chat_kwargs,
+ ),
+ raw_request,
+ timeout=timeout,
+ )
+ else:
+ # Standard generation
+ output = await _wait_with_disconnect(
+ engine.chat(messages=messages, **chat_kwargs),
+ raw_request,
+ timeout=timeout,
+ )
if output is None:
return Response(status_code=499) # Client closed request
@@ -1408,12 +1470,22 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
# Determine finish reason
finish_reason = "tool_calls" if tool_calls else output.finish_reason
+ # Clean and strip thinking tags from content
+ # Thinking tags break JSON parsing in clients expecting pure content
+ # Also extract JSON if model outputs reasoning text before JSON
+ final_content = None
+ if cleaned_text:
+ final_content = strip_thinking_tags(clean_output_text(cleaned_text))
+ # If response looks like it ends with JSON, extract just the JSON part
+ # This handles Qwen3 reasoning mode: "Let me think... {json}"
+ final_content = extract_json_from_response(final_content)
+
return ChatCompletionResponse(
model=request.model,
choices=[
ChatCompletionChoice(
message=AssistantMessage(
- content=clean_output_text(cleaned_text) if cleaned_text else None,
+ content=final_content,
reasoning=reasoning_text,
tool_calls=tool_calls,
),
@@ -1432,7 +1504,11 @@ def _inject_json_instruction(messages: list, instruction: str) -> list:
"""
Inject JSON instruction into messages.
- If a system message exists, append to it. Otherwise, prepend a new system message.
+ PREPENDS instruction to system message for better instruction following.
+ JSON formatting instructions at the beginning of system message are more
+ likely to be followed than instructions at the end.
+
+ If a system message exists, prepend to it. Otherwise, prepend a new system message.
"""
messages = list(messages) # Make a copy
@@ -1445,14 +1521,15 @@ def _inject_json_instruction(messages: list, instruction: str) -> list:
break
if system_idx is not None:
- # Append to existing system message
+ # PREPEND to existing system message (not append!)
+ # Instructions at the start are more effective
msg = messages[system_idx]
if isinstance(msg, dict):
existing = msg.get("content", "")
- msg["content"] = f"{existing}\n\n{instruction}"
+ msg["content"] = f"{instruction}\n\n{existing}"
else:
existing = getattr(msg, "content", "") or ""
- msg.content = f"{existing}\n\n{instruction}"
+ msg.content = f"{instruction}\n\n{existing}"
else:
# Prepend new system message
messages.insert(0, {"role": "system", "content": instruction})
@@ -2219,6 +2296,18 @@ def main():
default=None,
help="Default top_p for generation when not specified in request",
)
+ parser.add_argument(
+ "--draft-model",
+ type=str,
+ default=None,
+ help="Draft model for speculative decoding (must use same tokenizer as main model)",
+ )
+ parser.add_argument(
+ "--num-draft-tokens",
+ type=int,
+ default=4,
+ help="Number of tokens to generate speculatively per step (default: 4)",
+ )
args = parser.parse_args()
@@ -2276,6 +2365,8 @@ def main():
use_batching=args.continuous_batching,
max_tokens=args.max_tokens,
force_mllm=args.mllm,
+ draft_model=args.draft_model,
+ num_draft_tokens=args.num_draft_tokens,
)
# Start server
diff --git a/vllm_mlx/speculative/__init__.py b/vllm_mlx/speculative/__init__.py
new file mode 100644
index 00000000..019b8895
--- /dev/null
+++ b/vllm_mlx/speculative/__init__.py
@@ -0,0 +1,8 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Speculative decoding utilities for vllm-mlx.
+"""
+
+from .prompt_lookup import PromptLookupDecoder
+
+__all__ = ["PromptLookupDecoder"]
diff --git a/vllm_mlx/speculative/prompt_lookup.py b/vllm_mlx/speculative/prompt_lookup.py
new file mode 100644
index 00000000..31c06ef8
--- /dev/null
+++ b/vllm_mlx/speculative/prompt_lookup.py
@@ -0,0 +1,312 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Prompt Lookup Decoding for speculative token generation.
+
+This module implements Prompt Lookup Decoding, a draft-model-free approach
+to speculative decoding that uses n-gram matching from the prompt and
+generated text to predict future tokens.
+
+Reference: https://github.com/apoorvumang/prompt-lookup-decoding
+"""
+
+import logging
+from collections import defaultdict
+import mlx.core as mx
+
+logger = logging.getLogger(__name__)
+
+
+class PromptLookupDecoder:
+ """
+ Prompt Lookup Decoder for speculative token generation.
+
+ Uses n-gram matching to find repeating patterns in the prompt
+ and generated text to speculate future tokens without a draft model.
+
+ This is particularly effective for:
+ - Code generation (repetitive patterns)
+ - Structured text (JSON, XML, markdown)
+ - Text with common phrases
+ - Translation tasks (similar patterns)
+
+ Args:
+ num_draft_tokens: Number of tokens to draft (default: 4)
+ ngram_size: Size of n-gram to match (default: 3)
+ min_matches: Minimum number of matching tokens required (default: 2)
+ """
+
+ def __init__(
+ self,
+ num_draft_tokens: int = 4,
+ ngram_size: int = 3,
+ min_matches: int = 2,
+ ):
+ self.num_draft_tokens = num_draft_tokens
+ self.ngram_size = ngram_size
+ self.min_matches = min_matches
+
+ # Token history for n-gram lookup
+ self._token_history: list[int] = []
+
+ # N-gram index: maps (token_1, ..., token_n) -> [positions]
+ self._ngram_index: dict[tuple, list[int]] = defaultdict(list)
+
+ # Statistics
+ self.total_drafts = 0
+ self.successful_drafts = 0
+ self.total_draft_tokens = 0
+ self.accepted_tokens = 0
+
+ def reset(self):
+ """Reset the decoder state for a new generation."""
+ self._token_history = []
+ self._ngram_index = defaultdict(list)
+
+ def add_prompt_tokens(self, tokens: list[int]):
+ """
+ Add prompt tokens to the history for lookup.
+
+ Args:
+ tokens: List of prompt token IDs
+ """
+ for token in tokens:
+ self._add_token(token)
+
+ def _add_token(self, token: int):
+ """Add a single token to history and update n-gram index."""
+ self._token_history.append(token)
+
+ # Update n-gram index for all n-gram sizes up to ngram_size
+ pos = len(self._token_history) - 1
+ for n in range(1, min(self.ngram_size + 1, pos + 1)):
+ if pos >= n:
+ ngram = tuple(self._token_history[pos - n : pos])
+ self._ngram_index[ngram].append(pos)
+
+ def add_generated_token(self, token: int):
+ """
+ Add a generated token to history.
+
+ Args:
+ token: Generated token ID
+ """
+ self._add_token(token)
+
+ def get_draft_tokens(self) -> list[int]:
+ """
+ Get draft tokens based on n-gram lookup.
+
+ Returns:
+ List of draft token IDs (may be empty if no match found)
+ """
+ if len(self._token_history) < self.ngram_size:
+ return []
+
+ # Get the last ngram_size tokens as the query
+ query = tuple(self._token_history[-self.ngram_size :])
+
+ # Look up positions where this n-gram occurred
+ positions = self._ngram_index.get(query, [])
+
+ if not positions:
+ return []
+
+ # Find the best match (most recent, or one with longest continuation)
+ draft_tokens = []
+ best_continuation_length = 0
+
+ for pos in positions:
+ # Skip if this is the current position
+ if pos == len(self._token_history) - 1:
+ continue
+
+ # Check if there are tokens after this position
+ if pos < len(self._token_history) - 1:
+ # Get continuation tokens
+ continuation_end = min(
+ pos + self.num_draft_tokens + 1, len(self._token_history)
+ )
+ continuation = self._token_history[pos + 1 : continuation_end]
+
+ if len(continuation) > best_continuation_length:
+ best_continuation_length = len(continuation)
+ draft_tokens = continuation[: self.num_draft_tokens]
+
+ if len(draft_tokens) >= self.min_matches:
+ self.total_drafts += 1
+ self.total_draft_tokens += len(draft_tokens)
+ return draft_tokens
+
+ return []
+
+ def record_accepted(self, num_accepted: int):
+ """Record statistics about accepted draft tokens."""
+ if num_accepted > 0:
+ self.successful_drafts += 1
+ self.accepted_tokens += num_accepted
+
+ def get_stats(self) -> dict:
+ """Get decoder statistics."""
+ acceptance_rate = 0.0
+ if self.total_draft_tokens > 0:
+ acceptance_rate = self.accepted_tokens / self.total_draft_tokens
+
+ return {
+ "total_drafts": self.total_drafts,
+ "successful_drafts": self.successful_drafts,
+ "total_draft_tokens": self.total_draft_tokens,
+ "accepted_tokens": self.accepted_tokens,
+ "acceptance_rate": acceptance_rate,
+ "history_size": len(self._token_history),
+ }
+
+
+def prompt_lookup_generate_step(
+ prompt: mx.array,
+ model,
+ *,
+ num_draft_tokens: int = 4,
+ ngram_size: int = 3,
+ max_tokens: int = 256,
+ sampler=None,
+ logits_processors=None,
+ prompt_cache=None,
+ prefill_step_size: int = 512,
+):
+ """
+ Generator for token generation with prompt lookup speculation.
+
+ This is a drop-in replacement for generate_step that uses n-gram
+ matching for speculation instead of a draft model.
+
+ Args:
+ prompt: Input token array
+ model: The main model
+ num_draft_tokens: Number of tokens to draft
+ ngram_size: N-gram size for matching
+ max_tokens: Maximum tokens to generate
+ sampler: Token sampler (default: argmax)
+ logits_processors: Optional logits processors
+ prompt_cache: Optional KV cache
+ prefill_step_size: Prefill chunk size
+
+ Yields:
+ Tuple of (token, logprobs, from_draft)
+ """
+ from mlx_lm.models import cache
+
+ y = prompt.astype(mx.uint32)
+
+ # Create KV cache if not provided
+ if prompt_cache is None:
+ prompt_cache = cache.make_prompt_cache(model)
+
+ sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
+
+ # Initialize prompt lookup decoder
+ decoder = PromptLookupDecoder(
+ num_draft_tokens=num_draft_tokens,
+ ngram_size=ngram_size,
+ )
+
+ # Add prompt tokens to decoder
+ decoder.add_prompt_tokens(prompt.tolist())
+
+ def _step(tokens: mx.array, n_predict: int = 1):
+ """Run model on tokens and get predictions."""
+ logits = model(tokens[None], cache=prompt_cache)
+ logits = logits[:, -n_predict:, :]
+
+ logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
+ sampled = sampler(logprobs)
+
+ return sampled.squeeze(0), logprobs.squeeze(0)
+
+ def _prefill(tokens: mx.array):
+ """Process prompt tokens in chunks."""
+ while tokens.size > prefill_step_size:
+ model(tokens[:prefill_step_size][None], cache=prompt_cache)
+ mx.eval([c.state for c in prompt_cache])
+ tokens = tokens[prefill_step_size:]
+ mx.clear_cache()
+ return tokens
+
+ # Prefill prompt
+ y = _prefill(y)
+
+ # Generate first token
+ current_token, logprobs = _step(y)
+ mx.eval(current_token, logprobs)
+
+ ntoks = 0
+
+ while ntoks < max_tokens:
+ # Yield current token
+ yield current_token.item(), logprobs, False
+ ntoks += 1
+
+ if ntoks >= max_tokens:
+ break
+
+ # Add token to decoder history
+ decoder.add_generated_token(current_token.item())
+
+ # Try to get draft tokens via n-gram lookup
+ draft_tokens = decoder.get_draft_tokens()
+
+ if draft_tokens and len(draft_tokens) > 0:
+ # Speculative path: verify draft tokens
+ verify_input = mx.array([current_token.item()] + draft_tokens, mx.uint32)
+ verified_tokens, verified_logprobs = _step(
+ verify_input, n_predict=len(draft_tokens) + 1
+ )
+ mx.eval(verified_tokens, verified_logprobs)
+
+ verified_tokens = verified_tokens.tolist()
+
+ # Check how many draft tokens were accepted
+ n_accepted = 0
+ for i, (draft_t, verify_t) in enumerate(
+ zip(draft_tokens, verified_tokens[:-1])
+ ):
+ if draft_t == verify_t:
+ n_accepted += 1
+ ntoks += 1
+ decoder.add_generated_token(draft_t)
+ yield draft_t, verified_logprobs[i], True # from_draft=True
+ if ntoks >= max_tokens:
+ break
+ else:
+ break
+
+ decoder.record_accepted(n_accepted)
+
+ if ntoks >= max_tokens:
+ break
+
+ # Trim cache for rejected tokens
+ if n_accepted < len(draft_tokens):
+ cache.trim_prompt_cache(prompt_cache, len(draft_tokens) - n_accepted)
+
+ # Next token is the first non-accepted verification result
+ current_token = mx.array(verified_tokens[n_accepted], mx.uint32)
+ logprobs = verified_logprobs[n_accepted]
+ else:
+ # Standard path: single token generation
+ next_token, next_logprobs = _step(
+ mx.array([current_token.item()], mx.uint32)
+ )
+ mx.eval(next_token, next_logprobs)
+ current_token = next_token
+ logprobs = next_logprobs
+
+ if ntoks % 256 == 0:
+ mx.clear_cache()
+
+ # Log final stats
+ stats = decoder.get_stats()
+ if stats["total_drafts"] > 0:
+ logger.info(
+ f"Prompt Lookup stats: {stats['accepted_tokens']}/{stats['total_draft_tokens']} "
+ f"tokens accepted ({stats['acceptance_rate']:.1%})"
+ )