diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 5955109a5..8a90bc9be 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -197,6 +197,7 @@ def serve_command(args): stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, force_mllm=args.mllm, + served_model_name=args.served_model_name, mtp=args.enable_mtp, prefill_step_size=args.prefill_step_size, specprefill_enabled=args.specprefill, @@ -607,6 +608,12 @@ def main(): # Serve command serve_parser = subparsers.add_parser("serve", help="Start OpenAI-compatible server") serve_parser.add_argument("model", type=str, help="Model to serve") + serve_parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. If not specified, the model argument is used.", + ) serve_parser.add_argument( "--host", type=str, default="0.0.0.0", help="Host to bind" ) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f599ddc8f..cf3e66596 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -111,6 +111,9 @@ # Global engine instance _engine: BaseEngine | None = None _model_name: str | None = None +_model_path: str | None = ( + None # Actual model path (for cache dir, not affected by --served-model-name) +) _default_max_tokens: int = 32768 _default_timeout: float = 300.0 # Default request timeout in seconds (5 minutes) _default_temperature: float | None = None # Set via --default-temperature @@ -188,11 +191,14 @@ def _save_prefix_cache_to_disk() -> None: def _get_cache_dir() -> str: - """Get cache persistence directory based on model name.""" - # Use global _model_name which is always a string, set during load_model() - model_name = _model_name if _model_name else "default" + """Get cache persistence directory based on actual model path.""" + # Use _model_path (actual model path) not _model_name (which may be overridden + # by --served-model-name). This ensures cache is shared regardless of served name. + model_name = ( + _model_path if _model_path else (_model_name if _model_name else "default") + ) logger.info( - f"[_get_cache_dir] _model_name={_model_name!r} type={type(_model_name)}" + f"[_get_cache_dir] _model_path={_model_path!r} type={type(_model_path)}" ) # Sanitize model name for filesystem safe_name = str(model_name).replace("/", "--").replace("\\", "--") @@ -335,6 +341,16 @@ def get_engine() -> BaseEngine: return _engine +def _validate_model_name(request_model: str) -> None: + """Validate that the request model name matches the served model.""" + if _model_name and request_model != _model_name: + raise HTTPException( + status_code=404, + detail=f"The model `{request_model}` does not exist. " + f"Available model: `{_model_name}`", + ) + + def _parse_tool_calls_with_parser( output_text: str, request: ChatCompletionRequest | None = None ) -> tuple[str, list | None]: @@ -467,6 +483,7 @@ def load_model( stream_interval: int = 1, max_tokens: int = 32768, force_mllm: bool = False, + served_model_name: str | None = None, mtp: bool = False, prefill_step_size: int = 2048, specprefill_enabled: bool = False, @@ -491,10 +508,11 @@ def load_model( specprefill_keep_pct: Fraction of tokens to keep (default: 0.3) specprefill_draft_model: Path to small draft model for SpecPrefill scoring """ - global _engine, _model_name, _default_max_tokens, _tool_parser_instance + global _engine, _model_name, _model_path, _default_max_tokens, _tool_parser_instance _default_max_tokens = max_tokens - _model_name = model_name + _model_path = model_name + _model_name = served_model_name or model_name # Reset tool parser instance when model is reloaded (tokenizer may change) _tool_parser_instance = None @@ -1188,6 +1206,7 @@ async def _wait_disconnect(): ) async def create_completion(request: CompletionRequest, raw_request: Request): """Create a text completion.""" + _validate_model_name(request.model) engine = get_engine() # Handle single prompt or list of prompts @@ -1252,7 +1271,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ) return CompletionResponse( - model=request.model, + model=_model_name, choices=choices, usage=Usage( prompt_tokens=total_prompt_tokens, @@ -1308,6 +1327,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re } ``` """ + _validate_model_name(request.model) engine = get_engine() # --- Detailed request logging --- @@ -1442,7 +1462,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re finish_reason = "tool_calls" if tool_calls else output.finish_reason return ChatCompletionResponse( - model=request.model, + model=_model_name, choices=[ ChatCompletionChoice( message=AssistantMessage( @@ -1516,6 +1536,8 @@ async def create_anthropic_message( body = await request.json() anthropic_request = AnthropicRequest(**body) + _validate_model_name(anthropic_request.model) + # --- Detailed request logging --- n_msgs = len(anthropic_request.messages) total_chars = 0 @@ -1598,7 +1620,7 @@ async def create_anthropic_message( # Build OpenAI response to convert openai_response = ChatCompletionResponse( - model=openai_request.model, + model=_model_name, choices=[ ChatCompletionChoice( message=AssistantMessage( @@ -1616,7 +1638,7 @@ async def create_anthropic_message( ) # Convert to Anthropic response - anthropic_response = openai_to_anthropic(openai_response, anthropic_request.model) + anthropic_response = openai_to_anthropic(openai_response, _model_name) return Response( content=anthropic_response.model_dump_json(exclude_none=True), media_type="application/json", @@ -1730,7 +1752,7 @@ async def _stream_anthropic_messages( "id": msg_id, "type": "message", "role": "assistant", - "model": anthropic_request.model, + "model": _model_name, "content": [], "stop_reason": None, "stop_sequence": None, @@ -1858,7 +1880,7 @@ async def stream_completion( "id": f"cmpl-{uuid.uuid4().hex[:8]}", "object": "text_completion", "created": int(time.time()), - "model": request.model, + "model": _model_name, "choices": [ { "index": 0, @@ -1890,7 +1912,7 @@ async def stream_chat_completion( # First chunk with role first_chunk = ChatCompletionChunk( id=response_id, - model=request.model, + model=_model_name, choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta(role="assistant"), @@ -1901,7 +1923,9 @@ async def stream_chat_completion( # Track if we need to add prefix for thinking models (when no reasoning parser) # The template adds to the prompt, so the model output starts inside the think block - is_thinking_model = "nemotron" in request.model.lower() and not _reasoning_parser + is_thinking_model = ( + "nemotron" in (engine.model_name or "").lower() and not _reasoning_parser + ) think_prefix_sent = False # Reset reasoning parser state for this stream @@ -1963,7 +1987,7 @@ async def stream_chat_completion( chunk = ChatCompletionChunk( id=response_id, - model=request.model, + model=_model_name, choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( @@ -2015,7 +2039,7 @@ async def stream_chat_completion( tool_calls_detected = True chunk = ChatCompletionChunk( id=response_id, - model=request.model, + model=_model_name, choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( @@ -2036,7 +2060,7 @@ async def stream_chat_completion( chunk = ChatCompletionChunk( id=response_id, - model=request.model, + model=_model_name, choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( @@ -2065,7 +2089,7 @@ async def stream_chat_completion( if result.tools_called: tool_chunk = ChatCompletionChunk( id=response_id, - model=request.model, + model=_model_name, choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( @@ -2099,7 +2123,7 @@ async def stream_chat_completion( if include_usage: usage_chunk = ChatCompletionChunk( id=response_id, - model=request.model, + model=_model_name, choices=[], # Empty choices for usage-only chunk usage=Usage( prompt_tokens=prompt_tokens,