Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/guides/tool-calling.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ if response.choices[0].message.tool_calls:
print(f"Arguments: {tc.function.arguments}")
```

## MLLM / VLM Tool Calling (I7)

Tool calling now runs through both text (LLM) and multimodal (MLLM/VLM) chat paths.

For VLM models, start the server in MLLM mode and keep parser flags enabled:

```bash
vllm-mlx serve <vlm-model-id> \
--mllm \
--enable-auto-tool-choice \
--tool-call-parser auto
```

Notes:
- `tools` and `tool_choice` are passed into the MLLM chat-template path.
- Structured `tool_calls` are still parser/model-format dependent.
- `tool_choice` is best-effort: templates that do not support it fall back safely.

## Supported Parsers

Use `--tool-call-parser` to select a parser for your model family:
Expand Down
138 changes: 138 additions & 0 deletions tests/test_simple_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,141 @@ async def test_requests_complete_in_order(self, mock_model):
assert len(results) == 3
for result in results:
assert result.text == "test response"


class TestSimpleEngineToolChoicePassthrough:
"""Test tool/tool_choice propagation for LLM and MLLM paths."""

@pytest.mark.asyncio
async def test_mllm_chat_passes_tools_and_tool_choice(self):
from vllm_mlx.engine.simple import SimpleEngine

model = MagicMock()
model.chat = MagicMock(
return_value=MagicMock(
text='<tool_call>{"name":"search_files","arguments":{"q":"x"}}</tool_call>',
prompt_tokens=12,
completion_tokens=4,
finish_reason="stop",
)
)

tools = [
{
"type": "function",
"function": {
"name": "search_files",
"description": "Search files",
"parameters": {
"type": "object",
"properties": {"q": {"type": "string"}},
"required": ["q"],
},
},
}
]
tool_choice = {"type": "function", "function": {"name": "search_files"}}

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True):
engine = SimpleEngine("test-mllm")
engine._model = model
engine._loaded = True

await engine.chat(
messages=[{"role": "user", "content": "Find X"}],
tools=tools,
tool_choice=tool_choice,
max_tokens=32,
)

_, kwargs = model.chat.call_args
assert kwargs["tools"] == tools
assert kwargs["tool_choice"] == tool_choice

@pytest.mark.asyncio
async def test_mllm_stream_chat_passes_tools_and_tool_choice(self):
from vllm_mlx.engine.simple import SimpleEngine

chunk1 = MagicMock()
chunk1.text = "<tool_call>"
chunk1.finish_reason = None
chunk1.prompt_tokens = 8
chunk2 = MagicMock()
chunk2.text = "</tool_call>"
chunk2.finish_reason = "stop"
chunk2.prompt_tokens = 8

model = MagicMock()
model.stream_chat = MagicMock(return_value=iter([chunk1, chunk2]))

tools = [
{
"type": "function",
"function": {
"name": "search_files",
"description": "Search files",
"parameters": {"type": "object", "properties": {}},
},
}
]
tool_choice = "required"

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True):
engine = SimpleEngine("test-mllm")
engine._model = model
engine._loaded = True

outputs = []
async for output in engine.stream_chat(
messages=[{"role": "user", "content": "Find X"}],
tools=tools,
tool_choice=tool_choice,
max_tokens=16,
):
outputs.append(output)

assert outputs
_, kwargs = model.stream_chat.call_args
assert kwargs["tools"] == tools
assert kwargs["tool_choice"] == tool_choice

@pytest.mark.asyncio
async def test_llm_chat_does_not_leak_tool_choice_to_model_call(self):
from vllm_mlx.engine.simple import SimpleEngine

model = MagicMock()
model.tokenizer = MagicMock()
model.tokenizer.apply_chat_template = MagicMock(return_value="prompt")
model.chat = MagicMock(
return_value=MagicMock(
text="ok",
tokens=[1, 2],
finish_reason="stop",
)
)

tools = [
{
"type": "function",
"function": {
"name": "search_files",
"description": "Search files",
"parameters": {"type": "object", "properties": {}},
},
}
]

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-llm")
engine._model = model
engine._loaded = True

await engine.chat(
messages=[{"role": "user", "content": "Find X"}],
tools=tools,
tool_choice="required",
max_tokens=16,
)

_, chat_kwargs = model.chat.call_args
assert "tool_choice" not in chat_kwargs
32 changes: 27 additions & 5 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def _apply_chat_template(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: str | dict | None = None,
num_images: int = 0,
) -> str:
"""Apply chat template to messages.
Expand Down Expand Up @@ -369,6 +370,8 @@ def _apply_chat_template(
}
if tools:
template_kwargs["tools"] = tools
if tool_choice is not None:
template_kwargs["tool_choice"] = tool_choice

try:
return template_applicator.apply_chat_template(
Expand All @@ -377,7 +380,7 @@ def _apply_chat_template(
except TypeError as e:
# Some templates don't accept 'tools'; retry without them.
logger.debug(f"Chat template TypeError, retrying without extras: {e}")
for key in ["tools"]:
for key in ["tools", "tool_choice"]:
if key in template_kwargs:
del template_kwargs[key]
return template_applicator.apply_chat_template(
Expand Down Expand Up @@ -620,11 +623,13 @@ async def chat(

# Convert tools for template
template_tools = convert_tools_for_template(tools) if tools else None
template_tool_choice = kwargs.pop("tool_choice", None)

# Apply chat template
prompt = self._apply_chat_template(
messages,
template_tools,
template_tool_choice,
num_images=len(all_images),
)

Expand All @@ -639,7 +644,10 @@ async def chat(
)

def _compute_prefix_boundary(
self, messages: list[dict[str, Any]], tools: list[dict] | None = None
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: str | dict | None = None,
) -> int:
"""Compute token count for the shared prefix across message variations.

Expand All @@ -661,15 +669,23 @@ def _compute_prefix_boundary(
template_tools = convert_tools_for_template(tools) if tools else None

# Tokenize the real prompt
real_prompt = self._apply_chat_template(messages, template_tools)
real_prompt = self._apply_chat_template(
messages,
template_tools,
tool_choice,
)

# Build a dummy variant with different last user content
dummy_messages = list(messages)
dummy_messages[last_user_idx] = {
**messages[last_user_idx],
"content": "XXXXXXXXXX",
}
dummy_prompt = self._apply_chat_template(dummy_messages, template_tools)
dummy_prompt = self._apply_chat_template(
dummy_messages,
template_tools,
tool_choice,
)

tokenizer = self.tokenizer
if hasattr(tokenizer, "tokenizer"):
Expand Down Expand Up @@ -731,16 +747,22 @@ async def stream_chat(

# Convert tools for template
template_tools = convert_tools_for_template(tools) if tools else None
template_tool_choice = kwargs.pop("tool_choice", None)

# Apply chat template
prompt = self._apply_chat_template(
messages,
template_tools,
template_tool_choice,
num_images=len(all_images),
)

# Compute prefix boundary for cache
prefix_boundary = self._compute_prefix_boundary(messages, tools)
prefix_boundary = self._compute_prefix_boundary(
messages,
tools,
template_tool_choice,
)
if prefix_boundary > 0:
kwargs["prefix_boundary"] = prefix_boundary

Expand Down
10 changes: 9 additions & 1 deletion vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ async def chat(

# Convert tools for template if provided
template_tools = convert_tools_for_template(tools) if tools else None
template_tool_choice = kwargs.pop("tool_choice", None)

async with self._generation_lock:
if self._is_mllm:
Expand All @@ -276,6 +277,8 @@ async def chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
tool_choice=template_tool_choice,
**kwargs,
)
text = clean_output_text(output.text)
Expand Down Expand Up @@ -337,6 +340,7 @@ async def stream_chat(

# Convert tools for template
template_tools = convert_tools_for_template(tools) if tools else None
template_tool_choice = kwargs.pop("tool_choice", None)

# Build prompt using tokenizer
if self._is_mllm:
Expand All @@ -351,6 +355,8 @@ def run_stream():
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
tool_choice=template_tool_choice,
**kwargs,
)
)
Expand Down Expand Up @@ -390,12 +396,14 @@ def run_stream():
}
if template_tools:
template_kwargs["tools"] = template_tools
if template_tool_choice is not None:
template_kwargs["tool_choice"] = template_tool_choice

try:
prompt = tokenizer.apply_chat_template(messages, **template_kwargs)
except TypeError:
# Some templates don't support all kwargs
for key in ["tools", "enable_thinking"]:
for key in ["tools", "tool_choice", "enable_thinking"]:
if key in template_kwargs:
del template_kwargs[key]
prompt = tokenizer.apply_chat_template(messages, **template_kwargs)
Expand Down
Loading