diff --git a/mise.toml b/mise.toml new file mode 100644 index 000000000..ed4e4262c --- /dev/null +++ b/mise.toml @@ -0,0 +1,10 @@ +[tools] +python = "3.12" + +[tasks.serve] +description = "Start the vllm-mlx server" +run = "python -m vllm_mlx.server" + +[tasks.test] +description = "Run tests" +run = "pytest" diff --git a/tests/test_server.py b/tests/test_server.py index 56f414946..1b0f84b3d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -265,6 +265,417 @@ def test_extract_multimodal_content_with_video_url(self): +# ============================================================================= +# Security and Reliability Tests (PR #4) +# ============================================================================= + +class TestRateLimiter: + """Test the RateLimiter class for rate limiting functionality.""" + + def test_rate_limiter_disabled_by_default(self): + """Test that rate limiter allows all requests when disabled.""" + from vllm_mlx.server import RateLimiter + + limiter = RateLimiter(requests_per_minute=5, enabled=False) + + # Should allow unlimited requests when disabled + for _ in range(100): + allowed, retry_after = limiter.is_allowed("client1") + assert allowed is True + assert retry_after == 0 + + def test_rate_limiter_enforces_limit(self): + """Test that rate limiter enforces the request limit.""" + from vllm_mlx.server import RateLimiter + + limiter = RateLimiter(requests_per_minute=3, enabled=True) + + # First 3 requests should be allowed + for i in range(3): + allowed, retry_after = limiter.is_allowed("client1") + assert allowed is True, f"Request {i+1} should be allowed" + assert retry_after == 0 + + # 4th request should be blocked + allowed, retry_after = limiter.is_allowed("client1") + assert allowed is False + assert retry_after > 0 + + def test_rate_limiter_per_client(self): + """Test that rate limits are tracked per client.""" + from vllm_mlx.server import RateLimiter + + limiter = RateLimiter(requests_per_minute=2, enabled=True) + + # Client 1 uses its quota + limiter.is_allowed("client1") + limiter.is_allowed("client1") + allowed, _ = limiter.is_allowed("client1") + assert allowed is False + + # Client 2 should still have quota + allowed, _ = limiter.is_allowed("client2") + assert allowed is True + + def test_rate_limiter_thread_safety(self): + """Test that rate limiter is thread-safe.""" + import threading + from vllm_mlx.server import RateLimiter + + limiter = RateLimiter(requests_per_minute=100, enabled=True) + results = [] + errors = [] + + def make_requests(): + try: + for _ in range(10): + allowed, _ = limiter.is_allowed("shared_client") + results.append(allowed) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=make_requests) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0, f"Thread safety errors: {errors}" + assert len(results) == 100 + # Exactly 100 requests allowed (our limit) + assert results.count(True) == 100 + + +class TestTempFileManager: + """Test the TempFileManager class for temp file cleanup.""" + + def test_register_and_cleanup_single_file(self): + """Test registering and cleaning up a single temp file.""" + import tempfile + import os + from vllm_mlx.models.mllm import TempFileManager + + manager = TempFileManager() + + # Create a real temp file + temp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt") + temp.write(b"test content") + temp.close() + + # Register it + path = manager.register(temp.name) + assert path == temp.name + assert os.path.exists(temp.name) + + # Cleanup + result = manager.cleanup(temp.name) + assert result is True + assert not os.path.exists(temp.name) + + def test_cleanup_all_files(self): + """Test cleaning up all registered temp files.""" + import tempfile + import os + from vllm_mlx.models.mllm import TempFileManager + + manager = TempFileManager() + paths = [] + + # Create multiple temp files + for i in range(3): + temp = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{i}.txt") + temp.write(f"content {i}".encode()) + temp.close() + manager.register(temp.name) + paths.append(temp.name) + + # Verify all exist + for p in paths: + assert os.path.exists(p) + + # Cleanup all + cleaned = manager.cleanup_all() + assert cleaned == 3 + + # Verify all deleted + for p in paths: + assert not os.path.exists(p) + + def test_cleanup_nonexistent_file(self): + """Test cleanup of a non-existent file.""" + from vllm_mlx.models.mllm import TempFileManager + + manager = TempFileManager() + + # Cleanup a file that doesn't exist + result = manager.cleanup("/nonexistent/path/file.txt") + assert result is False + + def test_thread_safe_registration(self): + """Test that TempFileManager is thread-safe.""" + import threading + import tempfile + import os + from vllm_mlx.models.mllm import TempFileManager + + manager = TempFileManager() + paths = [] + lock = threading.Lock() + errors = [] + + def register_files(): + try: + for _ in range(5): + temp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt") + temp.write(b"test") + temp.close() + path = manager.register(temp.name) + with lock: + paths.append(path) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=register_files) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0, f"Thread safety errors: {errors}" + assert len(paths) == 25 + + # Cleanup all + cleaned = manager.cleanup_all() + assert cleaned == 25 + + +class TestRequestOutputCollectorThreadSafety: + """Test thread-safety of RequestOutputCollector._waiting_consumers.""" + + def test_waiting_consumers_thread_safe(self): + """Test that _waiting_consumers counter is thread-safe.""" + import threading + import asyncio + from vllm_mlx.output_collector import RequestOutputCollector + + # Reset the counter + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers = 0 + + errors = [] + + def manipulate_counter(): + try: + for _ in range(100): + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers += 1 + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers -= 1 + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=manipulate_counter) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0, f"Thread safety errors: {errors}" + # Should return to zero + with RequestOutputCollector._waiting_lock: + assert RequestOutputCollector._waiting_consumers == 0 + + def test_has_waiting_consumers_method(self): + """Test has_waiting_consumers class method.""" + from vllm_mlx.output_collector import RequestOutputCollector + + # Reset counter + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers = 0 + + assert RequestOutputCollector.has_waiting_consumers() is False + + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers = 1 + + assert RequestOutputCollector.has_waiting_consumers() is True + + # Reset + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers = 0 + + +class TestRequestTimeoutField: + """Test the new timeout field in request models.""" + + def test_chat_completion_request_timeout_field(self): + """Test that ChatCompletionRequest has timeout field.""" + from vllm_mlx.server import ChatCompletionRequest, Message + + # Default should be None + request = ChatCompletionRequest( + model="test-model", + messages=[Message(role="user", content="Hello")] + ) + assert request.timeout is None + + # Can set custom timeout + request_with_timeout = ChatCompletionRequest( + model="test-model", + messages=[Message(role="user", content="Hello")], + timeout=60.0 + ) + assert request_with_timeout.timeout == 60.0 + + def test_completion_request_timeout_field(self): + """Test that CompletionRequest has timeout field.""" + from vllm_mlx.server import CompletionRequest + + # Default should be None + request = CompletionRequest( + model="test-model", + prompt="Once upon a time" + ) + assert request.timeout is None + + # Can set custom timeout + request_with_timeout = CompletionRequest( + model="test-model", + prompt="Once upon a time", + timeout=120.0 + ) + assert request_with_timeout.timeout == 120.0 + + +class TestAPIKeyVerification: + """Test API key verification with timing attack prevention.""" + + def test_secrets_compare_digest_usage(self): + """Test that secrets.compare_digest is used (timing attack prevention).""" + import secrets + + # Verify secrets.compare_digest works as expected + key1 = "test-api-key-12345" + key2 = "test-api-key-12345" + key3 = "different-key-67890" + + # Same keys should match + assert secrets.compare_digest(key1, key2) is True + + # Different keys should not match + assert secrets.compare_digest(key1, key3) is False + + # Verify it's constant-time (by checking function exists) + assert hasattr(secrets, 'compare_digest') + + def test_verify_api_key_rejects_invalid(self): + """Test that invalid API key is rejected with 401.""" + import asyncio + from unittest.mock import MagicMock + from fastapi import HTTPException + from fastapi.security import HTTPAuthorizationCredentials + + # Import and set up the module + import vllm_mlx.server as server + original_key = server._api_key + + try: + # Set a known API key + server._api_key = "valid-secret-key" + + # Create mock credentials with invalid key + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", + credentials="invalid-key" + ) + + # Should raise HTTPException with 401 + with pytest.raises(HTTPException) as exc_info: + asyncio.get_event_loop().run_until_complete( + server.verify_api_key(credentials) + ) + + assert exc_info.value.status_code == 401 + assert "Invalid API key" in str(exc_info.value.detail) + finally: + server._api_key = original_key + + def test_verify_api_key_accepts_valid(self): + """Test that valid API key is accepted.""" + import asyncio + from fastapi.security import HTTPAuthorizationCredentials + + import vllm_mlx.server as server + original_key = server._api_key + + try: + # Set a known API key + server._api_key = "valid-secret-key" + + # Create mock credentials with valid key + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", + credentials="valid-secret-key" + ) + + # Should not raise any exception + result = asyncio.get_event_loop().run_until_complete( + server.verify_api_key(credentials) + ) + # verify_api_key returns True on success (no exception raised) + assert result is True or result is None + finally: + server._api_key = original_key + + +class TestRateLimiterHTTPResponse: + """Test rate limiter HTTP response behavior.""" + + def test_rate_limiter_returns_retry_after(self): + """Test that rate limiter returns retry_after when limit exceeded.""" + from vllm_mlx.server import RateLimiter + import time + + limiter = RateLimiter(requests_per_minute=2, enabled=True) + + # Exhaust the limit + limiter.is_allowed("test_client") + limiter.is_allowed("test_client") + + # Next request should be denied with retry_after + allowed, retry_after = limiter.is_allowed("test_client") + + assert allowed is False + assert retry_after is not None + assert retry_after > 0 + assert retry_after <= 60 # Should be within a minute + + def test_rate_limiter_window_cleanup(self): + """Test that rate limiter cleans up old requests from sliding window.""" + from vllm_mlx.server import RateLimiter + import time + + limiter = RateLimiter(requests_per_minute=2, enabled=True) + + # Make some requests + limiter.is_allowed("test_client") + limiter.is_allowed("test_client") + + # Should be denied (limit reached) + allowed, _ = limiter.is_allowed("test_client") + assert allowed is False + + # Manually inject old timestamps to simulate time passing + # The sliding window should clean these up + old_time = time.time() - 120 # 2 minutes ago + with limiter._lock: + limiter._requests["test_client"] = [old_time, old_time] + + # Now should be allowed again (old requests cleaned up) + allowed, _ = limiter.is_allowed("test_client") + assert allowed is True + + # ============================================================================= # Integration Tests (require running server) # ============================================================================= diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index d3f48656f..689cf8a4f 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -15,19 +15,21 @@ from pydantic import BaseModel, Field - # ============================================================================= # Content Types (for multimodal messages) # ============================================================================= + class ImageUrl(BaseModel): """Image URL with optional detail level.""" + url: str detail: Optional[str] = None class VideoUrl(BaseModel): """Video URL.""" + url: str @@ -41,6 +43,7 @@ class ContentPart(BaseModel): - video: Video from local path - video_url: Video from URL or base64 """ + type: str # "text", "image_url", "video", "video_url" text: Optional[str] = None image_url: Optional[Union[ImageUrl, dict, str]] = None @@ -52,6 +55,7 @@ class ContentPart(BaseModel): # Messages # ============================================================================= + class Message(BaseModel): """ A message in a chat conversation. @@ -62,6 +66,7 @@ class Message(BaseModel): - Tool call messages (assistant with tool_calls) - Tool response messages (role="tool" with tool_call_id) """ + role: str content: Optional[Union[str, List[ContentPart], List[dict]]] = None # For assistant messages with tool calls @@ -74,14 +79,17 @@ class Message(BaseModel): # Tool Calling # ============================================================================= + class FunctionCall(BaseModel): """A function call with name and arguments.""" + name: str arguments: str # JSON string class ToolCall(BaseModel): """A tool call from the model.""" + id: str type: str = "function" function: FunctionCall @@ -89,6 +97,7 @@ class ToolCall(BaseModel): class ToolDefinition(BaseModel): """Definition of a tool that can be called by the model.""" + type: str = "function" function: dict @@ -97,8 +106,10 @@ class ToolDefinition(BaseModel): # Structured Output (JSON Schema) # ============================================================================= + class ResponseFormatJsonSchema(BaseModel): """JSON Schema definition for structured output.""" + name: str description: Optional[str] = None schema_: dict = Field(alias="schema") # JSON Schema specification @@ -117,6 +128,7 @@ class ResponseFormat(BaseModel): - "json_object": Forces valid JSON output - "json_schema": Forces JSON matching a specific schema """ + type: str = "text" # "text", "json_object", "json_schema" json_schema: Optional[ResponseFormatJsonSchema] = None @@ -125,8 +137,10 @@ class ResponseFormat(BaseModel): # Chat Completion # ============================================================================= + class ChatCompletionRequest(BaseModel): """Request for chat completion.""" + model: str messages: List[Message] temperature: float = 0.7 @@ -142,10 +156,13 @@ class ChatCompletionRequest(BaseModel): # MLLM-specific parameters video_fps: Optional[float] = None video_max_frames: Optional[int] = None + # Request timeout in seconds (None = use server default) + timeout: Optional[float] = None class AssistantMessage(BaseModel): """Response message from the assistant.""" + role: str = "assistant" content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = None @@ -153,6 +170,7 @@ class AssistantMessage(BaseModel): class ChatCompletionChoice(BaseModel): """A single choice in chat completion response.""" + index: int = 0 message: AssistantMessage finish_reason: Optional[str] = "stop" @@ -160,6 +178,7 @@ class ChatCompletionChoice(BaseModel): class Usage(BaseModel): """Token usage statistics.""" + prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 @@ -167,6 +186,7 @@ class Usage(BaseModel): class ChatCompletionResponse(BaseModel): """Response for chat completion.""" + id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:8]}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -179,8 +199,10 @@ class ChatCompletionResponse(BaseModel): # Text Completion # ============================================================================= + class CompletionRequest(BaseModel): """Request for text completion.""" + model: str prompt: Union[str, List[str]] temperature: float = 0.7 @@ -188,10 +210,13 @@ class CompletionRequest(BaseModel): max_tokens: Optional[int] = None stream: bool = False stop: Optional[List[str]] = None + # Request timeout in seconds (None = use server default) + timeout: Optional[float] = None class CompletionChoice(BaseModel): """A single choice in text completion response.""" + index: int = 0 text: str finish_reason: Optional[str] = "stop" @@ -199,6 +224,7 @@ class CompletionChoice(BaseModel): class CompletionResponse(BaseModel): """Response for text completion.""" + id: str = Field(default_factory=lambda: f"cmpl-{uuid.uuid4().hex[:8]}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -211,8 +237,10 @@ class CompletionResponse(BaseModel): # Models List # ============================================================================= + class ModelInfo(BaseModel): """Information about an available model.""" + id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) @@ -221,6 +249,7 @@ class ModelInfo(BaseModel): class ModelsResponse(BaseModel): """Response for listing models.""" + object: str = "list" data: List[ModelInfo] @@ -229,8 +258,10 @@ class ModelsResponse(BaseModel): # MCP (Model Context Protocol) # ============================================================================= + class MCPToolInfo(BaseModel): """Information about an MCP tool.""" + name: str description: str server: str @@ -239,12 +270,14 @@ class MCPToolInfo(BaseModel): class MCPToolsResponse(BaseModel): """Response for listing MCP tools.""" + tools: List[MCPToolInfo] count: int class MCPServerInfo(BaseModel): """Information about an MCP server.""" + name: str state: str transport: str @@ -254,17 +287,20 @@ class MCPServerInfo(BaseModel): class MCPServersResponse(BaseModel): """Response for listing MCP servers.""" + servers: List[MCPServerInfo] class MCPExecuteRequest(BaseModel): """Request to execute an MCP tool.""" + tool_name: str arguments: dict = Field(default_factory=dict) class MCPExecuteResponse(BaseModel): """Response from executing an MCP tool.""" + tool_name: str content: Optional[Union[str, list, dict]] = None is_error: bool = False @@ -275,8 +311,10 @@ class MCPExecuteResponse(BaseModel): # Streaming (for SSE responses) # ============================================================================= + class ChatCompletionChunkDelta(BaseModel): """Delta content in a streaming chunk.""" + role: Optional[str] = None content: Optional[str] = None tool_calls: Optional[List[dict]] = None @@ -284,6 +322,7 @@ class ChatCompletionChunkDelta(BaseModel): class ChatCompletionChunkChoice(BaseModel): """A single choice in a streaming chunk.""" + index: int = 0 delta: ChatCompletionChunkDelta finish_reason: Optional[str] = None @@ -291,6 +330,7 @@ class ChatCompletionChunkChoice(BaseModel): class ChatCompletionChunk(BaseModel): """A streaming chunk for chat completion.""" + id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:8]}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index afd353288..b4a72c3ab 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -6,6 +6,7 @@ performance when serving a single user at a time. """ +import asyncio import logging from typing import Any, AsyncIterator, Dict, List, Optional @@ -120,7 +121,9 @@ async def generate( if not self._loaded: await self.start() - output = self._model.generate( + # Run in thread pool to allow asyncio timeout to work + output = await asyncio.to_thread( + self._model.generate, prompt=prompt, max_tokens=max_tokens, temperature=temperature, @@ -232,7 +235,9 @@ async def chat( if self._is_mllm: # For MLLM, use the chat method which handles images/videos - output = self._model.chat( + # Run in thread pool to allow asyncio timeout to work + output = await asyncio.to_thread( + self._model.chat, messages=messages, max_tokens=max_tokens, temperature=temperature, @@ -247,7 +252,9 @@ async def chat( ) else: # For LLM, use the chat method - output = self._model.chat( + # Run in thread pool to allow asyncio timeout to work + output = await asyncio.to_thread( + self._model.chat, messages=messages, max_tokens=max_tokens, temperature=temperature, diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index e951338b5..88a27bbb3 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -13,25 +13,88 @@ - VLM KV cache for repeated image/video+prompt combinations """ +import atexit import base64 import logging import math +import os import re import tempfile +import threading from dataclasses import dataclass, field from io import BytesIO from pathlib import Path -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Set, Tuple, Union from urllib.parse import urlparse -import requests import numpy as np +import requests from vllm_mlx.vlm_cache import VLMCacheManager logger = logging.getLogger(__name__) +class TempFileManager: + """Thread-safe manager for tracking and cleaning up temporary files.""" + + def __init__(self): + self._files: Set[str] = set() + self._lock = threading.Lock() + atexit.register(self.cleanup_all) + + def register(self, path: str) -> str: + """Register a temp file for tracking. Returns the path for convenience.""" + with self._lock: + self._files.add(path) + return path + + def cleanup(self, path: str) -> bool: + """Clean up a specific temp file. Returns True if successful.""" + with self._lock: + if path in self._files: + self._files.discard(path) + try: + if os.path.exists(path): + os.unlink(path) + logger.debug(f"Cleaned up temp file: {path}") + return True + except OSError as e: + logger.warning(f"Failed to clean up temp file {path}: {e}") + return False + + def cleanup_all(self) -> int: + """Clean up all tracked temp files. Returns count of cleaned files.""" + with self._lock: + files_to_clean = list(self._files) + self._files.clear() + + cleaned = 0 + for path in files_to_clean: + try: + if os.path.exists(path): + os.unlink(path) + cleaned += 1 + except OSError: + pass + + if cleaned: + logger.info(f"Cleaned up {cleaned} temp files") + return cleaned + + +# Global temp file manager +_temp_manager = TempFileManager() + + +def cleanup_temp_file(path: str) -> bool: + """Clean up a specific temporary file.""" + return _temp_manager.cleanup(path) + + +def cleanup_all_temp_files() -> int: + """Clean up all tracked temporary files. Returns count of cleaned files.""" + return _temp_manager.cleanup_all() # Video processing constants @@ -45,15 +108,17 @@ @dataclass class MultimodalInput: """Input for multimodal generation.""" + prompt: str images: list[str] = field(default_factory=list) # Paths, URLs, or base64 videos: list[str] = field(default_factory=list) # Paths - audio: list[str] = field(default_factory=list) # Paths + audio: list[str] = field(default_factory=list) # Paths @dataclass class MLLMOutput: """Output from multimodal language model.""" + text: str finish_reason: str | None = None prompt_tokens: int = 0 @@ -110,12 +175,12 @@ def download_image(url: str, timeout: int = 30) -> str: path = urlparse(url).path ext = Path(path).suffix or ".jpg" - # Save to temp file + # Save to temp file and register for cleanup temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False) temp_file.write(response.content) temp_file.close() - return temp_file.name + return _temp_manager.register(temp_file.name) def download_video(url: str, timeout: int = 120) -> str: @@ -154,16 +219,18 @@ def download_video(url: str, timeout: int = 120) -> str: path = urlparse(url).path ext = Path(path).suffix or ".mp4" - # Save to temp file (stream for larger files) + # Save to temp file (stream for larger files) and register for cleanup temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False) for chunk in response.iter_content(chunk_size=8192): temp_file.write(chunk) temp_file.close() file_size = Path(temp_file.name).stat().st_size - logger.info(f"Video downloaded: {temp_file.name} ({file_size / 1024 / 1024:.1f} MB)") + logger.info( + f"Video downloaded: {temp_file.name} ({file_size / 1024 / 1024:.1f} MB)" + ) - return temp_file.name + return _temp_manager.register(temp_file.name) def decode_base64_video(base64_string: str) -> str: @@ -190,15 +257,17 @@ def decode_base64_video(base64_string: str) -> str: data = base64_string ext = ".mp4" - # Decode and save + # Decode, save, and register for cleanup video_bytes = base64.b64decode(data) temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False) temp_file.write(video_bytes) temp_file.close() - logger.info(f"Base64 video decoded: {temp_file.name} ({len(video_bytes) / 1024 / 1024:.1f} MB)") + logger.info( + f"Base64 video decoded: {temp_file.name} ({len(video_bytes) / 1024 / 1024:.1f} MB)" + ) - return temp_file.name + return _temp_manager.register(temp_file.name) def process_video_input(video: Union[str, dict]) -> str: @@ -247,13 +316,13 @@ def save_base64_image(base64_string: str) -> str: image_bytes = decode_base64_image(base64_string) # Detect format from magic bytes - if image_bytes[:8] == b'\x89PNG\r\n\x1a\n': + if image_bytes[:8] == b"\x89PNG\r\n\x1a\n": ext = ".png" - elif image_bytes[:2] == b'\xff\xd8': + elif image_bytes[:2] == b"\xff\xd8": ext = ".jpg" - elif image_bytes[:6] in (b'GIF87a', b'GIF89a'): + elif image_bytes[:6] in (b"GIF87a", b"GIF89a"): ext = ".gif" - elif image_bytes[:4] == b'RIFF' and image_bytes[8:12] == b'WEBP': + elif image_bytes[:4] == b"RIFF" and image_bytes[8:12] == b"WEBP": ext = ".webp" else: ext = ".jpg" # Default @@ -262,7 +331,7 @@ def save_base64_image(base64_string: str) -> str: temp_file.write(image_bytes) temp_file.close() - return temp_file.name + return _temp_manager.register(temp_file.name) def process_image_input(image: Union[str, dict]) -> str: @@ -419,7 +488,7 @@ def save_frames_to_temp(frames: list[np.ndarray]) -> list[str]: img = Image.fromarray(frame) temp_file = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) img.save(temp_file.name, "JPEG", quality=85) - paths.append(temp_file.name) + paths.append(_temp_manager.register(temp_file.name)) return paths @@ -600,8 +669,8 @@ def generate( self.load() from mlx_vlm import generate - from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.models import cache as vlm_cache + from mlx_vlm.prompt_utils import apply_chat_template images = images or [] videos = videos or [] @@ -626,11 +695,13 @@ def generate( all_images.extend(frames) # Include video params in cache key video_str = video_path if isinstance(video_path, str) else str(video_path) - all_sources.append(f"video:{video_str}:fps{video_fps}:max{video_max_frames}") + all_sources.append( + f"video:{video_str}:fps{video_fps}:max{video_max_frames}" + ) logger.info(f"Added {len(frames)} frames from video: {video_path}") # Apply chat template if needed - if all_images and hasattr(self.processor, 'apply_chat_template'): + if all_images and hasattr(self.processor, "apply_chat_template"): try: formatted_prompt = apply_chat_template( self.processor, @@ -679,7 +750,7 @@ def generate( if use_cache and self._cache_manager and all_sources and not cache_hit: if prompt_cache is not None: try: - num_tokens = getattr(result, 'prompt_tokens', 0) + num_tokens = getattr(result, "prompt_tokens", 0) self._cache_manager.store_cache( all_sources, formatted_prompt, prompt_cache, num_tokens ) @@ -688,10 +759,10 @@ def generate( logger.debug(f"Failed to store VLM cache: {e}") # Handle GenerationResult object or plain string - if hasattr(result, 'text'): + if hasattr(result, "text"): output_text = result.text - prompt_tokens = getattr(result, 'prompt_tokens', 0) - generation_tokens = getattr(result, 'generation_tokens', 0) + prompt_tokens = getattr(result, "prompt_tokens", 0) + generation_tokens = getattr(result, "generation_tokens", 0) else: output_text = str(result) prompt_tokens = 0 @@ -835,9 +906,9 @@ def chat( continue # Convert Pydantic models to dicts - if hasattr(item, 'model_dump'): + if hasattr(item, "model_dump"): item = item.model_dump() - elif hasattr(item, 'dict'): + elif hasattr(item, "dict"): item = item.dict() if isinstance(item, dict): @@ -868,7 +939,9 @@ def chat( video_fps = kwargs.pop("video_fps", DEFAULT_FPS) video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) for video_path in videos: - frames = self._prepare_video(video_path, fps=video_fps, max_frames=video_max_frames) + frames = self._prepare_video( + video_path, fps=video_fps, max_frames=video_max_frames + ) all_images.extend(frames) logger.info(f"Added {len(frames)} frames from video: {video_path}") @@ -897,10 +970,10 @@ def chat( ) # Handle GenerationResult object or plain string - if hasattr(result, 'text'): + if hasattr(result, "text"): output_text = result.text - prompt_tokens = getattr(result, 'prompt_tokens', 0) - generation_tokens = getattr(result, 'generation_tokens', 0) + prompt_tokens = getattr(result, "prompt_tokens", 0) + generation_tokens = getattr(result, "generation_tokens", 0) else: output_text = str(result) prompt_tokens = 0 @@ -1080,18 +1153,31 @@ def list_supported_model_families() -> dict[str, str]: def is_mllm_model(model_name: str) -> bool: """Check if a model name indicates an MLLM model.""" mllm_patterns = [ - "-VL-", "-VL/", "VL-", - "llava", "LLaVA", - "idefics", "Idefics", - "paligemma", "PaliGemma", - "pixtral", "Pixtral", - "molmo", "Molmo", - "phi3-vision", "phi-3-vision", - "cogvlm", "CogVLM", - "internvl", "InternVL", - "minicpm-v", "MiniCPM-V", - "florence", "Florence", - "deepseek-vl", "DeepSeek-VL", + "-VL-", + "-VL/", + "VL-", + "llava", + "LLaVA", + "idefics", + "Idefics", + "paligemma", + "PaliGemma", + "pixtral", + "Pixtral", + "molmo", + "Molmo", + "phi3-vision", + "phi-3-vision", + "cogvlm", + "CogVLM", + "internvl", + "InternVL", + "minicpm-v", + "MiniCPM-V", + "florence", + "Florence", + "deepseek-vl", + "DeepSeek-VL", ] model_lower = model_name.lower() return any(pattern.lower() in model_lower for pattern in mllm_patterns) diff --git a/vllm_mlx/output_collector.py b/vllm_mlx/output_collector.py index 6c08e7580..6c345e9ab 100644 --- a/vllm_mlx/output_collector.py +++ b/vllm_mlx/output_collector.py @@ -7,6 +7,7 @@ """ import asyncio +import threading from dataclasses import dataclass, field from typing import List, Optional @@ -36,6 +37,7 @@ class RequestOutputCollector: # Global counter of collectors with waiting consumers # Used to optimize: only yield when someone is waiting _waiting_consumers: int = 0 + _waiting_lock: threading.Lock = threading.Lock() def __init__(self, aggregate: bool = True): """ @@ -100,7 +102,8 @@ async def get(self) -> RequestOutput: # Track that we're waiting (for yield optimization) if not self._is_waiting: self._is_waiting = True - RequestOutputCollector._waiting_consumers += 1 + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers += 1 try: while self.output is None: await self.ready.wait() @@ -111,7 +114,8 @@ async def get(self) -> RequestOutput: finally: if self._is_waiting: self._is_waiting = False - RequestOutputCollector._waiting_consumers -= 1 + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers -= 1 def _merge_outputs( self, @@ -153,7 +157,8 @@ def clear(self) -> None: self.ready.clear() if self._is_waiting: self._is_waiting = False - RequestOutputCollector._waiting_consumers -= 1 + with RequestOutputCollector._waiting_lock: + RequestOutputCollector._waiting_consumers -= 1 @classmethod def has_waiting_consumers(cls) -> bool: @@ -161,7 +166,8 @@ def has_waiting_consumers(cls) -> bool: Used by engine to optimize: only yield when someone is waiting. """ - return cls._waiting_consumers > 0 + with cls._waiting_lock: + return cls._waiting_consumers > 0 @dataclass diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index a95875589..5e2fef595 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -42,6 +42,7 @@ import json import logging import os +import secrets import time import uuid from contextlib import asynccontextmanager @@ -51,40 +52,19 @@ from fastapi.responses import StreamingResponse # Import from new modular API -from .api.models import ( - Message, - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionChoice, - ChatCompletionChunk, - ChatCompletionChunkChoice, - ChatCompletionChunkDelta, - AssistantMessage, - CompletionRequest, - CompletionResponse, - CompletionChoice, - Usage, - ModelInfo, - ModelsResponse, - MCPToolInfo, - MCPToolsResponse, - MCPServerInfo, - MCPServersResponse, - MCPExecuteRequest, - MCPExecuteResponse, - # Re-export for backwards compatibility with tests - ContentPart, - ImageUrl, - VideoUrl, -) -from .api.utils import clean_output_text, extract_multimodal_content, is_mllm_model -from .api.tool_calling import ( - parse_tool_calls, - convert_tools_for_template, - parse_json_output, - build_json_system_prompt, -) -from .engine import BaseEngine, SimpleEngine, BatchedEngine +from .api.models import ( # Re-export for backwards compatibility with tests + AssistantMessage, ChatCompletionChoice, ChatCompletionChunk, + ChatCompletionChunkChoice, ChatCompletionChunkDelta, ChatCompletionRequest, + ChatCompletionResponse, CompletionChoice, CompletionRequest, + CompletionResponse, ContentPart, ImageUrl, MCPExecuteRequest, + MCPExecuteResponse, MCPServerInfo, MCPServersResponse, MCPToolInfo, + MCPToolsResponse, Message, ModelInfo, ModelsResponse, Usage, VideoUrl) +from .api.tool_calling import (build_json_system_prompt, + convert_tools_for_template, parse_json_output, + parse_tool_calls) +from .api.utils import (clean_output_text, extract_multimodal_content, + is_mllm_model) +from .engine import BaseEngine, BatchedEngine, SimpleEngine logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -93,6 +73,7 @@ _engine: Optional[BaseEngine] = None _model_name: Optional[str] = None _default_max_tokens: int = 32768 +_default_timeout: float = 300.0 # Default request timeout in seconds (5 minutes) # Global MCP manager _mcp_manager = None @@ -108,7 +89,7 @@ async def lifespan(app: FastAPI): global _engine, _mcp_manager # Startup: Start engine if loaded (needed for BatchedEngine in uvicorn's event loop) - if _engine is not None and hasattr(_engine, '_loaded') and not _engine._loaded: + if _engine is not None and hasattr(_engine, "_loaded") and not _engine._loaded: await _engine.start() # Initialize MCP if config provided @@ -135,19 +116,84 @@ async def lifespan(app: FastAPI): ) -from fastapi import Request, Depends -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +import threading +from collections import defaultdict + +from fastapi import Depends, Request +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer security = HTTPBearer(auto_error=False) +class RateLimiter: + """Simple in-memory rate limiter using sliding window.""" + + def __init__(self, requests_per_minute: int = 60, enabled: bool = False): + self.requests_per_minute = requests_per_minute + self.enabled = enabled + self.window_size = 60.0 # 1 minute window + self._requests: dict[str, list[float]] = defaultdict(list) + self._lock = threading.Lock() + + def is_allowed(self, client_id: str) -> tuple[bool, int]: + """ + Check if request is allowed for client. + + Returns: + (is_allowed, retry_after_seconds) + """ + if not self.enabled: + return True, 0 + + current_time = time.time() + window_start = current_time - self.window_size + + with self._lock: + # Clean old requests outside window + self._requests[client_id] = [ + t for t in self._requests[client_id] if t > window_start + ] + + # Check rate limit + if len(self._requests[client_id]) >= self.requests_per_minute: + # Calculate retry-after + oldest = min(self._requests[client_id]) + retry_after = int(oldest + self.window_size - current_time) + 1 + return False, max(1, retry_after) + + # Record this request + self._requests[client_id].append(current_time) + return True, 0 + + +# Global rate limiter (disabled by default) +_rate_limiter = RateLimiter(requests_per_minute=60, enabled=False) + + +async def check_rate_limit(request: Request): + """Rate limiting dependency.""" + # Use API key as client ID if available, otherwise use IP + client_id = request.headers.get( + "Authorization", request.client.host if request.client else "unknown" + ) + + allowed, retry_after = _rate_limiter.is_allowed(client_id) + if not allowed: + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", + headers={"Retry-After": str(retry_after)}, + ) + + async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): """Verify API key if authentication is enabled.""" if _api_key is None: return True # No auth required if credentials is None: raise HTTPException(status_code=401, detail="API key required") - if credentials.credentials != _api_key: + # Use constant-time comparison to prevent timing attacks + if not secrets.compare_digest(credentials.credentials, _api_key): raise HTTPException(status_code=401, detail="Invalid API key") return True @@ -212,7 +258,9 @@ async def health(): """Health check endpoint.""" mcp_info = None if _mcp_manager is not None: - connected = sum(1 for s in _mcp_manager.get_server_status() if s.state.value == "connected") + connected = sum( + 1 for s in _mcp_manager.get_server_status() if s.state.value == "connected" + ) total = len(_mcp_manager.get_server_status()) mcp_info = { "enabled": True, @@ -246,6 +294,7 @@ async def list_models() -> ModelsResponse: # MCP Endpoints # ============================================================================= + @app.get("/v1/mcp/tools", dependencies=[Depends(verify_api_key)]) async def list_mcp_tools() -> MCPToolsResponse: """List all available MCP tools.""" @@ -254,12 +303,14 @@ async def list_mcp_tools() -> MCPToolsResponse: tools = [] for tool in _mcp_manager.get_all_tools(): - tools.append(MCPToolInfo( - name=tool.full_name, - description=tool.description, - server=tool.server_name, - parameters=tool.input_schema, - )) + tools.append( + MCPToolInfo( + name=tool.full_name, + description=tool.description, + server=tool.server_name, + parameters=tool.input_schema, + ) + ) return MCPToolsResponse(tools=tools, count=len(tools)) @@ -272,13 +323,15 @@ async def list_mcp_servers() -> MCPServersResponse: servers = [] for status in _mcp_manager.get_server_status(): - servers.append(MCPServerInfo( - name=status.name, - state=status.state.value, - transport=status.transport.value, - tools_count=status.tools_count, - error=status.error, - )) + servers.append( + MCPServerInfo( + name=status.name, + state=status.state.value, + transport=status.transport.value, + tools_count=status.tools_count, + error=status.error, + ) + ) return MCPServersResponse(servers=servers) @@ -288,8 +341,7 @@ async def execute_mcp_tool(request: MCPExecuteRequest) -> MCPExecuteResponse: """Execute an MCP tool.""" if _mcp_manager is None: raise HTTPException( - status_code=503, - detail="MCP not configured. Start server with --mcp-config" + status_code=503, detail="MCP not configured. Start server with --mcp-config" ) result = await _mcp_manager.execute_tool( @@ -331,8 +383,8 @@ async def create_transcription( - parakeet-tdt-0.6b-v2 (English, fastest) """ global _stt_engine - import tempfile import os + import tempfile try: from .audio.stt import STTEngine @@ -376,7 +428,7 @@ async def create_transcription( except ImportError: raise HTTPException( status_code=503, - detail="mlx-audio not installed. Install with: pip install mlx-audio" + detail="mlx-audio not installed. Install with: pip install mlx-audio", ) except Exception as e: logger.error(f"Transcription failed: {e}") @@ -425,13 +477,15 @@ async def create_speech( audio = _tts_engine.generate(input, voice=voice, speed=speed) audio_bytes = _tts_engine.to_bytes(audio, format=response_format) - content_type = "audio/wav" if response_format == "wav" else f"audio/{response_format}" + content_type = ( + "audio/wav" if response_format == "wav" else f"audio/{response_format}" + ) return Response(content=audio_bytes, media_type=content_type) except ImportError: raise HTTPException( status_code=503, - detail="mlx-audio not installed. Install with: pip install mlx-audio" + detail="mlx-audio not installed. Install with: pip install mlx-audio", ) except Exception as e: logger.error(f"TTS generation failed: {e}") @@ -441,7 +495,7 @@ async def create_speech( @app.get("/v1/audio/voices", dependencies=[Depends(verify_api_key)]) async def list_voices(model: str = "kokoro"): """List available voices for a TTS model.""" - from .audio.tts import KOKORO_VOICES, CHATTERBOX_VOICES + from .audio.tts import CHATTERBOX_VOICES, KOKORO_VOICES if "kokoro" in model.lower(): return {"voices": KOKORO_VOICES} @@ -455,7 +509,10 @@ async def list_voices(model: str = "kokoro"): # Completion Endpoints # ============================================================================= -@app.post("/v1/completions", dependencies=[Depends(verify_api_key)]) + +@app.post( + "/v1/completions", dependencies=[Depends(verify_api_key), Depends(check_rate_limit)] +) async def create_completion(request: CompletionRequest): """Create a text completion.""" engine = get_engine() @@ -469,30 +526,43 @@ async def create_completion(request: CompletionRequest): media_type="text/event-stream", ) - # Non-streaming response with timing + # Non-streaming response with timing and timeout start_time = time.perf_counter() + timeout = request.timeout or _default_timeout choices = [] total_completion_tokens = 0 for i, prompt in enumerate(prompts): - output = await engine.generate( - prompt=prompt, - max_tokens=request.max_tokens or _default_max_tokens, - temperature=request.temperature, - top_p=request.top_p, - stop=request.stop, + try: + output = await asyncio.wait_for( + engine.generate( + prompt=prompt, + max_tokens=request.max_tokens or _default_max_tokens, + temperature=request.temperature, + top_p=request.top_p, + stop=request.stop, + ), + timeout=timeout, + ) + except asyncio.TimeoutError: + raise HTTPException( + status_code=504, detail=f"Request timed out after {timeout:.1f} seconds" + ) + + choices.append( + CompletionChoice( + index=i, + text=output.text, + finish_reason=output.finish_reason, + ) ) - - choices.append(CompletionChoice( - index=i, - text=output.text, - finish_reason=output.finish_reason, - )) total_completion_tokens += output.completion_tokens elapsed = time.perf_counter() - start_time tokens_per_sec = total_completion_tokens / elapsed if elapsed > 0 else 0 - logger.info(f"Completion: {total_completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)") + logger.info( + f"Completion: {total_completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + ) return CompletionResponse( model=request.model, @@ -504,7 +574,10 @@ async def create_completion(request: CompletionRequest): ) -@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)]) +@app.post( + "/v1/chat/completions", + dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], +) async def create_chat_completion(request: ChatCompletionRequest): """ Create a chat completion (supports multimodal content for VLM models). @@ -588,14 +661,24 @@ async def create_chat_completion(request: ChatCompletionRequest): media_type="text/event-stream", ) - # Non-streaming response with timing + # Non-streaming response with timing and timeout start_time = time.perf_counter() + timeout = request.timeout or _default_timeout - output = await engine.chat(messages=messages, **chat_kwargs) + try: + output = await asyncio.wait_for( + engine.chat(messages=messages, **chat_kwargs), timeout=timeout + ) + except asyncio.TimeoutError: + raise HTTPException( + status_code=504, detail=f"Request timed out after {timeout:.1f} seconds" + ) elapsed = time.perf_counter() - start_time tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0 - logger.info(f"Chat completion: {output.completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)") + logger.info( + f"Chat completion: {output.completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + ) # Parse tool calls from output cleaned_text, tool_calls = parse_tool_calls(output.text) @@ -603,8 +686,7 @@ async def create_chat_completion(request: ChatCompletionRequest): # Process response_format if specified if response_format and not tool_calls: cleaned_text, parsed_json, is_valid, error = parse_json_output( - cleaned_text or output.text, - response_format + cleaned_text or output.text, response_format ) if parsed_json is not None: # Return JSON as string @@ -617,13 +699,15 @@ async def create_chat_completion(request: ChatCompletionRequest): return ChatCompletionResponse( model=request.model, - choices=[ChatCompletionChoice( - message=AssistantMessage( - content=clean_output_text(cleaned_text) if cleaned_text else None, - tool_calls=tool_calls, - ), - finish_reason=finish_reason, - )], + choices=[ + ChatCompletionChoice( + message=AssistantMessage( + content=clean_output_text(cleaned_text) if cleaned_text else None, + tool_calls=tool_calls, + ), + finish_reason=finish_reason, + ) + ], usage=Usage( prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, @@ -668,6 +752,7 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: # Streaming Helpers # ============================================================================= + async def stream_completion( engine: BaseEngine, prompt: str, @@ -686,11 +771,13 @@ async def stream_completion( "object": "text_completion", "created": int(time.time()), "model": request.model, - "choices": [{ - "index": 0, - "text": output.new_text, - "finish_reason": output.finish_reason if output.finished else None, - }], + "choices": [ + { + "index": 0, + "text": output.new_text, + "finish_reason": output.finish_reason if output.finished else None, + } + ], } yield f"data: {json.dumps(data)}\n\n" @@ -710,9 +797,11 @@ async def stream_chat_completion( first_chunk = ChatCompletionChunk( id=response_id, model=request.model, - choices=[ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta(role="assistant"), - )], + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta(role="assistant"), + ) + ], ) yield f"data: {first_chunk.model_dump_json()}\n\n" @@ -733,10 +822,14 @@ async def stream_chat_completion( chunk = ChatCompletionChunk( id=response_id, model=request.model, - choices=[ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta(content=content if content else None), - finish_reason=output.finish_reason if output.finished else None, - )], + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + content=content if content else None + ), + finish_reason=output.finish_reason if output.finished else None, + ) + ], ) yield f"data: {chunk.model_dump_json()}\n\n" @@ -747,12 +840,14 @@ async def stream_chat_completion( # MCP Initialization # ============================================================================= + async def init_mcp(config_path: str): """Initialize MCP manager from config file.""" global _mcp_manager, _mcp_executor try: - from vllm_mlx.mcp import MCPClientManager, ToolExecutor, load_mcp_config + from vllm_mlx.mcp import (MCPClientManager, ToolExecutor, + load_mcp_config) config = load_mcp_config(config_path) _mcp_manager = MCPClientManager(config) @@ -774,6 +869,7 @@ async def init_mcp(config_path: str): # Main Entry Point # ============================================================================= + def main(): """Run the server.""" parser = argparse.ArgumentParser( @@ -837,12 +933,32 @@ def main(): default=None, help="API key for authentication (if not set, no auth required)", ) + parser.add_argument( + "--timeout", + type=float, + default=300.0, + help="Default request timeout in seconds (default: 300)", + ) + parser.add_argument( + "--rate-limit", + type=int, + default=0, + help="Rate limit requests per minute per client (0 = disabled)", + ) args = parser.parse_args() - # Set API key globally - global _api_key + # Set global configuration + global _api_key, _default_timeout, _rate_limiter _api_key = args.api_key + _default_timeout = args.timeout + + # Configure rate limiter + if args.rate_limit > 0: + _rate_limiter = RateLimiter(requests_per_minute=args.rate_limit, enabled=True) + logger.info( + f"Rate limiting enabled: {args.rate_limit} requests/minute per client" + ) # Set MCP config for lifespan if args.mcp_config: @@ -858,6 +974,7 @@ def main(): # Start server import uvicorn + uvicorn.run(app, host=args.host, port=args.port)