diff --git a/README.md b/README.md index bc63e127..cb128a64 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ A high-performance OpenAI-compatible API server for MLX models. Run text, vision - 🖼️ **Multimodal support** - Text, vision, audio, and image generation/editing - 🎨 **Flux-series models** - Image generation (schnell, dev, krea-dev, flux-2-klein) and editing (kontext, qwen-image-edit) - 🔌 **Easy integration** - Works with existing OpenAI client libraries -- ⚡ **Performance** - Configurable quantization (4/8/16-bit) and context length +- ⚡ **Performance** - Configurable quantization (4/8/16-bit), context length, and speculative decoding (lm) - 🎛️ **LoRA adapters** - Fine-tuned image generation and editing - 📈 **Queue management** - Built-in request queuing and monitoring @@ -53,6 +53,13 @@ mlx-openai-server launch \ --model-path \ --model-type +# Text-only with speculative decoding (faster generation using a smaller draft model) +mlx-openai-server launch \ + --model-path \ + --model-type lm \ + --draft-model \ + --num-draft-tokens 4 + # Image generation (Flux-series) mlx-openai-server launch \ --model-type image-generation \ @@ -85,6 +92,8 @@ mlx-openai-server launch \ - `--config-name`: For image models - `flux-schnell`, `flux-dev`, `flux-krea-dev`, `flux-kontext-dev`, `flux2-klein-4b`, `flux2-klein-9b`, `qwen-image`, `qwen-image-edit`, `z-image-turbo`, `fibo` - `--quantize`: Quantization level - `4`, `8`, or `16` (image models) - `--context-length`: Max sequence length for memory optimization +- `--draft-model`: Path to draft model for speculative decoding (lm only) +- `--num-draft-tokens`: Draft tokens per step for speculative decoding (lm only, default: 2) - `--max-concurrency`: Concurrent requests (default: 1) - `--queue-timeout`: Request timeout in seconds (default: 300) - `--lora-paths`: Comma-separated LoRA adapter paths (image models) @@ -329,6 +338,21 @@ mlx-openai-server launch \ --chat-template-file /path/to/template.jinja ``` +### Speculative Decoding (lm) + +Use a smaller draft model to propose tokens and verify them with the main model for faster text generation. Supported only for `--model-type lm`. + +```bash +mlx-openai-server launch \ + --model-path mlx-community/MyModel-8B-4bit \ + --model-type lm \ + --draft-model mlx-community/MyModel-1B-4bit \ + --num-draft-tokens 4 +``` + +- **`--draft-model`**: Path or HuggingFace repo of the draft model (smaller size model). +- **`--num-draft-tokens`**: Number of tokens the draft model generates per verification step (default: 2). Higher values can increase throughput at the cost of more draft compute. + ## Request Queue System The server includes a request queue system with monitoring: diff --git a/app/cli.py b/app/cli.py index c03aa42a..da1efb0d 100644 --- a/app/cli.py +++ b/app/cli.py @@ -203,6 +203,18 @@ def cli(): type=int, help="Maximum number of prompt KV cache entries to store. Only works with language models (lm). Default is 10.", ) +@click.option( + "--draft-model", + default=None, + type=str, + help="Path to the draft model for speculative decoding. Only supported with model type 'lm'. When set, --num-draft-tokens controls how many tokens the draft model generates per step.", +) +@click.option( + "--num-draft-tokens", + default=2, + type=int, + help="Number of draft tokens per step when using speculative decoding (--draft-model). Only supported with model type 'lm'. Default is 2.", +) def launch( model_path, model_type, @@ -228,6 +240,8 @@ def launch( chat_template_file, debug, prompt_cache_size, + draft_model, + num_draft_tokens, ) -> None: """Start the FastAPI/Uvicorn server with the supplied flags. @@ -261,6 +275,8 @@ def launch( chat_template_file=chat_template_file, debug=debug, prompt_cache_size=prompt_cache_size, + draft_model_path=draft_model, + num_draft_tokens=num_draft_tokens, ) asyncio.run(start(args)) diff --git a/app/config.py b/app/config.py index 963262e0..3ac1ed88 100644 --- a/app/config.py +++ b/app/config.py @@ -45,6 +45,8 @@ class MLXServerConfig: chat_template_file: str | None = None debug: bool = False prompt_cache_size: int = 10 + draft_model_path: str | None = None + num_draft_tokens: int = 2 # Used to capture raw CLI input before processing lora_paths_str: str | None = None @@ -96,6 +98,15 @@ def __post_init__(self): ) self.config_name = "flux-kontext-dev" + # Speculative decoding (draft model) is only supported for lm model type + if self.draft_model_path and self.model_type != "lm": + logger.warning( + "Draft model / num-draft-tokens are only supported for model type 'lm'. " + "Ignoring speculative decoding options." + ) + self.draft_model_path = None + self.num_draft_tokens = 2 + @property def model_identifier(self) -> str: """Get the appropriate model identifier based on model type. diff --git a/app/handler/mlx_lm.py b/app/handler/mlx_lm.py index 3073ded5..f0f320c9 100644 --- a/app/handler/mlx_lm.py +++ b/app/handler/mlx_lm.py @@ -30,23 +30,49 @@ class MLXLMHandler: Provides request queuing, metrics tracking, and robust error handling. """ - def __init__(self, model_path: str, context_length: int | None = None, max_concurrency: int = 1, enable_auto_tool_choice: bool = False, tool_call_parser: str = None, reasoning_parser: str = None, message_converter: str = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False, prompt_cache_size: int = 10): + def __init__(self, model_path: str, draft_model_path: str | None = None, num_draft_tokens: int = 2, context_length: int | None = None, max_concurrency: int = 1, enable_auto_tool_choice: bool = False, tool_call_parser: str = None, reasoning_parser: str = None, message_converter: str = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False, prompt_cache_size: int = 10): """ Initialize the handler with the specified model path. - - Args: - model_path (str): Path to the model directory. - context_length (int | None): Maximum context length for the model. If None, uses model default. - max_concurrency (int): Maximum number of concurrent model inference tasks. - enable_auto_tool_choice (bool): Enable automatic tool choice. - tool_call_parser (str): Name of the tool call parser to use (qwen3, glm4_moe, harmony, minimax, ...) - reasoning_parser (str): Name of the reasoning parser to use (qwen3, qwen3_next, glm4_moe, harmony, minimax, ...) - trust_remote_code (bool): Enable trust_remote_code when loading models. - chat_template_file (str): Path to a custom chat template file. - prompt_cache_size (int): Maximum number of prompt KV cache entries to store. Default is 10. + + Parameters + ---------- + model_path : str + Path to the model directory. + draft_model_path : str | None + Path to the draft model for speculative decoding. If None, speculative decoding is disabled. + num_draft_tokens : int + Number of draft tokens per step when using speculative decoding. Default is 2. + context_length : int | None + Maximum context length for the model. If None, uses model default. + max_concurrency : int + Maximum number of concurrent model inference tasks. + enable_auto_tool_choice : bool + Enable automatic tool choice. + tool_call_parser : str | None + Name of the tool call parser to use (qwen3, glm4_moe, harmony, minimax, ...). + reasoning_parser : str | None + Name of the reasoning parser to use (qwen3, qwen3_next, glm4_moe, harmony, minimax, ...). + message_converter : str | None + Name of the message converter to use. + trust_remote_code : bool + Enable trust_remote_code when loading models. + chat_template_file : str | None + Path to a custom chat template file. + debug : bool + Enable debug mode. + prompt_cache_size : int + Maximum number of prompt KV cache entries to store. Default is 10. """ self.model_path = model_path - self.model = MLX_LM(model_path, context_length, trust_remote_code=trust_remote_code, chat_template_file=chat_template_file, debug=debug) + self.model = MLX_LM( + model_path, + draft_model_path=draft_model_path, + num_draft_tokens=num_draft_tokens, + context_length=context_length, + trust_remote_code=trust_remote_code, + chat_template_file=chat_template_file, + debug=debug, + ) self.model_created = int(time.time()) # Store creation time when model is loaded self.model_type = self.model.get_model_type() diff --git a/app/models/mlx_lm.py b/app/models/mlx_lm.py index d6e53981..c1dfe40c 100644 --- a/app/models/mlx_lm.py +++ b/app/models/mlx_lm.py @@ -1,5 +1,6 @@ import os import mlx.core as mx +from loguru import logger from mlx_lm.utils import load from mlx_lm.generate import ( stream_generate @@ -54,13 +55,18 @@ class MLX_LM: supporting both streaming and non-streaming modes. """ - def __init__(self, model_path: str, context_length: int | None = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False): + def __init__(self, model_path: str, draft_model_path: str = None, num_draft_tokens: int = 2, context_length: int | None = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False): try: self.model, self.tokenizer = load(model_path, lazy=False, tokenizer_config = {"trust_remote_code": trust_remote_code}) + self.context_length = context_length + self.draft_model = None + self.draft_tokenizer = None + self.num_draft_tokens = num_draft_tokens + if draft_model_path: + self._load_draft_model(draft_model_path, trust_remote_code) self.pad_token_id = self.tokenizer.pad_token_id self.bos_token = self.tokenizer.bos_token self.model_type = self.model.model_type - self.context_length = context_length self.debug = debug self.outlines_tokenizer = OutlinesTransformerTokenizer(self.tokenizer) if chat_template_file: @@ -74,9 +80,27 @@ def __init__(self, model_path: str, context_length: int | None = None, trust_rem except Exception as e: raise ValueError(f"Error loading model: {str(e)}") + def _load_draft_model(self, draft_model_path: str, trust_remote_code: bool) -> None: + try: + self.draft_model, self.draft_tokenizer = load(draft_model_path, lazy=False, tokenizer_config = {"trust_remote_code": trust_remote_code}) + self.context_length = None # speculative decoding does not support context length, should be set to None + self._validate_draft_tokenizer() + except Exception as e: + raise ValueError(f"Error loading draft model: {str(e)}") + + def _validate_draft_tokenizer(self) -> None: + if self.draft_tokenizer.vocab_size != self.tokenizer.vocab_size: + logger.warning( + "Draft model tokenizer does not match model tokenizer. " + "Speculative decoding may not work as expected." + ) + def create_prompt_cache(self) -> List[Any]: - return make_prompt_cache(self.model, max_kv_size=self.context_length) - + cache = make_prompt_cache(self.model, max_kv_size=self.context_length) + if self.draft_model: + cache += make_prompt_cache(self.draft_model, max_kv_size=self.context_length) + return cache + def get_model_type(self) -> str: return self.model_type @@ -165,8 +189,10 @@ def __call__( self.model, self.tokenizer, input_ids, + draft_model=self.draft_model, sampler=sampler, max_tokens=max_tokens, + num_draft_tokens=self.num_draft_tokens, prompt_cache=prompt_cache, logits_processors=logits_processors, prompt_progress_callback=prompt_progress_callback diff --git a/app/server.py b/app/server.py index b182f132..c41ab222 100644 --- a/app/server.py +++ b/app/server.py @@ -209,6 +209,8 @@ async def lifespan(app: FastAPI) -> None: chat_template_file=config_args.chat_template_file, debug=config_args.debug, prompt_cache_size=config_args.prompt_cache_size, + draft_model_path=config_args.draft_model_path, + num_draft_tokens=config_args.num_draft_tokens, ) # Initialize queue await handler.initialize( diff --git a/app/utils/prompt_cache.py b/app/utils/prompt_cache.py index 0840d5e2..f6cec189 100644 --- a/app/utils/prompt_cache.py +++ b/app/utils/prompt_cache.py @@ -309,7 +309,9 @@ def insert_cache( if __name__ == "__main__": from app.models.mlx_lm import MLX_LM - model = MLX_LM("mlx-community/MiniMax-M2.1-4bit") + model_path = "mlx-community/Qwen3-Coder-Next-8bit" + draft_model_path = "mlx-community/Qwen3-Coder-Next-4bit" + model = MLX_LM(model_path, draft_model_path) prompt_cache = LRUPromptCache() import time @@ -352,11 +354,15 @@ def insert_cache( start_time = time.time() response_2 = model(rest_input_ids_2, cache, stream=True) + raw_text = "" for chunk in response_2: if chunk: if first_token: print("TIME TO FIRST TOKEN", time.time() - start_time) first_token = False + raw_text += chunk.text cache_key_2.append(chunk.token) + print("RAW TEXT", raw_text) + prompt_cache.insert_cache(cache_key_2, cache) \ No newline at end of file