diff --git a/docs/guides/tool-calling.md b/docs/guides/tool-calling.md index f0adc9739..72486f4b4 100644 --- a/docs/guides/tool-calling.md +++ b/docs/guides/tool-calling.md @@ -45,6 +45,24 @@ if response.choices[0].message.tool_calls: print(f"Arguments: {tc.function.arguments}") ``` +## MLLM / VLM Tool Calling (I7) + +Tool calling now runs through both text (LLM) and multimodal (MLLM/VLM) chat paths. + +For VLM models, start the server in MLLM mode and keep parser flags enabled: + +```bash +vllm-mlx serve \ + --mllm \ + --enable-auto-tool-choice \ + --tool-call-parser auto +``` + +Notes: +- `tools` and `tool_choice` are passed into the MLLM chat-template path. +- Structured `tool_calls` are still parser/model-format dependent. +- `tool_choice` is best-effort: templates that do not support it fall back safely. + ## Supported Parsers Use `--tool-call-parser` to select a parser for your model family: diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index cce42bfc3..11fb74ea6 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -211,3 +211,141 @@ async def test_requests_complete_in_order(self, mock_model): assert len(results) == 3 for result in results: assert result.text == "test response" + + +class TestSimpleEngineToolChoicePassthrough: + """Test tool/tool_choice propagation for LLM and MLLM paths.""" + + @pytest.mark.asyncio + async def test_mllm_chat_passes_tools_and_tool_choice(self): + from vllm_mlx.engine.simple import SimpleEngine + + model = MagicMock() + model.chat = MagicMock( + return_value=MagicMock( + text='{"name":"search_files","arguments":{"q":"x"}}', + prompt_tokens=12, + completion_tokens=4, + finish_reason="stop", + ) + ) + + tools = [ + { + "type": "function", + "function": { + "name": "search_files", + "description": "Search files", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ] + tool_choice = {"type": "function", "function": {"name": "search_files"}} + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True): + engine = SimpleEngine("test-mllm") + engine._model = model + engine._loaded = True + + await engine.chat( + messages=[{"role": "user", "content": "Find X"}], + tools=tools, + tool_choice=tool_choice, + max_tokens=32, + ) + + _, kwargs = model.chat.call_args + assert kwargs["tools"] == tools + assert kwargs["tool_choice"] == tool_choice + + @pytest.mark.asyncio + async def test_mllm_stream_chat_passes_tools_and_tool_choice(self): + from vllm_mlx.engine.simple import SimpleEngine + + chunk1 = MagicMock() + chunk1.text = "" + chunk1.finish_reason = None + chunk1.prompt_tokens = 8 + chunk2 = MagicMock() + chunk2.text = "" + chunk2.finish_reason = "stop" + chunk2.prompt_tokens = 8 + + model = MagicMock() + model.stream_chat = MagicMock(return_value=iter([chunk1, chunk2])) + + tools = [ + { + "type": "function", + "function": { + "name": "search_files", + "description": "Search files", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + tool_choice = "required" + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True): + engine = SimpleEngine("test-mllm") + engine._model = model + engine._loaded = True + + outputs = [] + async for output in engine.stream_chat( + messages=[{"role": "user", "content": "Find X"}], + tools=tools, + tool_choice=tool_choice, + max_tokens=16, + ): + outputs.append(output) + + assert outputs + _, kwargs = model.stream_chat.call_args + assert kwargs["tools"] == tools + assert kwargs["tool_choice"] == tool_choice + + @pytest.mark.asyncio + async def test_llm_chat_does_not_leak_tool_choice_to_model_call(self): + from vllm_mlx.engine.simple import SimpleEngine + + model = MagicMock() + model.tokenizer = MagicMock() + model.tokenizer.apply_chat_template = MagicMock(return_value="prompt") + model.chat = MagicMock( + return_value=MagicMock( + text="ok", + tokens=[1, 2], + finish_reason="stop", + ) + ) + + tools = [ + { + "type": "function", + "function": { + "name": "search_files", + "description": "Search files", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-llm") + engine._model = model + engine._loaded = True + + await engine.chat( + messages=[{"role": "user", "content": "Find X"}], + tools=tools, + tool_choice="required", + max_tokens=16, + ) + + _, chat_kwargs = model.chat.call_args + assert "tool_choice" not in chat_kwargs diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628e..9a00071cc 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -334,6 +334,7 @@ def _apply_chat_template( self, messages: list[dict[str, Any]], tools: list[dict] | None = None, + tool_choice: str | dict | None = None, num_images: int = 0, ) -> str: """Apply chat template to messages. @@ -369,6 +370,8 @@ def _apply_chat_template( } if tools: template_kwargs["tools"] = tools + if tool_choice is not None: + template_kwargs["tool_choice"] = tool_choice try: return template_applicator.apply_chat_template( @@ -377,7 +380,7 @@ def _apply_chat_template( except TypeError as e: # Some templates don't accept 'tools'; retry without them. logger.debug(f"Chat template TypeError, retrying without extras: {e}") - for key in ["tools"]: + for key in ["tools", "tool_choice"]: if key in template_kwargs: del template_kwargs[key] return template_applicator.apply_chat_template( @@ -620,11 +623,13 @@ async def chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + template_tool_choice = kwargs.pop("tool_choice", None) # Apply chat template prompt = self._apply_chat_template( messages, template_tools, + template_tool_choice, num_images=len(all_images), ) @@ -639,7 +644,10 @@ async def chat( ) def _compute_prefix_boundary( - self, messages: list[dict[str, Any]], tools: list[dict] | None = None + self, + messages: list[dict[str, Any]], + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, ) -> int: """Compute token count for the shared prefix across message variations. @@ -661,7 +669,11 @@ def _compute_prefix_boundary( template_tools = convert_tools_for_template(tools) if tools else None # Tokenize the real prompt - real_prompt = self._apply_chat_template(messages, template_tools) + real_prompt = self._apply_chat_template( + messages, + template_tools, + tool_choice, + ) # Build a dummy variant with different last user content dummy_messages = list(messages) @@ -669,7 +681,11 @@ def _compute_prefix_boundary( **messages[last_user_idx], "content": "XXXXXXXXXX", } - dummy_prompt = self._apply_chat_template(dummy_messages, template_tools) + dummy_prompt = self._apply_chat_template( + dummy_messages, + template_tools, + tool_choice, + ) tokenizer = self.tokenizer if hasattr(tokenizer, "tokenizer"): @@ -731,16 +747,22 @@ async def stream_chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + template_tool_choice = kwargs.pop("tool_choice", None) # Apply chat template prompt = self._apply_chat_template( messages, template_tools, + template_tool_choice, num_images=len(all_images), ) # Compute prefix boundary for cache - prefix_boundary = self._compute_prefix_boundary(messages, tools) + prefix_boundary = self._compute_prefix_boundary( + messages, + tools, + template_tool_choice, + ) if prefix_boundary > 0: kwargs["prefix_boundary"] = prefix_boundary diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index ed118bbb0..2aa98ec67 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -266,6 +266,7 @@ async def chat( # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None + template_tool_choice = kwargs.pop("tool_choice", None) async with self._generation_lock: if self._is_mllm: @@ -276,6 +277,8 @@ async def chat( messages=messages, max_tokens=max_tokens, temperature=temperature, + tools=template_tools, + tool_choice=template_tool_choice, **kwargs, ) text = clean_output_text(output.text) @@ -337,6 +340,7 @@ async def stream_chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + template_tool_choice = kwargs.pop("tool_choice", None) # Build prompt using tokenizer if self._is_mllm: @@ -351,6 +355,8 @@ def run_stream(): messages=messages, max_tokens=max_tokens, temperature=temperature, + tools=template_tools, + tool_choice=template_tool_choice, **kwargs, ) ) @@ -390,12 +396,14 @@ def run_stream(): } if template_tools: template_kwargs["tools"] = template_tools + if template_tool_choice is not None: + template_kwargs["tool_choice"] = template_tool_choice try: prompt = tokenizer.apply_chat_template(messages, **template_kwargs) except TypeError: # Some templates don't support all kwargs - for key in ["tools", "enable_thinking"]: + for key in ["tools", "tool_choice", "enable_thinking"]: if key in template_kwargs: del template_kwargs[key] prompt = tokenizer.apply_chat_template(messages, **template_kwargs) diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 17b93de01..11fe8e11f 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -1024,6 +1024,8 @@ def chat( messages: list[dict], max_tokens: int = 256, temperature: float = 0.7, + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, **kwargs, ) -> MLLMOutput: """ @@ -1048,6 +1050,10 @@ def chat( from mlx_vlm import generate from mlx_vlm.prompt_utils import get_chat_template + template_tools = tools if tools is not None else kwargs.pop("tools", None) + template_tool_choice = ( + tool_choice if tool_choice is not None else kwargs.pop("tool_choice", None) + ) # Extract text and images from messages # Build chat_messages for multi-turn support WITH proper image tokens per message @@ -1154,6 +1160,16 @@ def chat( ) try: # Use get_chat_template directly since messages are already properly formatted + template_kwargs = {"add_generation_prompt": True} + if template_tools: + template_kwargs["tools"] = template_tools + if template_tool_choice is not None: + template_kwargs["tool_choice"] = template_tool_choice + formatted_prompt = get_chat_template(self.processor, chat_messages, **template_kwargs) + except TypeError as e: + logger.debug( + f"Chat template rejected tools/tool_choice kwargs ({e}); retrying without extras" + ) formatted_prompt = get_chat_template( self.processor, chat_messages, @@ -1376,6 +1392,8 @@ def stream_chat( messages: list[dict], max_tokens: int = 256, temperature: float = 0.7, + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, **kwargs, ) -> Iterator[MLLMOutput]: """ @@ -1401,12 +1419,20 @@ def stream_chat( try: from mlx_vlm import stream_generate from mlx_vlm.prompt_utils import get_chat_template + template_tools = tools if tools is not None else kwargs.pop("tools", None) + template_tool_choice = ( + tool_choice + if tool_choice is not None + else kwargs.pop("tool_choice", None) + ) except ImportError: # Fallback to non-streaming if stream_generate not available output = self.chat( messages=messages, max_tokens=max_tokens, temperature=temperature, + tools=tools, + tool_choice=tool_choice, **kwargs, ) yield output @@ -1506,6 +1532,16 @@ def stream_chat( # Apply chat template directly - messages are already properly structured try: # Use get_chat_template directly since messages are already properly formatted + template_kwargs = {"add_generation_prompt": True} + if template_tools: + template_kwargs["tools"] = template_tools + if template_tool_choice is not None: + template_kwargs["tool_choice"] = template_tool_choice + formatted_prompt = get_chat_template(self.processor, chat_messages, **template_kwargs) + except TypeError as e: + logger.debug( + f"Stream chat template rejected tools/tool_choice kwargs ({e}); retrying without extras" + ) formatted_prompt = get_chat_template( self.processor, chat_messages, diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f0328d4e6..c49938dc9 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1362,6 +1362,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re # Add tools if provided if request.tools: chat_kwargs["tools"] = convert_tools_for_template(request.tools) + if request.tool_choice is not None: + chat_kwargs["tool_choice"] = request.tool_choice if request.stream: return StreamingResponse( @@ -1538,6 +1540,8 @@ async def create_anthropic_message( if openai_request.tools: chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) + if openai_request.tool_choice is not None: + chat_kwargs["tool_choice"] = openai_request.tool_choice start_time = time.perf_counter() timeout = _default_timeout @@ -1695,6 +1699,8 @@ async def _stream_anthropic_messages( if openai_request.tools: chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) + if openai_request.tool_choice is not None: + chat_kwargs["tool_choice"] = openai_request.tool_choice # Emit message_start message_start = {