Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ A high-performance OpenAI-compatible API server for MLX models. Run text, vision
- 🖼️ **Multimodal support** - Text, vision, audio, and image generation/editing
- 🎨 **Flux-series models** - Image generation (schnell, dev, krea-dev, flux-2-klein) and editing (kontext, qwen-image-edit)
- 🔌 **Easy integration** - Works with existing OpenAI client libraries
- ⚡ **Performance** - Configurable quantization (4/8/16-bit) and context length
- ⚡ **Performance** - Configurable quantization (4/8/16-bit), context length, and speculative decoding (lm)
- 🎛️ **LoRA adapters** - Fine-tuned image generation and editing
- 📈 **Queue management** - Built-in request queuing and monitoring

Expand Down Expand Up @@ -53,6 +53,13 @@ mlx-openai-server launch \
--model-path <path-to-mlx-model> \
--model-type <lm|multimodal>

# Text-only with speculative decoding (faster generation using a smaller draft model)
mlx-openai-server launch \
--model-path <path-to-main-model> \
--model-type lm \
--draft-model <path-to-draft-model> \
--num-draft-tokens 4

# Image generation (Flux-series)
mlx-openai-server launch \
--model-type image-generation \
Expand Down Expand Up @@ -85,6 +92,8 @@ mlx-openai-server launch \
- `--config-name`: For image models - `flux-schnell`, `flux-dev`, `flux-krea-dev`, `flux-kontext-dev`, `flux2-klein-4b`, `flux2-klein-9b`, `qwen-image`, `qwen-image-edit`, `z-image-turbo`, `fibo`
- `--quantize`: Quantization level - `4`, `8`, or `16` (image models)
- `--context-length`: Max sequence length for memory optimization
- `--draft-model`: Path to draft model for speculative decoding (lm only)
- `--num-draft-tokens`: Draft tokens per step for speculative decoding (lm only, default: 2)
- `--max-concurrency`: Concurrent requests (default: 1)
- `--queue-timeout`: Request timeout in seconds (default: 300)
- `--lora-paths`: Comma-separated LoRA adapter paths (image models)
Expand Down Expand Up @@ -329,6 +338,21 @@ mlx-openai-server launch \
--chat-template-file /path/to/template.jinja
```

### Speculative Decoding (lm)

Use a smaller draft model to propose tokens and verify them with the main model for faster text generation. Supported only for `--model-type lm`.

```bash
mlx-openai-server launch \
--model-path mlx-community/MyModel-8B-4bit \
--model-type lm \
--draft-model mlx-community/MyModel-1B-4bit \
--num-draft-tokens 4
```

- **`--draft-model`**: Path or HuggingFace repo of the draft model (smaller size model).
- **`--num-draft-tokens`**: Number of tokens the draft model generates per verification step (default: 2). Higher values can increase throughput at the cost of more draft compute.

## Request Queue System

The server includes a request queue system with monitoring:
Expand Down
16 changes: 16 additions & 0 deletions app/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ def cli():
type=int,
help="Maximum number of prompt KV cache entries to store. Only works with language models (lm). Default is 10.",
)
@click.option(
"--draft-model",
default=None,
type=str,
help="Path to the draft model for speculative decoding. Only supported with model type 'lm'. When set, --num-draft-tokens controls how many tokens the draft model generates per step.",
)
@click.option(
"--num-draft-tokens",
default=2,
type=int,
help="Number of draft tokens per step when using speculative decoding (--draft-model). Only supported with model type 'lm'. Default is 2.",
)
def launch(
model_path,
model_type,
Expand All @@ -228,6 +240,8 @@ def launch(
chat_template_file,
debug,
prompt_cache_size,
draft_model,
num_draft_tokens,
) -> None:
"""Start the FastAPI/Uvicorn server with the supplied flags.

Expand Down Expand Up @@ -261,6 +275,8 @@ def launch(
chat_template_file=chat_template_file,
debug=debug,
prompt_cache_size=prompt_cache_size,
draft_model_path=draft_model,
num_draft_tokens=num_draft_tokens,
)

asyncio.run(start(args))
11 changes: 11 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class MLXServerConfig:
chat_template_file: str | None = None
debug: bool = False
prompt_cache_size: int = 10
draft_model_path: str | None = None
num_draft_tokens: int = 2

# Used to capture raw CLI input before processing
lora_paths_str: str | None = None
Expand Down Expand Up @@ -96,6 +98,15 @@ def __post_init__(self):
)
self.config_name = "flux-kontext-dev"

# Speculative decoding (draft model) is only supported for lm model type
if self.draft_model_path and self.model_type != "lm":
logger.warning(
"Draft model / num-draft-tokens are only supported for model type 'lm'. "
"Ignoring speculative decoding options."
)
self.draft_model_path = None
self.num_draft_tokens = 2

@property
def model_identifier(self) -> str:
"""Get the appropriate model identifier based on model type.
Expand Down
52 changes: 39 additions & 13 deletions app/handler/mlx_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,49 @@ class MLXLMHandler:
Provides request queuing, metrics tracking, and robust error handling.
"""

def __init__(self, model_path: str, context_length: int | None = None, max_concurrency: int = 1, enable_auto_tool_choice: bool = False, tool_call_parser: str = None, reasoning_parser: str = None, message_converter: str = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False, prompt_cache_size: int = 10):
def __init__(self, model_path: str, draft_model_path: str | None = None, num_draft_tokens: int = 2, context_length: int | None = None, max_concurrency: int = 1, enable_auto_tool_choice: bool = False, tool_call_parser: str = None, reasoning_parser: str = None, message_converter: str = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False, prompt_cache_size: int = 10):
"""
Initialize the handler with the specified model path.

Args:
model_path (str): Path to the model directory.
context_length (int | None): Maximum context length for the model. If None, uses model default.
max_concurrency (int): Maximum number of concurrent model inference tasks.
enable_auto_tool_choice (bool): Enable automatic tool choice.
tool_call_parser (str): Name of the tool call parser to use (qwen3, glm4_moe, harmony, minimax, ...)
reasoning_parser (str): Name of the reasoning parser to use (qwen3, qwen3_next, glm4_moe, harmony, minimax, ...)
trust_remote_code (bool): Enable trust_remote_code when loading models.
chat_template_file (str): Path to a custom chat template file.
prompt_cache_size (int): Maximum number of prompt KV cache entries to store. Default is 10.

Parameters
----------
model_path : str
Path to the model directory.
draft_model_path : str | None
Path to the draft model for speculative decoding. If None, speculative decoding is disabled.
num_draft_tokens : int
Number of draft tokens per step when using speculative decoding. Default is 2.
context_length : int | None
Maximum context length for the model. If None, uses model default.
max_concurrency : int
Maximum number of concurrent model inference tasks.
enable_auto_tool_choice : bool
Enable automatic tool choice.
tool_call_parser : str | None
Name of the tool call parser to use (qwen3, glm4_moe, harmony, minimax, ...).
reasoning_parser : str | None
Name of the reasoning parser to use (qwen3, qwen3_next, glm4_moe, harmony, minimax, ...).
message_converter : str | None
Name of the message converter to use.
trust_remote_code : bool
Enable trust_remote_code when loading models.
chat_template_file : str | None
Path to a custom chat template file.
debug : bool
Enable debug mode.
prompt_cache_size : int
Maximum number of prompt KV cache entries to store. Default is 10.
"""
self.model_path = model_path
self.model = MLX_LM(model_path, context_length, trust_remote_code=trust_remote_code, chat_template_file=chat_template_file, debug=debug)
self.model = MLX_LM(
model_path,
draft_model_path=draft_model_path,
num_draft_tokens=num_draft_tokens,
context_length=context_length,
trust_remote_code=trust_remote_code,
chat_template_file=chat_template_file,
debug=debug,
)
self.model_created = int(time.time()) # Store creation time when model is loaded
self.model_type = self.model.get_model_type()

Expand Down
34 changes: 30 additions & 4 deletions app/models/mlx_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import mlx.core as mx
from loguru import logger
from mlx_lm.utils import load
from mlx_lm.generate import (
stream_generate
Expand Down Expand Up @@ -54,13 +55,18 @@ class MLX_LM:
supporting both streaming and non-streaming modes.
"""

def __init__(self, model_path: str, context_length: int | None = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False):
def __init__(self, model_path: str, draft_model_path: str = None, num_draft_tokens: int = 2, context_length: int | None = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False):
try:
self.model, self.tokenizer = load(model_path, lazy=False, tokenizer_config = {"trust_remote_code": trust_remote_code})
self.context_length = context_length
self.draft_model = None
self.draft_tokenizer = None
self.num_draft_tokens = num_draft_tokens
if draft_model_path:
self._load_draft_model(draft_model_path, trust_remote_code)
self.pad_token_id = self.tokenizer.pad_token_id
self.bos_token = self.tokenizer.bos_token
self.model_type = self.model.model_type
self.context_length = context_length
self.debug = debug
self.outlines_tokenizer = OutlinesTransformerTokenizer(self.tokenizer)
if chat_template_file:
Expand All @@ -74,9 +80,27 @@ def __init__(self, model_path: str, context_length: int | None = None, trust_rem
except Exception as e:
raise ValueError(f"Error loading model: {str(e)}")

def _load_draft_model(self, draft_model_path: str, trust_remote_code: bool) -> None:
try:
self.draft_model, self.draft_tokenizer = load(draft_model_path, lazy=False, tokenizer_config = {"trust_remote_code": trust_remote_code})
self.context_length = None # speculative decoding does not support context length, should be set to None
self._validate_draft_tokenizer()
except Exception as e:
raise ValueError(f"Error loading draft model: {str(e)}")

def _validate_draft_tokenizer(self) -> None:
if self.draft_tokenizer.vocab_size != self.tokenizer.vocab_size:
logger.warning(
"Draft model tokenizer does not match model tokenizer. "
"Speculative decoding may not work as expected."
)

def create_prompt_cache(self) -> List[Any]:
return make_prompt_cache(self.model, max_kv_size=self.context_length)

cache = make_prompt_cache(self.model, max_kv_size=self.context_length)
if self.draft_model:
cache += make_prompt_cache(self.draft_model, max_kv_size=self.context_length)
return cache

def get_model_type(self) -> str:
return self.model_type

Expand Down Expand Up @@ -165,8 +189,10 @@ def __call__(
self.model,
self.tokenizer,
input_ids,
draft_model=self.draft_model,
sampler=sampler,
max_tokens=max_tokens,
num_draft_tokens=self.num_draft_tokens,
prompt_cache=prompt_cache,
logits_processors=logits_processors,
prompt_progress_callback=prompt_progress_callback
Expand Down
2 changes: 2 additions & 0 deletions app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ async def lifespan(app: FastAPI) -> None:
chat_template_file=config_args.chat_template_file,
debug=config_args.debug,
prompt_cache_size=config_args.prompt_cache_size,
draft_model_path=config_args.draft_model_path,
num_draft_tokens=config_args.num_draft_tokens,
)
# Initialize queue
await handler.initialize(
Expand Down
8 changes: 7 additions & 1 deletion app/utils/prompt_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def insert_cache(

if __name__ == "__main__":
from app.models.mlx_lm import MLX_LM
model = MLX_LM("mlx-community/MiniMax-M2.1-4bit")
model_path = "mlx-community/Qwen3-Coder-Next-8bit"
draft_model_path = "mlx-community/Qwen3-Coder-Next-4bit"
model = MLX_LM(model_path, draft_model_path)
prompt_cache = LRUPromptCache()

import time
Expand Down Expand Up @@ -352,11 +354,15 @@ def insert_cache(

start_time = time.time()
response_2 = model(rest_input_ids_2, cache, stream=True)
raw_text = ""
for chunk in response_2:
if chunk:
if first_token:
print("TIME TO FIRST TOKEN", time.time() - start_time)
first_token = False
raw_text += chunk.text
cache_key_2.append(chunk.token)

print("RAW TEXT", raw_text)

prompt_cache.insert_cache(cache_key_2, cache)