Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm_mlx/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ class ChatCompletionChunkDelta(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[dict]] = None
reasoning: Optional[str] = (
None # For reasoning/thinking content (Qwen3, DeepSeek-R1)
)


class ChatCompletionChunkChoice(BaseModel):
Expand Down
22 changes: 21 additions & 1 deletion vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def serve_command(args):
server._enable_auto_tool_choice = False
server._tool_call_parser = None

# Configure generation defaults
if args.default_temperature is not None:
server._default_temperature = args.default_temperature
if args.default_top_p is not None:
server._default_top_p = args.default_top_p

# Security summary at startup
print("=" * 60)
print("SECURITY CONFIGURATION")
Expand Down Expand Up @@ -511,14 +517,28 @@ def main():
"nemotron",
"xlam",
"functionary",
"glm47",
],
help=(
"Select the tool call parser for the model. Options: "
"auto (auto-detect), mistral, qwen, llama, hermes, deepseek, "
"kimi, granite, nemotron, xlam, functionary. "
"kimi, granite, nemotron, xlam, functionary, glm47. "
"Required for --enable-auto-tool-choice."
),
)
# Generation defaults
serve_parser.add_argument(
"--default-temperature",
type=float,
default=None,
help="Default temperature for generation when not specified in request",
)
serve_parser.add_argument(
"--default-top-p",
type=float,
default=None,
help="Default top_p for generation when not specified in request",
)

# Bench command
bench_parser = subparsers.add_parser("bench", help="Run benchmark")
Expand Down
200 changes: 190 additions & 10 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
_model_name: str | None = None
_default_max_tokens: int = 32768
_default_timeout: float = 300.0 # Default request timeout in seconds (5 minutes)
_default_temperature: float | None = None # Set via --default-temperature
_default_top_p: float | None = None # Set via --default-top-p

# Global MCP manager
_mcp_manager = None
Expand Down Expand Up @@ -738,8 +740,14 @@ async def create_completion(request: CompletionRequest):
engine.generate(
prompt=prompt,
max_tokens=request.max_tokens or _default_max_tokens,
temperature=request.temperature,
top_p=request.top_p,
temperature=(
request.temperature
if request.temperature is not None
else _default_temperature
),
top_p=(
request.top_p if request.top_p is not None else _default_top_p
),
stop=request.stop,
),
timeout=timeout,
Expand Down Expand Up @@ -856,8 +864,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
# Prepare kwargs
chat_kwargs = {
"max_tokens": request.max_tokens or _default_max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"temperature": (
request.temperature
if request.temperature is not None
else _default_temperature
),
"top_p": request.top_p if request.top_p is not None else _default_top_p,
}

# Add multimodal content
Expand Down Expand Up @@ -989,8 +1001,12 @@ async def stream_completion(
async for output in engine.stream_generate(
prompt=prompt,
max_tokens=request.max_tokens or _default_max_tokens,
temperature=request.temperature,
top_p=request.top_p,
temperature=(
request.temperature
if request.temperature is not None
else _default_temperature
),
top_p=request.top_p if request.top_p is not None else _default_top_p,
stop=request.stop,
):
data = {
Expand Down Expand Up @@ -1020,6 +1036,8 @@ async def stream_chat_completion(
**kwargs,
) -> AsyncIterator[str]:
"""Stream chat completion response."""
global _tool_parser_instance

response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"

# Check if we should include usage in the final chunk
Expand All @@ -1046,18 +1064,47 @@ async def stream_chat_completion(
if _reasoning_parser:
_reasoning_parser.reset_state()

# Track accumulated text for reasoning parser
# Track accumulated text for reasoning parser and tool call parsing
accumulated_text = ""

# Track token counts for usage reporting
prompt_tokens = 0
completion_tokens = 0
last_output = None

# Tool call streaming state
tool_call_enabled = (
_enable_auto_tool_choice
and _tool_call_parser
and request.tools # Only parse if tools were provided in request
)
tool_calls_emitted = False

# Initialize tool parser if needed for streaming
if tool_call_enabled and _tool_parser_instance is None:
try:
parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser)
tokenizer = None
if _engine is not None and hasattr(_engine, "_tokenizer"):
tokenizer = _engine._tokenizer
_tool_parser_instance = parser_cls(tokenizer)
logger.info(
f"Initialized tool call parser for streaming: {_tool_call_parser}"
)
except Exception as e:
logger.warning(f"Failed to initialize tool parser for streaming: {e}")
tool_call_enabled = False

# Reset tool parser state for this stream
if tool_call_enabled and _tool_parser_instance:
_tool_parser_instance.reset()

# Stream content
async for output in engine.stream_chat(messages=messages, **kwargs):
delta_text = output.new_text
last_output = output
previous_text = accumulated_text
accumulated_text += delta_text

# Track token counts from output (updated each chunk)
if hasattr(output, "prompt_tokens") and output.prompt_tokens:
Expand All @@ -1067,8 +1114,6 @@ async def stream_chat_completion(

# Use reasoning parser if enabled
if _reasoning_parser and delta_text:
previous_text = accumulated_text
accumulated_text += delta_text
delta_msg = _reasoning_parser.extract_reasoning_streaming(
previous_text, accumulated_text, delta_text
)
Expand All @@ -1092,8 +1137,62 @@ async def stream_chat_completion(
usage=get_usage(output) if output.finished else None,
)
yield f"data: {chunk.model_dump_json()}\n\n"
elif tool_call_enabled and _tool_parser_instance:
# Tool call parsing path
streaming_result = _tool_parser_instance.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=accumulated_text,
delta_text=delta_text,
)

if streaming_result is None:
# Buffering - inside tool call, don't emit yet
continue

if "tool_calls" in streaming_result and streaming_result["tool_calls"]:
# Emit tool calls chunk
chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
tool_calls=streaming_result["tool_calls"],
),
finish_reason="tool_calls",
)
],
usage=get_usage(output) if output.finished else None,
)
yield f"data: {chunk.model_dump_json()}\n\n"
tool_calls_emitted = True
elif "content" in streaming_result and streaming_result["content"]:
# Emit content chunk
content = streaming_result["content"]

