Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions docs/guides/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pip install mlx-embeddings>=0.0.5
vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit
```

If you don't use `--embedding-model`, the embedding model is loaded lazily on the first request.
If you don't use `--embedding-model`, the embedding model is loaded lazily on the first request, but only from the built-in request-time allowlist.

### Generate embeddings with the OpenAI SDK

Expand Down Expand Up @@ -59,19 +59,25 @@ curl http://localhost:8000/v1/embeddings \

## Supported Models

Any BERT, XLM-RoBERTa, or ModernBERT model from HuggingFace that is compatible with mlx-embeddings:
Supported request-time models:

| Model | Use Case | Size |
|-------|----------|------|
| `mlx-community/all-MiniLM-L6-v2-4bit` | Fast, compact | Small |
| `mlx-community/embeddinggemma-300m-6bit` | High quality | 300M |
| `mlx-community/bge-large-en-v1.5-4bit` | Best for English | Large |
| `mlx-community/multilingual-e5-small-mlx` | Multilingual retrieval | Small |
| `mlx-community/multilingual-e5-large-mlx` | Multilingual retrieval | Large |
| `mlx-community/bert-base-uncased-mlx` | General BERT baseline | Base |
| `mlx-community/ModernBERT-base-mlx` | ModernBERT baseline | Base |

Other embedding models require `--embedding-model` at server startup.

## Model Management

### Lazy loading

By default, the embedding model is loaded on the first `/v1/embeddings` request. You can switch models between requests and the previous model will be unloaded automatically.
By default, the embedding model is loaded on the first `/v1/embeddings` request. You can switch between the supported request-time models above, and the previous model will be unloaded automatically.

### Pre-loading at startup

Expand All @@ -93,7 +99,7 @@ Create embeddings for the given input text(s).

| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `model` | string | Yes | Model name from HuggingFace |
| `model` | string | Yes | Supported embedding model ID, or the startup-pinned model when `--embedding-model` is used |
| `input` | string or list[string] | Yes | Text(s) to embed |

**Response:**
Expand Down Expand Up @@ -137,7 +143,7 @@ pip install mlx-embeddings>=0.0.5

### Model not found

Make sure the model name matches a HuggingFace repository compatible with mlx-embeddings. You can pre-download models:
Make sure the model name matches one of the supported request-time IDs above, or start the server with `--embedding-model` to pin a custom model. You can pre-download supported models:

```bash
huggingface-cli download mlx-community/all-MiniLM-L6-v2-4bit
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun
| Model Family | Example Models |
|--------------|----------------|
| **BERT** | `mlx-community/bert-base-uncased-mlx` |
| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `multilingual-e5-large-mlx` |
| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `mlx-community/multilingual-e5-large-mlx` |
| **ModernBERT** | `mlx-community/ModernBERT-base-mlx` |

## Audio Models (via mlx-audio)
Expand Down
33 changes: 25 additions & 8 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_batch_input_preserves_order(self, client):

