diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index dcbee8ac..eb744f36 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -178,6 +178,8 @@ def serve_command(args): print(f"Prefix cache: max_entries={args.prefix_cache_size}") else: print("Mode: Simple (maximum throughput)") + if args.enable_mtp: + print("MTP: enabled (native speculative decoding)") # Load model with unified server load_model( @@ -187,6 +189,7 @@ def serve_command(args): stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, force_mllm=args.mllm, + mtp=args.enable_mtp, ) # Start server diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 3a7e14e2..6b9bb2eb 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -32,6 +32,7 @@ def __init__( trust_remote_code: bool = True, enable_cache: bool = True, force_mllm: bool = False, + mtp: bool = False, ): """ Initialize the simple engine. @@ -41,11 +42,13 @@ def __init__( trust_remote_code: Whether to trust remote code enable_cache: Enable VLM cache for multimodal models force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable native MTP speculative decoding (model must have MTP head) """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._enable_cache = enable_cache self._is_mllm = force_mllm or is_mllm_model(model_name) + self._mtp = mtp self._model = None self._loaded = False @@ -91,11 +94,15 @@ async def start(self) -> None: self._model = MLXLanguageModel( self._model_name, trust_remote_code=self._trust_remote_code, + mtp=self._mtp, ) self._model.load() self._loaded = True - logger.info(f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm})") + mtp_info = f", MTP={self._mtp}" if self._mtp else "" + logger.info( + f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm}{mtp_info})" + ) async def stop(self) -> None: """Stop the engine and cleanup resources.""" diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 092c060e..bb487859 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -50,6 +50,7 @@ def __init__( model_name: str, tokenizer_name: str | None = None, trust_remote_code: bool = False, + mtp: bool = False, ): """ Initialize the MLX language model. @@ -58,10 +59,12 @@ def __init__( model_name: HuggingFace model name or local path tokenizer_name: Optional separate tokenizer name trust_remote_code: Whether to trust remote code + mtp: Enable native MTP speculative decoding (model must have MTP head) """ self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name self.trust_remote_code = trust_remote_code + self._mtp = mtp self.model = None self.tokenizer = None @@ -203,12 +206,18 @@ def stream_generate( token_count = 0 accumulated_text = "" + # Pass MTP flag for native speculative decoding if enabled + mtp_kwargs = {} + if self._mtp: + mtp_kwargs["mtp"] = True + for response in stream_generate( self.model, self.tokenizer, prompt=prompt, max_tokens=max_tokens, sampler=sampler, + **mtp_kwargs, ): token_count += 1 # response.text is the new token text (not accumulated) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f0328d4e..53b5ac6e 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -467,6 +467,7 @@ def load_model( stream_interval: int = 1, max_tokens: int = 32768, force_mllm: bool = False, + mtp: bool = False, ): """ Load a model (auto-detects MLLM vs LLM). @@ -478,6 +479,7 @@ def load_model( stream_interval: Tokens to batch before streaming (batched mode only) max_tokens: Default max tokens for generation force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable native MTP speculative decoding (SimpleEngine only) """ global _engine, _model_name, _default_max_tokens, _tool_parser_instance @@ -502,7 +504,7 @@ def load_model( logger.info(f"Model loaded (batched mode): {model_name}") else: logger.info(f"Loading model with SimpleEngine: {model_name}") - _engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm) + _engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm, mtp=mtp) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) loop = asyncio.new_event_loop()