# Add <think> prefix on first content chunk for thinking models
if is_thinking_model and not think_prefix_sent and content:
content = "<think>" + content
think_prefix_sent = True

chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
content=content,
),
finish_reason=(
output.finish_reason if output.finished else None
),
)
],
usage=get_usage(output) if output.finished else None,
)
yield f"data: {chunk.model_dump_json()}\n\n"
else:
# Standard path without reasoning parsing
# Standard path without reasoning or tool call parsing
content = delta_text

# Add <think> prefix on first content chunk for thinking models
Expand All @@ -1116,6 +1215,48 @@ async def stream_chat_completion(
)
yield f"data: {chunk.model_dump_json()}\n\n"

# If tool call parsing is enabled but no tool_calls were emitted during streaming,
# check final accumulated text for tool calls (handles cases where </tool_call>
# wasn't detected during streaming)
if tool_call_enabled and _tool_parser_instance and not tool_calls_emitted:
final_result = _tool_parser_instance.extract_tool_calls(accumulated_text)
if final_result.tools_called:
tool_calls_list = [
{
"index": i,
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"],
},
}
for i, tc in enumerate(final_result.tool_calls)
]
chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
tool_calls=tool_calls_list,
),
finish_reason="tool_calls",
)
],
usage=(
Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if last_output and last_output.finished
else None
),
)
yield f"data: {chunk.model_dump_json()}\n\n"
tool_calls_emitted = True

# Send final chunk with usage if requested
if include_usage:
usage_chunk = ChatCompletionChunk(
Expand Down Expand Up @@ -1248,13 +1389,42 @@ def main():
choices=["qwen3", "deepseek_r1"],
help="Enable reasoning content extraction with specified parser",
)
parser.add_argument(
"--default-temperature",
type=float,
default=None,
help="Default temperature for generation when not specified in request",
)
parser.add_argument(
"--default-top-p",
type=float,
default=None,
help="Default top_p for generation when not specified in request",
)
parser.add_argument(
"--enable-auto-tool-choice",
action="store_true",
help="Enable automatic tool call parsing for models that support it",
)
parser.add_argument(
"--tool-call-parser",
type=str,
default=None,
choices=ToolParserManager.list_registered(),
help="Tool call parser to use (requires --enable-auto-tool-choice)",
)

args = parser.parse_args()

# Set global configuration
global _api_key, _default_timeout, _rate_limiter
global _default_temperature, _default_top_p
_api_key = args.api_key
_default_timeout = args.timeout
if args.default_temperature is not None:
_default_temperature = args.default_temperature
if args.default_top_p is not None:
_default_top_p = args.default_top_p

# Configure rate limiter
if args.rate_limit > 0:
Expand Down Expand Up @@ -1291,6 +1461,16 @@ def main():
_reasoning_parser = parser_cls()
logger.info(f"Reasoning parser enabled: {args.reasoning_parser}")

# Configure tool call parsing
global _enable_auto_tool_choice, _tool_call_parser
if args.enable_auto_tool_choice:
_enable_auto_tool_choice = True
_tool_call_parser = args.tool_call_parser
if _tool_call_parser:
logger.info(f"Tool call parsing enabled with parser: {_tool_call_parser}")
else:
logger.info("Tool call parsing enabled with auto-detection")

# Load model before starting server
load_model(
args.model,
Expand Down
3 changes: 3 additions & 0 deletions vllm_mlx/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- nemotron/nemotron3: NVIDIA Nemotron models
- xlam: Salesforce xLAM models
- functionary/meetkai: MeetKai Functionary models
- glm47/glm4: GLM-4.7 and GLM-4.7-Flash models

Usage:
from vllm_mlx.tool_parsers import ToolParserManager
Expand Down Expand Up @@ -53,6 +54,7 @@
from .nemotron_tool_parser import NemotronToolParser
from .qwen_tool_parser import QwenToolParser
from .xlam_tool_parser import xLAMToolParser
from .glm47_tool_parser import Glm47ToolParser

__all__ = [
# Base classes
Expand All @@ -71,4 +73,5 @@
"NemotronToolParser",
"xLAMToolParser",
"FunctionaryToolParser",
"Glm47ToolParser",
]
Loading
Loading