texts = ["first", "second", "third"]
mock_engine = MagicMock()
mock_engine.model_name = "test-embed"
mock_engine.model_name = "mlx-community/all-MiniLM-L6-v2-4bit"
mock_engine.embed.return_value = [
[1.0, 0.0],
[0.0, 1.0],
Expand All @@ -174,7 +174,7 @@ def test_batch_input_preserves_order(self, client):
try:
resp = client.post(
"/v1/embeddings",
json={"model": "test-embed", "input": texts},
json={"model": "mlx-community/all-MiniLM-L6-v2-4bit", "input": texts},
)
finally:
srv._embedding_engine = original
Expand All @@ -193,22 +193,22 @@ def test_empty_input_returns_400(self, client):
import vllm_mlx.server as srv

mock_engine = MagicMock()
mock_engine.model_name = "test-embed"
mock_engine.model_name = "mlx-community/all-MiniLM-L6-v2-4bit"

original = srv._embedding_engine
srv._embedding_engine = mock_engine
try:
resp = client.post(
"/v1/embeddings",
json={"model": "test-embed", "input": []},
json={"model": "mlx-community/all-MiniLM-L6-v2-4bit", "input": []},
)
finally:
srv._embedding_engine = original

assert resp.status_code == 400

def test_model_hot_swap(self, client):
"""Test that requesting a different model triggers reload."""
"""Test that switching to another allowlisted model triggers reload."""
import vllm_mlx.server as srv

mock_engine = MagicMock()
Expand All @@ -222,17 +222,22 @@ def test_model_hot_swap(self, client):
try:
with patch("vllm_mlx.embedding.EmbeddingEngine") as mock_cls:
new_engine = MagicMock()
new_engine.model_name = "new-model"
new_engine.model_name = "mlx-community/multilingual-e5-small-mlx"
new_engine.embed.return_value = [[0.9]]
new_engine.count_tokens.return_value = 1
mock_cls.return_value = new_engine

resp = client.post(
"/v1/embeddings",
json={"model": "new-model", "input": "test"},
json={
"model": "mlx-community/multilingual-e5-small-mlx",
"input": "test",
},
)
assert resp.status_code == 200
mock_cls.assert_called_once_with("new-model")
mock_cls.assert_called_once_with(
"mlx-community/multilingual-e5-small-mlx"
)
new_engine.load.assert_called_once()
finally:
srv._embedding_engine = original
Expand Down Expand Up @@ -262,6 +267,18 @@ def test_model_locked_rejects_different_model(self, client):
srv._embedding_engine = original_engine
srv._embedding_model_locked = original_locked

def test_unknown_embedding_model_rejected(self, client):
"""Test that request-time embedding loads reject unknown models."""
resp = client.post(
"/v1/embeddings",
json={"model": "attacker/unknown-embedding", "input": "test"},
)

assert resp.status_code == 400
body = resp.json()
assert "attacker/unknown-embedding" in body["detail"]
assert "--embedding-model" in body["detail"]


# =============================================================================
# Slow Integration Test - Real Model
Expand Down
81 changes: 81 additions & 0 deletions tests/test_endpoint_model_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
"""Cross-platform tests for optional endpoint model resolution policies."""

import pytest
from fastapi import HTTPException

from vllm_mlx.endpoint_model_policies import (
resolve_embedding_model_name,
resolve_stt_model_name,
resolve_tts_model_name,
)


class TestEmbeddingModelPolicy:
def test_allowlisted_embedding_model_passes(self):
assert (
resolve_embedding_model_name("mlx-community/multilingual-e5-small-mlx")
== "mlx-community/multilingual-e5-small-mlx"
)

def test_unknown_embedding_model_rejected(self):
with pytest.raises(HTTPException) as exc_info:
resolve_embedding_model_name("attacker/unknown-embedding")

assert exc_info.value.status_code == 400
assert "attacker/unknown-embedding" in exc_info.value.detail
assert "--embedding-model" in exc_info.value.detail

def test_locked_embedding_model_can_be_custom(self):
assert (
resolve_embedding_model_name(
"custom/private-embedding",
locked_model="custom/private-embedding",
)
== "custom/private-embedding"
)

def test_locked_embedding_model_rejects_other_request(self):
with pytest.raises(HTTPException) as exc_info:
resolve_embedding_model_name(
"mlx-community/all-MiniLM-L6-v2-4bit",
locked_model="custom/private-embedding",
)

assert exc_info.value.status_code == 400
assert "custom/private-embedding" in exc_info.value.detail


class TestAudioModelPolicy:
def test_stt_alias_resolves_to_configured_model(self):
assert (
resolve_stt_model_name("whisper-large-v3")
== "mlx-community/whisper-large-v3-mlx"
)

def test_stt_full_model_id_is_accepted(self):
model_name = "mlx-community/parakeet-tdt-0.6b-v2"
assert resolve_stt_model_name(model_name) == model_name

def test_stt_unknown_model_rejected(self):
with pytest.raises(HTTPException) as exc_info:
resolve_stt_model_name("attacker/unknown-stt")

assert exc_info.value.status_code == 400
assert "attacker/unknown-stt" in exc_info.value.detail
assert "whisper-large-v3" in exc_info.value.detail

def test_tts_alias_resolves_to_configured_model(self):
assert resolve_tts_model_name("kokoro") == "mlx-community/Kokoro-82M-bf16"

def test_tts_full_model_id_is_accepted(self):
model_name = "mlx-community/chatterbox-turbo-fp16"
assert resolve_tts_model_name(model_name) == model_name

def test_tts_unknown_model_rejected(self):
with pytest.raises(HTTPException) as exc_info:
resolve_tts_model_name("attacker/unknown-tts")

assert exc_info.value.status_code == 400
assert "attacker/unknown-tts" in exc_info.value.detail
assert "kokoro" in exc_info.value.detail
118 changes: 118 additions & 0 deletions vllm_mlx/endpoint_model_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
"""Request-time model resolution policies for optional endpoints.

These endpoints intentionally do not expose arbitrary Hugging Face loading from
user-controlled request bodies. Unknown model names must be rejected before any
engine instantiation or download path is reached.
"""

from fastapi import HTTPException

_EMBEDDING_MODELS = frozenset(
{
"mlx-community/ModernBERT-base-mlx",
"mlx-community/all-MiniLM-L6-v2-4bit",
"mlx-community/bert-base-uncased-mlx",
"mlx-community/bge-large-en-v1.5-4bit",
"mlx-community/embeddinggemma-300m-6bit",
"mlx-community/multilingual-e5-large-mlx",
"mlx-community/multilingual-e5-small-mlx",
}
)

_STT_MODEL_ALIASES = {
"whisper-large-v3": "mlx-community/whisper-large-v3-mlx",
"whisper-large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"whisper-medium": "mlx-community/whisper-medium-mlx",
"whisper-small": "mlx-community/whisper-small-mlx",
"parakeet": "mlx-community/parakeet-tdt-0.6b-v2",
"parakeet-v3": "mlx-community/parakeet-tdt-0.6b-v3",
}

_TTS_MODEL_ALIASES = {
"kokoro": "mlx-community/Kokoro-82M-bf16",
"kokoro-4bit": "mlx-community/Kokoro-82M-4bit",
"chatterbox": "mlx-community/chatterbox-turbo-fp16",
"chatterbox-4bit": "mlx-community/chatterbox-turbo-4bit",
"vibevoice": "mlx-community/VibeVoice-Realtime-0.5B-4bit",
"voxcpm": "mlx-community/VoxCPM1.5",
}


def _with_identity_aliases(model_map: dict[str, str]) -> dict[str, str]:
expanded = dict(model_map)
for model_name in model_map.values():
expanded[model_name] = model_name
return expanded


_STT_MODEL_MAP = _with_identity_aliases(_STT_MODEL_ALIASES)
_TTS_MODEL_MAP = _with_identity_aliases(_TTS_MODEL_ALIASES)


def _reject_unknown_embedding_model(requested_model: str) -> None:
supported = ", ".join(sorted(_EMBEDDING_MODELS))
raise HTTPException(
status_code=400,
detail=(
f"Embedding model '{requested_model}' is not available. "
"Request-time embedding model loading is limited to the supported "
f"allowlist: {supported}. To use a different embedding model, start "
"the server with --embedding-model <model>."
),
)


def _reject_unknown_audio_model(
endpoint: str,
requested_model: str,
supported_aliases: dict[str, str],
) -> None:
aliases = ", ".join(sorted(supported_aliases))
raise HTTPException(
status_code=400,
detail=(
f"{endpoint} model '{requested_model}' is not available. "
f"Supported request models are: {aliases}. Exact configured model IDs "
"for those aliases are also accepted."
),
)


def resolve_embedding_model_name(
requested_model: str,
*,
locked_model: str | None = None,
) -> str:
"""Resolve the embedding model for a request or raise HTTP 400."""
if locked_model is not None:
if requested_model == locked_model:
return locked_model
raise HTTPException(
status_code=400,
detail=(
f"Embedding model '{requested_model}' is not available. "
f"This server was started with --embedding-model {locked_model}. "
f"Only '{locked_model}' can be used for embeddings. Restart the "
f"server with a different --embedding-model to use '{requested_model}'."
),
)

if requested_model in _EMBEDDING_MODELS:
return requested_model

_reject_unknown_embedding_model(requested_model)


def resolve_stt_model_name(requested_model: str) -> str:
"""Resolve an STT request model alias or configured model ID."""
if requested_model in _STT_MODEL_MAP:
return _STT_MODEL_MAP[requested_model]
_reject_unknown_audio_model("Transcription", requested_model, _STT_MODEL_ALIASES)


def resolve_tts_model_name(requested_model: str) -> str:
"""Resolve a TTS request model alias or configured model ID."""
if requested_model in _TTS_MODEL_MAP:
return _TTS_MODEL_MAP[requested_model]
_reject_unknown_audio_model("Speech", requested_model, _TTS_MODEL_ALIASES)
Loading
Loading