Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def serve_command(args):
stream_interval=args.stream_interval if args.continuous_batching else 1,
max_tokens=args.max_tokens,
force_mllm=args.mllm,
served_model_name=args.served_model_name,
mtp=args.enable_mtp,
prefill_step_size=args.prefill_step_size,
specprefill_enabled=args.specprefill,
Expand Down Expand Up @@ -607,6 +608,12 @@ def main():
# Serve command
serve_parser = subparsers.add_parser("serve", help="Start OpenAI-compatible server")
serve_parser.add_argument("model", type=str, help="Model to serve")
serve_parser.add_argument(
"--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not specified, the model argument is used.",
)
serve_parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host to bind"
)
Expand Down
62 changes: 43 additions & 19 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@
# Global engine instance
_engine: BaseEngine | None = None
_model_name: str | None = None
_model_path: str | None = (
None # Actual model path (for cache dir, not affected by --served-model-name)
)
_default_max_tokens: int = 32768
_default_timeout: float = 300.0 # Default request timeout in seconds (5 minutes)
_default_temperature: float | None = None # Set via --default-temperature
Expand Down Expand Up @@ -188,11 +191,14 @@ def _save_prefix_cache_to_disk() -> None:


def _get_cache_dir() -> str:
"""Get cache persistence directory based on model name."""
# Use global _model_name which is always a string, set during load_model()
model_name = _model_name if _model_name else "default"
"""Get cache persistence directory based on actual model path."""
# Use _model_path (actual model path) not _model_name (which may be overridden
# by --served-model-name). This ensures cache is shared regardless of served name.
model_name = (
_model_path if _model_path else (_model_name if _model_name else "default")
)
logger.info(
f"[_get_cache_dir] _model_name={_model_name!r} type={type(_model_name)}"
f"[_get_cache_dir] _model_path={_model_path!r} type={type(_model_path)}"
)
# Sanitize model name for filesystem
safe_name = str(model_name).replace("/", "--").replace("\\", "--")
Expand Down Expand Up @@ -335,6 +341,16 @@ def get_engine() -> BaseEngine:
return _engine


def _validate_model_name(request_model: str) -> None:
"""Validate that the request model name matches the served model."""
if _model_name and request_model != _model_name:
raise HTTPException(
status_code=404,
detail=f"The model `{request_model}` does not exist. "
f"Available model: `{_model_name}`",
)


def _parse_tool_calls_with_parser(
output_text: str, request: ChatCompletionRequest | None = None
) -> tuple[str, list | None]:
Expand Down Expand Up @@ -467,6 +483,7 @@ def load_model(
stream_interval: int = 1,
max_tokens: int = 32768,
force_mllm: bool = False,
served_model_name: str | None = None,
mtp: bool = False,
prefill_step_size: int = 2048,
specprefill_enabled: bool = False,
Expand All @@ -491,10 +508,11 @@ def load_model(
specprefill_keep_pct: Fraction of tokens to keep (default: 0.3)
specprefill_draft_model: Path to small draft model for SpecPrefill scoring
"""
global _engine, _model_name, _default_max_tokens, _tool_parser_instance
global _engine, _model_name, _model_path, _default_max_tokens, _tool_parser_instance

_default_max_tokens = max_tokens
_model_name = model_name
_model_path = model_name
_model_name = served_model_name or model_name
# Reset tool parser instance when model is reloaded (tokenizer may change)
_tool_parser_instance = None

Expand Down Expand Up @@ -1188,6 +1206,7 @@ async def _wait_disconnect():
)
async def create_completion(request: CompletionRequest, raw_request: Request):
"""Create a text completion."""
_validate_model_name(request.model)
engine = get_engine()

# Handle single prompt or list of prompts
Expand Down Expand Up @@ -1252,7 +1271,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
)

return CompletionResponse(
model=request.model,
model=_model_name,
choices=choices,
usage=Usage(
prompt_tokens=total_prompt_tokens,
Expand Down Expand Up @@ -1308,6 +1327,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
}
```
"""
_validate_model_name(request.model)
engine = get_engine()

# --- Detailed request logging ---
Expand Down Expand Up @@ -1442,7 +1462,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
finish_reason = "tool_calls" if tool_calls else output.finish_reason

return ChatCompletionResponse(
model=request.model,
model=_model_name,
choices=[
ChatCompletionChoice(
message=AssistantMessage(
Expand Down Expand Up @@ -1516,6 +1536,8 @@ async def create_anthropic_message(
body = await request.json()
anthropic_request = AnthropicRequest(**body)

_validate_model_name(anthropic_request.model)

# --- Detailed request logging ---
n_msgs = len(anthropic_request.messages)
total_chars = 0
Expand Down Expand Up @@ -1598,7 +1620,7 @@ async def create_anthropic_message(

# Build OpenAI response to convert
openai_response = ChatCompletionResponse(
model=openai_request.model,
model=_model_name,
choices=[
ChatCompletionChoice(
message=AssistantMessage(
Expand All @@ -1616,7 +1638,7 @@ async def create_anthropic_message(
)

# Convert to Anthropic response
anthropic_response = openai_to_anthropic(openai_response, anthropic_request.model)
anthropic_response = openai_to_anthropic(openai_response, _model_name)
return Response(
content=anthropic_response.model_dump_json(exclude_none=True),
media_type="application/json",
Expand Down Expand Up @@ -1730,7 +1752,7 @@ async def _stream_anthropic_messages(
"id": msg_id,
"type": "message",
"role": "assistant",
"model": anthropic_request.model,
"model": _model_name,
"content": [],
"stop_reason": None,
"stop_sequence": None,
Expand Down Expand Up @@ -1858,7 +1880,7 @@ async def stream_completion(
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
"object": "text_completion",
"created": int(time.time()),
"model": request.model,
"model": _model_name,
"choices": [
{
"index": 0,
Expand Down Expand Up @@ -1890,7 +1912,7 @@ async def stream_chat_completion(
# First chunk with role
first_chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
model=_model_name,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(role="assistant"),
Expand All @@ -1901,7 +1923,9 @@ async def stream_chat_completion(

# Track if we need to add <think> prefix for thinking models (when no reasoning parser)
# The template adds <think> to the prompt, so the model output starts inside the think block
is_thinking_model = "nemotron" in request.model.lower() and not _reasoning_parser
is_thinking_model = (
"nemotron" in (engine.model_name or "").lower() and not _reasoning_parser
)
think_prefix_sent = False

# Reset reasoning parser state for this stream
Expand Down Expand Up @@ -1963,7 +1987,7 @@ async def stream_chat_completion(

chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
model=_model_name,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
Expand Down Expand Up @@ -2015,7 +2039,7 @@ async def stream_chat_completion(
tool_calls_detected = True
chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
model=_model_name,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
Expand All @@ -2036,7 +2060,7 @@ async def stream_chat_completion(

chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
model=_model_name,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
Expand Down Expand Up @@ -2065,7 +2089,7 @@ async def stream_chat_completion(
if result.tools_called:
tool_chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
model=_model_name,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
Expand Down Expand Up @@ -2099,7 +2123,7 @@ async def stream_chat_completion(
if include_usage:
usage_chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
model=_model_name,
choices=[], # Empty choices for usage-only chunk
usage=Usage(
prompt_tokens=prompt_tokens,
Expand Down
Loading