From d560d87664a98175f413c6c87b4f7d455654e997 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Tue, 14 Apr 2026 09:32:12 -0500 Subject: [PATCH] fix(server): reject arbitrary endpoint model loads --- docs/guides/embeddings.md | 16 ++-- docs/reference/models.md | 2 +- tests/test_embeddings.py | 33 +++++-- tests/test_endpoint_model_policies.py | 81 ++++++++++++++++++ vllm_mlx/endpoint_model_policies.py | 118 ++++++++++++++++++++++++++ vllm_mlx/server.py | 59 +++++-------- 6 files changed, 256 insertions(+), 53 deletions(-) create mode 100644 tests/test_endpoint_model_policies.py create mode 100644 vllm_mlx/endpoint_model_policies.py diff --git a/docs/guides/embeddings.md b/docs/guides/embeddings.md index 3fdd4426c..f8f1ce9eb 100644 --- a/docs/guides/embeddings.md +++ b/docs/guides/embeddings.md @@ -17,7 +17,7 @@ pip install mlx-embeddings>=0.0.5 vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit ``` -If you don't use `--embedding-model`, the embedding model is loaded lazily on the first request. +If you don't use `--embedding-model`, the embedding model is loaded lazily on the first request, but only from the built-in request-time allowlist. ### Generate embeddings with the OpenAI SDK @@ -59,19 +59,25 @@ curl http://localhost:8000/v1/embeddings \ ## Supported Models -Any BERT, XLM-RoBERTa, or ModernBERT model from HuggingFace that is compatible with mlx-embeddings: +Supported request-time models: | Model | Use Case | Size | |-------|----------|------| | `mlx-community/all-MiniLM-L6-v2-4bit` | Fast, compact | Small | | `mlx-community/embeddinggemma-300m-6bit` | High quality | 300M | | `mlx-community/bge-large-en-v1.5-4bit` | Best for English | Large | +| `mlx-community/multilingual-e5-small-mlx` | Multilingual retrieval | Small | +| `mlx-community/multilingual-e5-large-mlx` | Multilingual retrieval | Large | +| `mlx-community/bert-base-uncased-mlx` | General BERT baseline | Base | +| `mlx-community/ModernBERT-base-mlx` | ModernBERT baseline | Base | + +Other embedding models require `--embedding-model` at server startup. ## Model Management ### Lazy loading -By default, the embedding model is loaded on the first `/v1/embeddings` request. You can switch models between requests and the previous model will be unloaded automatically. +By default, the embedding model is loaded on the first `/v1/embeddings` request. You can switch between the supported request-time models above, and the previous model will be unloaded automatically. ### Pre-loading at startup @@ -93,7 +99,7 @@ Create embeddings for the given input text(s). | Field | Type | Required | Description | |-------|------|----------|-------------| -| `model` | string | Yes | Model name from HuggingFace | +| `model` | string | Yes | Supported embedding model ID, or the startup-pinned model when `--embedding-model` is used | | `input` | string or list[string] | Yes | Text(s) to embed | **Response:** @@ -137,7 +143,7 @@ pip install mlx-embeddings>=0.0.5 ### Model not found -Make sure the model name matches a HuggingFace repository compatible with mlx-embeddings. You can pre-download models: +Make sure the model name matches one of the supported request-time IDs above, or start the server with `--embedding-model` to pin a custom model. You can pre-download supported models: ```bash huggingface-cli download mlx-community/all-MiniLM-L6-v2-4bit diff --git a/docs/reference/models.md b/docs/reference/models.md index d378de003..0ca34a105 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -55,7 +55,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun | Model Family | Example Models | |--------------|----------------| | **BERT** | `mlx-community/bert-base-uncased-mlx` | -| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `multilingual-e5-large-mlx` | +| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `mlx-community/multilingual-e5-large-mlx` | | **ModernBERT** | `mlx-community/ModernBERT-base-mlx` | ## Audio Models (via mlx-audio) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 65774df93..41790eaad 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -161,7 +161,7 @@ def test_batch_input_preserves_order(self, client): texts = ["first", "second", "third"] mock_engine = MagicMock() - mock_engine.model_name = "test-embed" + mock_engine.model_name = "mlx-community/all-MiniLM-L6-v2-4bit" mock_engine.embed.return_value = [ [1.0, 0.0], [0.0, 1.0], @@ -174,7 +174,7 @@ def test_batch_input_preserves_order(self, client): try: resp = client.post( "/v1/embeddings", - json={"model": "test-embed", "input": texts}, + json={"model": "mlx-community/all-MiniLM-L6-v2-4bit", "input": texts}, ) finally: srv._embedding_engine = original @@ -193,14 +193,14 @@ def test_empty_input_returns_400(self, client): import vllm_mlx.server as srv mock_engine = MagicMock() - mock_engine.model_name = "test-embed" + mock_engine.model_name = "mlx-community/all-MiniLM-L6-v2-4bit" original = srv._embedding_engine srv._embedding_engine = mock_engine try: resp = client.post( "/v1/embeddings", - json={"model": "test-embed", "input": []}, + json={"model": "mlx-community/all-MiniLM-L6-v2-4bit", "input": []}, ) finally: srv._embedding_engine = original @@ -208,7 +208,7 @@ def test_empty_input_returns_400(self, client): assert resp.status_code == 400 def test_model_hot_swap(self, client): - """Test that requesting a different model triggers reload.""" + """Test that switching to another allowlisted model triggers reload.""" import vllm_mlx.server as srv mock_engine = MagicMock() @@ -222,17 +222,22 @@ def test_model_hot_swap(self, client): try: with patch("vllm_mlx.embedding.EmbeddingEngine") as mock_cls: new_engine = MagicMock() - new_engine.model_name = "new-model" + new_engine.model_name = "mlx-community/multilingual-e5-small-mlx" new_engine.embed.return_value = [[0.9]] new_engine.count_tokens.return_value = 1 mock_cls.return_value = new_engine resp = client.post( "/v1/embeddings", - json={"model": "new-model", "input": "test"}, + json={ + "model": "mlx-community/multilingual-e5-small-mlx", + "input": "test", + }, ) assert resp.status_code == 200 - mock_cls.assert_called_once_with("new-model") + mock_cls.assert_called_once_with( + "mlx-community/multilingual-e5-small-mlx" + ) new_engine.load.assert_called_once() finally: srv._embedding_engine = original @@ -262,6 +267,18 @@ def test_model_locked_rejects_different_model(self, client): srv._embedding_engine = original_engine srv._embedding_model_locked = original_locked + def test_unknown_embedding_model_rejected(self, client): + """Test that request-time embedding loads reject unknown models.""" + resp = client.post( + "/v1/embeddings", + json={"model": "attacker/unknown-embedding", "input": "test"}, + ) + + assert resp.status_code == 400 + body = resp.json() + assert "attacker/unknown-embedding" in body["detail"] + assert "--embedding-model" in body["detail"] + # ============================================================================= # Slow Integration Test - Real Model diff --git a/tests/test_endpoint_model_policies.py b/tests/test_endpoint_model_policies.py new file mode 100644 index 000000000..b6a9e5d89 --- /dev/null +++ b/tests/test_endpoint_model_policies.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Cross-platform tests for optional endpoint model resolution policies.""" + +import pytest +from fastapi import HTTPException + +from vllm_mlx.endpoint_model_policies import ( + resolve_embedding_model_name, + resolve_stt_model_name, + resolve_tts_model_name, +) + + +class TestEmbeddingModelPolicy: + def test_allowlisted_embedding_model_passes(self): + assert ( + resolve_embedding_model_name("mlx-community/multilingual-e5-small-mlx") + == "mlx-community/multilingual-e5-small-mlx" + ) + + def test_unknown_embedding_model_rejected(self): + with pytest.raises(HTTPException) as exc_info: + resolve_embedding_model_name("attacker/unknown-embedding") + + assert exc_info.value.status_code == 400 + assert "attacker/unknown-embedding" in exc_info.value.detail + assert "--embedding-model" in exc_info.value.detail + + def test_locked_embedding_model_can_be_custom(self): + assert ( + resolve_embedding_model_name( + "custom/private-embedding", + locked_model="custom/private-embedding", + ) + == "custom/private-embedding" + ) + + def test_locked_embedding_model_rejects_other_request(self): + with pytest.raises(HTTPException) as exc_info: + resolve_embedding_model_name( + "mlx-community/all-MiniLM-L6-v2-4bit", + locked_model="custom/private-embedding", + ) + + assert exc_info.value.status_code == 400 + assert "custom/private-embedding" in exc_info.value.detail + + +class TestAudioModelPolicy: + def test_stt_alias_resolves_to_configured_model(self): + assert ( + resolve_stt_model_name("whisper-large-v3") + == "mlx-community/whisper-large-v3-mlx" + ) + + def test_stt_full_model_id_is_accepted(self): + model_name = "mlx-community/parakeet-tdt-0.6b-v2" + assert resolve_stt_model_name(model_name) == model_name + + def test_stt_unknown_model_rejected(self): + with pytest.raises(HTTPException) as exc_info: + resolve_stt_model_name("attacker/unknown-stt") + + assert exc_info.value.status_code == 400 + assert "attacker/unknown-stt" in exc_info.value.detail + assert "whisper-large-v3" in exc_info.value.detail + + def test_tts_alias_resolves_to_configured_model(self): + assert resolve_tts_model_name("kokoro") == "mlx-community/Kokoro-82M-bf16" + + def test_tts_full_model_id_is_accepted(self): + model_name = "mlx-community/chatterbox-turbo-fp16" + assert resolve_tts_model_name(model_name) == model_name + + def test_tts_unknown_model_rejected(self): + with pytest.raises(HTTPException) as exc_info: + resolve_tts_model_name("attacker/unknown-tts") + + assert exc_info.value.status_code == 400 + assert "attacker/unknown-tts" in exc_info.value.detail + assert "kokoro" in exc_info.value.detail diff --git a/vllm_mlx/endpoint_model_policies.py b/vllm_mlx/endpoint_model_policies.py new file mode 100644 index 000000000..6fdb9eada --- /dev/null +++ b/vllm_mlx/endpoint_model_policies.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Request-time model resolution policies for optional endpoints. + +These endpoints intentionally do not expose arbitrary Hugging Face loading from +user-controlled request bodies. Unknown model names must be rejected before any +engine instantiation or download path is reached. +""" + +from fastapi import HTTPException + +_EMBEDDING_MODELS = frozenset( + { + "mlx-community/ModernBERT-base-mlx", + "mlx-community/all-MiniLM-L6-v2-4bit", + "mlx-community/bert-base-uncased-mlx", + "mlx-community/bge-large-en-v1.5-4bit", + "mlx-community/embeddinggemma-300m-6bit", + "mlx-community/multilingual-e5-large-mlx", + "mlx-community/multilingual-e5-small-mlx", + } +) + +_STT_MODEL_ALIASES = { + "whisper-large-v3": "mlx-community/whisper-large-v3-mlx", + "whisper-large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + "whisper-medium": "mlx-community/whisper-medium-mlx", + "whisper-small": "mlx-community/whisper-small-mlx", + "parakeet": "mlx-community/parakeet-tdt-0.6b-v2", + "parakeet-v3": "mlx-community/parakeet-tdt-0.6b-v3", +} + +_TTS_MODEL_ALIASES = { + "kokoro": "mlx-community/Kokoro-82M-bf16", + "kokoro-4bit": "mlx-community/Kokoro-82M-4bit", + "chatterbox": "mlx-community/chatterbox-turbo-fp16", + "chatterbox-4bit": "mlx-community/chatterbox-turbo-4bit", + "vibevoice": "mlx-community/VibeVoice-Realtime-0.5B-4bit", + "voxcpm": "mlx-community/VoxCPM1.5", +} + + +def _with_identity_aliases(model_map: dict[str, str]) -> dict[str, str]: + expanded = dict(model_map) + for model_name in model_map.values(): + expanded[model_name] = model_name + return expanded + + +_STT_MODEL_MAP = _with_identity_aliases(_STT_MODEL_ALIASES) +_TTS_MODEL_MAP = _with_identity_aliases(_TTS_MODEL_ALIASES) + + +def _reject_unknown_embedding_model(requested_model: str) -> None: + supported = ", ".join(sorted(_EMBEDDING_MODELS)) + raise HTTPException( + status_code=400, + detail=( + f"Embedding model '{requested_model}' is not available. " + "Request-time embedding model loading is limited to the supported " + f"allowlist: {supported}. To use a different embedding model, start " + "the server with --embedding-model ." + ), + ) + + +def _reject_unknown_audio_model( + endpoint: str, + requested_model: str, + supported_aliases: dict[str, str], +) -> None: + aliases = ", ".join(sorted(supported_aliases)) + raise HTTPException( + status_code=400, + detail=( + f"{endpoint} model '{requested_model}' is not available. " + f"Supported request models are: {aliases}. Exact configured model IDs " + "for those aliases are also accepted." + ), + ) + + +def resolve_embedding_model_name( + requested_model: str, + *, + locked_model: str | None = None, +) -> str: + """Resolve the embedding model for a request or raise HTTP 400.""" + if locked_model is not None: + if requested_model == locked_model: + return locked_model + raise HTTPException( + status_code=400, + detail=( + f"Embedding model '{requested_model}' is not available. " + f"This server was started with --embedding-model {locked_model}. " + f"Only '{locked_model}' can be used for embeddings. Restart the " + f"server with a different --embedding-model to use '{requested_model}'." + ), + ) + + if requested_model in _EMBEDDING_MODELS: + return requested_model + + _reject_unknown_embedding_model(requested_model) + + +def resolve_stt_model_name(requested_model: str) -> str: + """Resolve an STT request model alias or configured model ID.""" + if requested_model in _STT_MODEL_MAP: + return _STT_MODEL_MAP[requested_model] + _reject_unknown_audio_model("Transcription", requested_model, _STT_MODEL_ALIASES) + + +def resolve_tts_model_name(requested_model: str) -> str: + """Resolve a TTS request model alias or configured model ID.""" + if requested_model in _TTS_MODEL_MAP: + return _TTS_MODEL_MAP[requested_model] + _reject_unknown_audio_model("Speech", requested_model, _TTS_MODEL_ALIASES) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 6cd9581bf..3ed6ed185 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -110,6 +110,11 @@ is_mllm_model, # noqa: F401 ) from .engine import BaseEngine, BatchedEngine, GenerationOutput, SimpleEngine +from .endpoint_model_policies import ( + resolve_embedding_model_name, + resolve_stt_model_name, + resolve_tts_model_name, +) from .metrics import metrics as _metrics from .tool_parsers import ToolParserManager @@ -871,33 +876,27 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: } ``` - Supported models: + Supported request-time models: - mlx-community/all-MiniLM-L6-v2-4bit (fast, compact) - mlx-community/embeddinggemma-300m-6bit (high quality) - mlx-community/bge-large-en-v1.5-4bit (best for English) - - Any BERT/XLM-RoBERTa/ModernBERT model from HuggingFace + - mlx-community/multilingual-e5-small-mlx + - mlx-community/multilingual-e5-large-mlx + - mlx-community/bert-base-uncased-mlx + - mlx-community/ModernBERT-base-mlx + + Other embedding models must be pinned explicitly with --embedding-model at + server startup. """ global _embedding_engine tracker = _metrics.track_inference("embeddings", stream=False) try: - # Resolve model name - model_name = request.model - - # If an embedding model was pre-configured at startup, only allow that model - if ( - _embedding_model_locked is not None - and model_name != _embedding_model_locked - ): - raise HTTPException( - status_code=400, - detail=( - f"Embedding model '{model_name}' is not available. " - f"This server was started with --embedding-model {_embedding_model_locked}. " - f"Only '{_embedding_model_locked}' can be used for embeddings. " - f"Restart the server with a different --embedding-model to use '{model_name}'." - ), - ) + # Resolve model name before any lazy-load path is reached. + model_name = resolve_embedding_model_name( + request.model, + locked_model=_embedding_model_locked, + ) # Lazy-load or swap embedding engine load_embedding_model(model_name, lock=False, reuse_existing=True) @@ -1056,16 +1055,7 @@ async def create_transcription( try: from .audio.stt import STTEngine # Lazy import - optional feature - # Map model aliases to full names - model_map = { - "whisper-large-v3": "mlx-community/whisper-large-v3-mlx", - "whisper-large-v3-turbo": "mlx-community/whisper-large-v3-turbo", - "whisper-medium": "mlx-community/whisper-medium-mlx", - "whisper-small": "mlx-community/whisper-small-mlx", - "parakeet": "mlx-community/parakeet-tdt-0.6b-v2", - "parakeet-v3": "mlx-community/parakeet-tdt-0.6b-v3", - } - model_name = model_map.get(model, model) + model_name = resolve_stt_model_name(model) # Load engine if needed if _stt_engine is None or _stt_engine.model_name != model_name: @@ -1132,16 +1122,7 @@ async def create_speech( try: from .audio.tts import TTSEngine # Lazy import - optional feature - # Map model aliases to full names - model_map = { - "kokoro": "mlx-community/Kokoro-82M-bf16", - "kokoro-4bit": "mlx-community/Kokoro-82M-4bit", - "chatterbox": "mlx-community/chatterbox-turbo-fp16", - "chatterbox-4bit": "mlx-community/chatterbox-turbo-4bit", - "vibevoice": "mlx-community/VibeVoice-Realtime-0.5B-4bit", - "voxcpm": "mlx-community/VoxCPM1.5", - } - model_name = model_map.get(model, model) + model_name = resolve_tts_model_name(model) # Load engine if needed if _tts_engine is None or _tts_engine.model_name != model_name: