diff --git a/tests/test_server.py b/tests/test_server.py index 9fb86a3e..01067b69 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -167,6 +167,27 @@ def test_basic_completion_request(self): assert request.max_tokens is None # uses _default_max_tokens when None +class TestServerCli: + """Test CLI argument parsing.""" + + def test_served_model_name_argument(self): + """Test that --served-model-name is accepted and parsed.""" + from vllm_mlx.server import build_arg_parser + + parser = build_arg_parser() + args = parser.parse_args( + [ + "--model", + "mlx-community/Qwen3.5-4B-MLX-8bit", + "--served-model-name", + "qwen", + ] + ) + + assert args.model == "mlx-community/Qwen3.5-4B-MLX-8bit" + assert args.served_model_name == "qwen" + + # ============================================================================= # Helper Function Tests # ============================================================================= diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index a0038d5f..367c66af 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -2184,8 +2184,8 @@ async def init_mcp(config_path: str): # ============================================================================= -def main(): - """Run the server.""" +def build_arg_parser() -> argparse.ArgumentParser: + """Build the CLI argument parser.""" parser = argparse.ArgumentParser( description="vllm-mlx OpenAI-compatible server for LLM and MLLM inference", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -2207,6 +2207,12 @@ def main(): default="mlx-community/Llama-3.2-3B-Instruct-4bit", help="Model to load (HuggingFace model name or local path)", ) + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="Alias exposed via /v1/models and request validation", + ) parser.add_argument( "--host", type=str, @@ -2291,6 +2297,12 @@ def main(): default=None, help="Default top_p for generation when not specified in request", ) + return parser + + +def main(): + """Run the server.""" + parser = build_arg_parser() args = parser.parse_args() @@ -2348,6 +2360,7 @@ def main(): use_batching=args.continuous_batching, max_tokens=args.max_tokens, force_mllm=args.mllm, + served_model_name=args.served_model_name, ) # Start server