Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def serve_command(args):
print(f"Prefix cache: max_entries={args.prefix_cache_size}")
else:
print("Mode: Simple (maximum throughput)")
if args.enable_mtp:
print("MTP: enabled (native speculative decoding)")

# Load model with unified server
load_model(
Expand All @@ -187,6 +189,7 @@ def serve_command(args):
stream_interval=args.stream_interval if args.continuous_batching else 1,
max_tokens=args.max_tokens,
force_mllm=args.mllm,
mtp=args.enable_mtp,
)

# Start server
Expand Down
9 changes: 8 additions & 1 deletion vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
trust_remote_code: bool = True,
enable_cache: bool = True,
force_mllm: bool = False,
mtp: bool = False,
):
"""
Initialize the simple engine.
Expand All @@ -41,11 +42,13 @@ def __init__(
trust_remote_code: Whether to trust remote code
enable_cache: Enable VLM cache for multimodal models
force_mllm: Force loading as MLLM even if not auto-detected
mtp: Enable native MTP speculative decoding (model must have MTP head)
"""
self._model_name = model_name
self._trust_remote_code = trust_remote_code
self._enable_cache = enable_cache
self._is_mllm = force_mllm or is_mllm_model(model_name)
self._mtp = mtp

self._model = None
self._loaded = False
Expand Down Expand Up @@ -91,11 +94,15 @@ async def start(self) -> None:
self._model = MLXLanguageModel(
self._model_name,
trust_remote_code=self._trust_remote_code,
mtp=self._mtp,
)

self._model.load()
self._loaded = True
logger.info(f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm})")
mtp_info = f", MTP={self._mtp}" if self._mtp else ""
logger.info(
f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm}{mtp_info})"
)

async def stop(self) -> None:
"""Stop the engine and cleanup resources."""
Expand Down
9 changes: 9 additions & 0 deletions vllm_mlx/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
model_name: str,
tokenizer_name: str | None = None,
trust_remote_code: bool = False,
mtp: bool = False,
):
"""
Initialize the MLX language model.
Expand All @@ -58,10 +59,12 @@ def __init__(
model_name: HuggingFace model name or local path
tokenizer_name: Optional separate tokenizer name
trust_remote_code: Whether to trust remote code
mtp: Enable native MTP speculative decoding (model must have MTP head)
"""
self.model_name = model_name
self.tokenizer_name = tokenizer_name or model_name
self.trust_remote_code = trust_remote_code
self._mtp = mtp

self.model = None
self.tokenizer = None
Expand Down Expand Up @@ -203,12 +206,18 @@ def stream_generate(
token_count = 0
accumulated_text = ""

# Pass MTP flag for native speculative decoding if enabled
mtp_kwargs = {}
if self._mtp:
mtp_kwargs["mtp"] = True

for response in stream_generate(
self.model,
self.tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
**mtp_kwargs,
):
token_count += 1
# response.text is the new token text (not accumulated)
Expand Down
4 changes: 3 additions & 1 deletion vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def load_model(
stream_interval: int = 1,
max_tokens: int = 32768,
force_mllm: bool = False,
mtp: bool = False,
):
"""
Load a model (auto-detects MLLM vs LLM).
Expand All @@ -478,6 +479,7 @@ def load_model(
stream_interval: Tokens to batch before streaming (batched mode only)
max_tokens: Default max tokens for generation
force_mllm: Force loading as MLLM even if not auto-detected
mtp: Enable native MTP speculative decoding (SimpleEngine only)
"""
global _engine, _model_name, _default_max_tokens, _tool_parser_instance

Expand All @@ -502,7 +504,7 @@ def load_model(
logger.info(f"Model loaded (batched mode): {model_name}")
else:
logger.info(f"Loading model with SimpleEngine: {model_name}")
_engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm)
_engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm, mtp=mtp)
# Start SimpleEngine synchronously (no background loop)
# Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated)
loop = asyncio.new_event_loop()
Expand Down
Loading