diff --git a/requirements/common.txt b/requirements/common.txt index 05666c5d14b0..b610fd678687 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.17.0 -mistral_common[image] >= 1.10.0 +mistral_common[image] >= 1.11.0 opencv-python-headless >= 4.13.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index dd4c7c24f40c..3d5df9814cae 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -502,7 +502,7 @@ mbstrdecoder==1.1.4 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.10.0 +mistral-common==1.11.0 # via # -c requirements/common.txt # -r requirements/rocm-test.in diff --git a/requirements/test.txt b/requirements/test.txt index 642e589a6a27..c8ff5fcabb28 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -508,7 +508,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.10.0 +mistral-common==1.11.0 # via # -c requirements/common.txt # -r requirements/test.in diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py index faff61150265..2b101e8f98d9 100644 --- a/tests/tokenizers_/test_mistral.py +++ b/tests/tokenizers_/test_mistral.py @@ -3,8 +3,10 @@ from typing import Any +import llguidance import pytest from mistral_common.exceptions import InvalidMessageStructureException +from mistral_common.guidance.grammar_factory import GrammarFactory from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.tokenizers.mistral import ( @@ -2407,3 +2409,29 @@ def test_convert_ids_to_tokens( assert actual_tokens == expected_tokens assert mistral_tokenizer.convert_ids_to_tokens([]) == [] + + def test_grammar_factory(self, mistral_tokenizer: MistralTokenizer) -> None: + # works in this case cause Mistral 7B is < v11 and SPM + if not mistral_tokenizer.is_tekken: + with pytest.raises(AttributeError): + mistral_tokenizer.grammar_factory # noqa: B018 + return + factory = mistral_tokenizer.grammar_factory + assert isinstance(factory, GrammarFactory) + + # Test caching + factory_2 = mistral_tokenizer.grammar_factory + assert factory is factory_2 + + def test_llg_tokenizer(self, mistral_tokenizer: MistralTokenizer) -> None: + if not mistral_tokenizer.is_tekken: + with pytest.raises(ValueError): + mistral_tokenizer.llg_tokenizer # noqa: B018 + return + + llg_tokenizer = mistral_tokenizer.llg_tokenizer + assert isinstance(llg_tokenizer, llguidance.LLTokenizer) + + # Test caching + llg_tokenizer_2 = mistral_tokenizer.llg_tokenizer + assert llg_tokenizer is llg_tokenizer_2 diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index 4be5646669be..750dd4d15fc9 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -3,19 +3,49 @@ import json from collections.abc import Generator +from typing import Any +from unittest.mock import MagicMock, patch import partial_json_parser import pytest from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.request import InstructRequest -from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall +from mistral_common.protocol.instruct.tool_calls import ( + FunctionCall, + ToolCall, +) +from mistral_common.protocol.instruct.tool_calls import ( + NamedToolChoice as MistralNamedToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoice as MistralToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoiceEnum as MistralToolChoiceEnum, +) from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + StructuralTagResponseFormat, +) +from vllm.entrypoints.openai.engine.protocol import FunctionCall as VllmFunctionCall +from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser +from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally from vllm.tokenizers.mistral import MistralTokenizer -from vllm.tool_parsers.mistral_tool_parser import MistralToolParser +from vllm.tool_parsers.mistral_tool_parser import ( + _DEFAULT_JSON_SCHEMA, + MistralStreamingResult, + MistralToolCall, + MistralToolParser, +) @pytest.fixture(scope="module") @@ -40,6 +70,13 @@ def mistral_tool_parser(mistral_tokenizer): return MistralToolParser(mistral_tokenizer) +@pytest.fixture +def non_mistral_parser() -> MistralToolParser: + mock_tokenizer = MagicMock() + mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1} + return MistralToolParser(mock_tokenizer) + + def assert_tool_calls( actual_tool_calls: list[ToolCall] | list[DeltaToolCall], expected_tool_calls: list[ToolCall], @@ -951,3 +988,592 @@ def test_fast_detokenization_text_detection_pre_v11( assert len(delta_message.tool_calls) > 0 assert delta_message.tool_calls[0].function is not None assert delta_message.tool_calls[0].function.name == "add" + + +SAMPLE_TOOLS_DICTS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "add", + "description": "Add two numbers", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + }, + }, +] + + +def _make_request(**kwargs) -> ChatCompletionRequest: + defaults: dict = { + "messages": [], + "model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506", + "tools": SAMPLE_TOOLS_DICTS, + "tool_choice": "auto", + } + defaults.update(kwargs) + return ChatCompletionRequest(**defaults) + + +@pytest.mark.parametrize( + "request_kwargs,expected_mode,expected_parallel", + [ + ({"tool_choice": "auto"}, MistralToolChoiceEnum.auto, True), + ({"tool_choice": "none"}, MistralToolChoiceEnum.none, True), + ({"tool_choice": "required"}, MistralToolChoiceEnum.required, True), + ({"tool_choice": None, "tools": None}, MistralToolChoiceEnum.auto, True), + ( + { + "tool_choice": { + "type": "function", + "function": {"name": "get_weather"}, + } + }, + MistralNamedToolChoice.model_validate( + {"type": "function", "function": {"name": "get_weather"}} + ), + True, + ), + ( + {"tool_choice": "auto", "parallel_tool_calls": False}, + MistralToolChoiceEnum.auto, + False, + ), + ( + {"tool_choice": "auto", "response_format": {"type": "text"}}, + MistralToolChoiceEnum.auto, + True, + ), + ], + ids=[ + "auto", + "none", + "required", + "null_tool_choice", + "named_tool_choice", + "parallel_false", + "response_format_text", + ], +) +def test_adjust_request_grammar_factory( + mistral_tool_parser: MistralToolParser, + request_kwargs: dict, + expected_mode: MistralToolChoice, + expected_parallel: bool, +) -> None: + request = _make_request(**request_kwargs) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + call_kwargs = mock_get_lark.call_args + + assert call_kwargs.kwargs["mode"] == expected_mode + assert call_kwargs.kwargs["json_schema"] is None + assert call_kwargs.kwargs["parallel_tool_calls"] == expected_parallel + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +def test_adjust_request_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None: + with patch.object( + type(mistral_tokenizer), + "supports_grammar", + new_callable=lambda: property(lambda self: False), + ): + parser = MistralToolParser(mistral_tokenizer) + request = _make_request() + result = parser.adjust_request(request) + + assert result.structured_outputs is None + + +@pytest.mark.parametrize( + "tool_choice,expected_skip", + [("auto", False), ("none", True)], + ids=["auto_skip_false", "none_skip_true"], +) +def test_adjust_request_non_mistral_tokenizer( + non_mistral_parser: MistralToolParser, + tool_choice: str, + expected_skip: bool, +) -> None: + request = _make_request(tool_choice=tool_choice) + result = non_mistral_parser.adjust_request(request) + + assert result.skip_special_tokens is expected_skip + + +@pytest.mark.parametrize( + "so_kwargs", + [ + {"regex": r"\d+"}, + {"choice": ["a", "b"]}, + {"structural_tag": '{"key": "value"}'}, + {"grammar": "start: 'hello'"}, + ], + ids=["regex", "choice", "structural_tag", "grammar"], +) +def test_adjust_request_unsupported_structured_outputs( + mistral_tool_parser: MistralToolParser, + so_kwargs: dict, +) -> None: + request = _make_request( + structured_outputs=StructuredOutputsParams(**so_kwargs), + ) + result = mistral_tool_parser.adjust_request(request) + + assert result.structured_outputs == request.structured_outputs + + +def test_adjust_request_unsupported_response_format( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + response_format=StructuralTagResponseFormat( + type="structural_tag", format={"some": "config"} + ), + ) + result = mistral_tool_parser.adjust_request(request) + assert result.structured_outputs is None + assert result.response_format == request.response_format + + +@pytest.mark.parametrize( + "so_kwargs,expected_json_schema", + [ + ({"json_object": True}, _DEFAULT_JSON_SCHEMA), + ({"json": '{"type": "object"}'}, {"type": "object"}), + ( + {"json": {"type": "object", "properties": {"x": {"type": "integer"}}}}, + {"type": "object", "properties": {"x": {"type": "integer"}}}, + ), + ], + ids=["json_object", "json_str", "json_dict"], +) +def test_adjust_request_structured_outputs_generates_grammar( + mistral_tool_parser: MistralToolParser, + so_kwargs: dict, + expected_json_schema: str, +) -> None: + request = _make_request( + structured_outputs=StructuredOutputsParams(**so_kwargs), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +@pytest.mark.parametrize( + "response_format_kwargs,expected_json_schema", + [ + ({"type": "json_object"}, _DEFAULT_JSON_SCHEMA), + ( + { + "type": "json_schema", + "json_schema": { + "name": "my_schema", + "schema": { + "type": "object", + "properties": {"x": {"type": "integer"}}, + }, + }, + }, + {"type": "object", "properties": {"x": {"type": "integer"}}}, + ), + ], + ids=["json_object", "json_schema_with_schema"], +) +def test_adjust_request_response_format_generates_grammar( + mistral_tool_parser: MistralToolParser, + response_format_kwargs: dict, + expected_json_schema: str, +) -> None: + request = _make_request(response_format=response_format_kwargs) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + tool_choice="none", + structured_outputs=StructuredOutputsParams(json='{"type": "object"}'), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_for_json_schema", + wraps=factory.get_lark_for_json_schema, + ) as mock_json_schema: + result = mistral_tool_parser.adjust_request(request) + + mock_json_schema.assert_called_once() + assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"} + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + tool_choice="auto", + structured_outputs=StructuredOutputsParams(json='{"type": "object"}'), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with ( + patch.object( + factory, + "get_lark_for_json_schema", + wraps=factory.get_lark_for_json_schema, + ) as mock_json_schema, + patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_jinja, + ): + result = mistral_tool_parser.adjust_request(request) + + mock_jinja.assert_called_once() + assert mock_jinja.call_args.kwargs["json_schema"] == {"type": "object"} + mock_json_schema.assert_not_called() + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +@pytest.mark.parametrize( + "so, set_from_tool_parser, expected", + [ + (None, False, False), + (StructuredOutputsParams(grammar="user grammar"), False, False), + (StructuredOutputsParams(grammar="factory grammar"), True, True), + ], + ids=["no_structured_outputs", "user_supplied_grammar", "from_tool_parser"], +) +def test_is_mistral_grammar_path( + so: StructuredOutputsParams | None, + set_from_tool_parser: bool, + expected: bool, +) -> None: + request = _make_request(structured_outputs=so) + if set_from_tool_parser: + assert request.structured_outputs is not None + request.structured_outputs._from_tool_parser = True + + assert MistralToolParser.is_mistral_grammar_path(request) == expected + + +@pytest.mark.parametrize( + "tool_calls, expected_len", + [ + (None, 0), + ([], 0), + ([VllmFunctionCall(id="abc123xyz", name="f", arguments="{}")], 1), + ([VllmFunctionCall(name="f", arguments="{}")], 1), + ( + [ + VllmFunctionCall(id="fixed1234", name="a", arguments='{"x": 1}'), + VllmFunctionCall(name="b", arguments='{"y": 2}'), + ], + 2, + ), + ], + ids=["none", "empty", "with_id", "without_id", "mixed"], +) +def test_build_non_streaming_tool_calls( + tool_calls: list[VllmFunctionCall] | None, + expected_len: int, +) -> None: + result = MistralToolParser.build_non_streaming_tool_calls(tool_calls) + assert len(result) == expected_len + + if tool_calls is None: + return + + for i, tc in enumerate(result): + assert isinstance(tc, MistralToolCall) + assert tc.type == "function" + + input_tc = tool_calls[i] + if input_tc.id: + assert tc.id == input_tc.id + else: + assert len(tc.id) == 9 + assert tc.id.isalnum() + + assert tc.function.name == input_tc.name + assert tc.function.arguments == input_tc.arguments + + +class TestExtractMaybeReasoningAndToolStreaming: + r"""Tests for ``MistralToolParser.extract_maybe_reasoning_and_tool_streaming``.""" + + @pytest.fixture + def parser(self) -> MistralToolParser: + mock_tokenizer = MagicMock() + mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1} + return MistralToolParser(mock_tokenizer) + + @pytest.fixture + def request_obj(self) -> ChatCompletionRequest: + return _make_request() + + @staticmethod + def _call( + parser: MistralToolParser, + request: ChatCompletionRequest, + *, + reasoning_parser: Any = None, + previous_text: str = "", + current_text: str = "hello", + delta_text: str = "hello", + previous_token_ids: list[int] | None = None, + current_token_ids: list[int] | None = None, + output_token_ids: list[int] | None = None, + reasoning_ended: bool = False, + added_content_delta: bool = False, + prompt_is_reasoning_end: bool | None = None, + ) -> MistralStreamingResult: + return parser.extract_maybe_reasoning_and_tool_streaming( + reasoning_parser=reasoning_parser, + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids or [], + current_token_ids=current_token_ids or [1, 2, 3], + output_token_ids=output_token_ids or [1, 2, 3], + reasoning_ended=reasoning_ended, + added_content_delta=added_content_delta, + prompt_is_reasoning_end=prompt_is_reasoning_end, + request=request, + ) + + def test_no_reasoning_tools_called( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + tool_delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + function=DeltaFunctionCall(name="f", arguments="{}"), + ) + ] + ) + with patch.object( + parser, "extract_tool_calls_streaming", return_value=tool_delta + ): + result = self._call(parser, request_obj, reasoning_parser=None) + + assert result.tools_called + assert result.delta_message is not None + assert result.delta_message.tool_calls + + def test_no_reasoning_no_tools( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + content_delta = DeltaMessage(content="hello") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call(parser, request_obj, reasoning_parser=None) + + assert not result.tools_called + assert result.delta_message is not None + assert result.delta_message.content == "hello" + + def test_mistral_reasoning_parser_no_think_token( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + content_delta = DeltaMessage(content="direct") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 2, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_not_called() + assert result.delta_message is not None + assert result.delta_message.content == "direct" + + def test_mistral_reasoning_parser_with_think_token( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 999, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert not result.reasoning_ended + assert not result.tools_called + + def test_non_mistral_reasoning_parser_always_expects_thinking( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock() + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 2, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert not result.reasoning_ended + assert not result.tools_called + + def test_reasoning_ended_first_chunk_resets_state( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + content_delta = DeltaMessage(content="content") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ) as mock_extract: + result = self._call( + parser, + request_obj, + reasoning_parser=MagicMock(), + reasoning_ended=True, + added_content_delta=False, + ) + + _, call_kwargs = mock_extract.call_args + assert call_kwargs["previous_text"] == "" + assert call_kwargs["previous_token_ids"] == [] + + assert result.added_content_delta + + def test_pre_v15_ignores_prompt_reasoning_end( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_tokenizer = MagicMock(spec=MistralTokenizer) + mock_tokenizer.version = 13 + parser.model_tokenizer = mock_tokenizer + + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + prompt_is_reasoning_end=True, + current_token_ids=[999, 1, 2], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert not result.reasoning_ended + + def test_non_pre_v15_prompt_reasoning_end( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_tokenizer = MagicMock(spec=MistralTokenizer) + mock_tokenizer.version = 15 + parser.model_tokenizer = mock_tokenizer + + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + + content_delta = DeltaMessage(content="after reasoning") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + prompt_is_reasoning_end=True, + current_token_ids=[999, 1, 2], + ) + + mock_rp.extract_reasoning_streaming.assert_not_called() + assert result.reasoning_ended diff --git a/tests/tool_use/mistral/test_mistral_tool_calls.py b/tests/tool_use/mistral/test_mistral_tool_calls.py index 3c4a543abe41..fdb2846c664d 100644 --- a/tests/tool_use/mistral/test_mistral_tool_calls.py +++ b/tests/tool_use/mistral/test_mistral_tool_calls.py @@ -1,10 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + import openai import pytest -from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL +from tests.tool_use.utils import ( + MESSAGES_ASKING_FOR_TOOLS, + MESSAGES_WITH_TOOL_RESPONSE, + MESSAGES_WITHOUT_TOOLS, + SEARCH_TOOL, + SEED, + WEATHER_TOOL, + ensure_system_prompt, +) + +from .utils import ServerConfig + + +def _requires_tool_parser(server_config: ServerConfig) -> None: + r"""Skip test if server was not started with --tool-call-parser.""" + if "--tool-call-parser" not in server_config.get("arguments", []): + pytest.skip( + f"Skipping: {server_config['model']} not configured with --tool-call-parser" + ) # test: a tool_choice with mistral-tokenizer results in an ID of length 9 @@ -28,3 +48,361 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): assert choice.message.role == "assistant" assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1 assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral + + +@pytest.mark.asyncio +async def test_tool_call_auto( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_tool_parser(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + tool_calls = choice.message.tool_calls + + assert choice.finish_reason == "tool_calls" + assert tool_calls is not None and len(tool_calls) >= 1 + assert tool_calls[0].type == "function" + assert tool_calls[0].function.name == "get_current_weather" + assert isinstance(tool_calls[0].function.arguments, str) + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert "city" in parsed_arguments + assert len(tool_calls[0].id) == 9 + + # --- streaming --- + function_name: str | None = None + function_args_str: str = "" + tool_call_id: str | None = None + role_name: str | None = None + finish_reason_count: int = 0 + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == "tool_calls" + + if chunk.choices[0].delta.role: + assert not role_name or role_name == "assistant" + role_name = "assistant" + + streamed_tool_calls = chunk.choices[0].delta.tool_calls + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + if tool_call.function: + if tool_call.function.name: + assert function_name is None + function_name = tool_call.function.name + if tool_call.function.arguments: + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == "assistant" + assert function_name == "get_current_weather" + streamed_args = json.loads(function_args_str) + assert "city" in streamed_args + assert isinstance(tool_call_id, str) and len(tool_call_id) == 9 + assert parsed_arguments == streamed_args + + +@pytest.mark.asyncio +async def test_tool_call_required( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_tool_parser(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="required", + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + tool_calls = choice.message.tool_calls + + assert choice.finish_reason == "tool_calls" + assert tool_calls is not None and len(tool_calls) >= 1 + assert tool_calls[0].function.name == "get_current_weather" + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert len(tool_calls[0].id) == 9 + + # --- streaming --- + function_name: str | None = None + function_args_str: str = "" + tool_call_id: str | None = None + role_name: str | None = None + finish_reason_count: int = 0 + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="required", + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == "tool_calls" + + if chunk.choices[0].delta.role: + assert not role_name or role_name == "assistant" + role_name = "assistant" + + streamed_tool_calls = chunk.choices[0].delta.tool_calls + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + if tool_call.function: + if tool_call.function.name: + assert function_name is None + function_name = tool_call.function.name + if tool_call.function.arguments: + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == "assistant" + assert function_name == "get_current_weather" + streamed_args = json.loads(function_args_str) + assert isinstance(tool_call_id, str) and len(tool_call_id) == 9 + assert parsed_arguments == streamed_args + + +@pytest.mark.asyncio +async def test_tool_call_none_with_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_tool_parser(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="none", + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "[TOOL_CALLS]" not in choice.message.content + + non_streaming_content = choice.message.content + + # --- streaming --- + chunks: list[str] = [] + finish_reason_count: int = 0 + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="none", + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason != "tool_calls" + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert finish_reason_count == 1 + streamed_content = "".join(chunks) + assert "[TOOL_CALLS]" not in streamed_content + assert streamed_content == non_streaming_content + + +@pytest.mark.asyncio +async def test_chat_without_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config), + temperature=0, + max_completion_tokens=150, + model=model_name, + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + output_text = choice.message.content + + assert output_text is not None and len(output_text) > 0 + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + + # --- streaming --- + chunks: list[str] = [] + finish_reason_count: int = 0 + role_sent: bool = False + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config), + temperature=0, + max_completion_tokens=150, + model=model_name, + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == output_text + + +@pytest.mark.asyncio +async def test_tool_call_with_results( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_tool_parser(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content + + # --- streaming --- + chunks: list[str] = [] + finish_reason_count: int = 0 + role_sent: bool = False + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/tool_use/mistral/utils.py b/tests/tool_use/mistral/utils.py index 4d772ba63793..6f6ee2d8654f 100644 --- a/tests/tool_use/mistral/utils.py +++ b/tests/tool_use/mistral/utils.py @@ -29,4 +29,20 @@ class ServerConfig(TypedDict, total=False): "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally.", }, + "ministral-3b": { + "model": "mistralai/Ministral-3-3B-Instruct-2512", + "arguments": [ + "--tokenizer-mode", + "mistral", + "--tool-call-parser", + "mistral", + "--enable-auto-tool-choice", + ], + "system_prompt": "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally.", + "supports_parallel": True, + }, } diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py index 704ed8b9c9e9..ca8c9b0d7853 100644 --- a/tests/v1/structured_output/test_backend_guidance.py +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -11,6 +11,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.speculative import SpeculativeConfig from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.tokenizers import get_tokenizer from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output.backend_guidance import GuidanceBackend @@ -19,6 +20,14 @@ TOKENIZER = "gpt2" +@pytest.fixture(scope="module") +def mistral_tokenizer(): + return get_tokenizer( + tokenizer_name="mistralai/Mistral-Small-3.2-24B-Instruct-2506", + tokenizer_mode="mistral", + ) + + def test_backend_guidance_rollback_terminated(): # Test that the backend guidance successfully rollbacks from a # terminated state. This can happen with speculative decoding, @@ -187,3 +196,38 @@ def test_grammar_init_async_and_sync(async_grammar): # Verify the grammar can accept valid tokens assert grammar.accept_tokens(request.request_id, prompt) + + +@pytest.mark.parametrize( + "request_type,grammar_spec", + [ + pytest.param( + StructuredOutputOptions.JSON, + '{"type": "object"}', + id="json", + ), + pytest.param( + StructuredOutputOptions.GRAMMAR, + 'start: "hello" | "world"', + id="lark", + ), + ], +) +def test_mistral_tokenizer_compile_grammar( + mistral_tokenizer, + request_type: StructuredOutputOptions, + grammar_spec: str, +) -> None: + vllm_config = VllmConfig( + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), + ) + backend = GuidanceBackend( + vllm_config, + tokenizer=mistral_tokenizer, + vocab_size=mistral_tokenizer.vocab_size, + ) + assert backend.ll_tokenizer is mistral_tokenizer.llg_tokenizer + + grammar = backend.compile_grammar(request_type, grammar_spec) + assert grammar is not None + assert not grammar.is_terminated() diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 533959df6094..16d06f4651a5 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -781,13 +781,6 @@ def check_system_message_content_type(cls, data): return data - @model_validator(mode="before") - @classmethod - def set_include_reasoning_for_none_effort(cls, data: Any) -> Any: - if data.get("reasoning_effort") == "none": - data["include_reasoning"] = False - return data - class BatchChatCompletionRequest(OpenAIBaseModel): """Request model for the /v1/chat/completions/batch endpoint. diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index a426836afd35..e2479179d6dd 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -73,7 +73,10 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser -from vllm.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.tool_parsers.mistral_tool_parser import ( + MistralToolCall, + MistralToolParser, +) from vllm.tool_parsers.utils import partial_json_loads from vllm.utils.collection_utils import as_list from vllm.utils.mistral import is_mistral_tokenizer @@ -134,6 +137,12 @@ def __init__( enable_auto_tools=enable_auto_tools, model_name=self.model_config.model, ) + _is_mistral_tool_parser = self.tool_parser is not None and issubclass( + self.tool_parser, MistralToolParser + ) + if _is_mistral_tool_parser and self.reasoning_parser_cls is not None: + MistralToolParser.model_can_reason = True + self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.enable_prompt_tokens_details = enable_prompt_tokens_details @@ -305,6 +314,11 @@ async def create_chat_completion( else: if not request.include_reasoning: reasoning_ended = True + elif MistralToolParser.is_mistral_grammar_path(request): + # The Mistral grammar already includes an optional + # `think?` rule that handles both reasoning and + # non-reasoning outputs. + reasoning_ended = True elif reasoning_parser: reasoning_ended = reasoning_parser.is_reasoning_end( prompt_token_ids or [] @@ -523,6 +537,8 @@ async def chat_completion_stream_generator( harmony_tools_streamed = [False] * num_choices tools_streamed = [False] * num_choices + is_mistral_grammar_path = MistralToolParser.is_mistral_grammar_path(request) + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name else: @@ -546,7 +562,7 @@ async def chat_completion_stream_generator( # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. - if tool_choice_auto or reasoning_parser: + if is_mistral_grammar_path or tool_choice_auto or reasoning_parser: # These are only required in "auto" tool choice case all_previous_token_ids = [[] for _ in range(num_choices)] # For reasoning parser and tool call all enabled @@ -558,7 +574,7 @@ async def chat_completion_stream_generator( # Prepare the tool parser if it's needed try: - if tool_choice_auto and self.tool_parser: + if (is_mistral_grammar_path or tool_choice_auto) and self.tool_parser: if tokenizer is None: raise ValueError( "Tokenizer not available when `skip_tokenizer_init=True`" @@ -740,7 +756,7 @@ async def chat_completion_stream_generator( delta_message: DeltaMessage | None # just update previous_texts and previous_token_ids - if tool_choice_auto or reasoning_parser: + if is_mistral_grammar_path or tool_choice_auto or reasoning_parser: assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -764,6 +780,33 @@ async def chat_completion_stream_generator( ) ) harmony_tools_streamed[i] |= tools_streamed_flag + # Mistral grammar path: combined reasoning + tool streaming + elif is_mistral_grammar_path: + assert tool_parser is not None + assert isinstance(tool_parser, MistralToolParser) + assert added_content_delta_arr is not None + assert reasoning_end_arr is not None + output_token_ids = as_list(output.token_ids) + result = tool_parser.extract_maybe_reasoning_and_tool_streaming( + reasoning_parser=reasoning_parser, + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + output_token_ids=output_token_ids, + reasoning_ended=reasoning_end_arr[i], + added_content_delta=added_content_delta_arr[i], + prompt_is_reasoning_end=(prompt_is_reasoning_end_arr[i]), + request=request, + ) + delta_message = result.delta_message + reasoning_end_arr[i] = result.reasoning_ended + added_content_delta_arr[i] = result.added_content_delta + current_text = result.current_text + current_token_ids = result.current_token_ids + if result.tools_called: + tools_streamed[i] = True # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: # When encountering think end id in prompt_token_ids @@ -1010,7 +1053,9 @@ async def chat_completion_stream_generator( delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if (tool_choice_auto or reasoning_parser) and not self.use_harmony: + if ( + is_mistral_grammar_path or tool_choice_auto or reasoning_parser + ) and not self.use_harmony: assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -1397,7 +1442,25 @@ async def chat_completion_full_generator( tool_call_class = ( MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall ) - if (not self.enable_auto_tools or not self.tool_parser) and ( + + use_mistral_tool_parser = MistralToolParser.is_mistral_grammar_path(request) + if use_mistral_tool_parser: + tool_call_items = MistralToolParser.build_non_streaming_tool_calls( + tool_calls + ) + if tool_call_items: + auto_tools_called = not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam, + ) + message = ChatMessage( + role=role, + reasoning=reasoning, + content=content, + tool_calls=tool_call_items, + ) + + elif (not self.enable_auto_tools or not self.tool_parser) and ( not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) and request.tool_choice != "required" ): diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index f5f011a96f27..7cf43dc1f45b 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -72,6 +72,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolParser from vllm.tracing import ( contains_trace_headers, extract_trace_headers, @@ -785,16 +786,31 @@ def _parse_tool_calls_from_content( tool_parser_cls: type[ToolParser] | None, content: str | None = None, ) -> tuple[list[FunctionCall] | None, str | None]: + # When the Mistral grammar factory injected structured outputs, + # let the parser handle the output. + use_mistral_tool_parser = ( + isinstance(request, ChatCompletionRequest) + and tool_parser_cls is not None + and issubclass(tool_parser_cls, MistralToolParser) + and MistralToolParser.is_mistral_grammar_path(request=request) + ) + function_calls = list[FunctionCall]() - if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): + if ( + not use_mistral_tool_parser + and request.tool_choice + and isinstance(request.tool_choice, ToolChoiceFunction) + ): assert content is not None # Forced Function Call function_calls.append( FunctionCall(name=request.tool_choice.name, arguments=content) ) content = None # Clear content since tool is called. - elif request.tool_choice and isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam + elif ( + not use_mistral_tool_parser + and request.tool_choice + and isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) ): assert content is not None # Forced Function Call @@ -802,7 +818,7 @@ def _parse_tool_calls_from_content( FunctionCall(name=request.tool_choice.function.name, arguments=content) ) content = None # Clear content since tool is called. - elif request.tool_choice == "required": + elif not use_mistral_tool_parser and request.tool_choice == "required": tool_calls = [] with contextlib.suppress(ValidationError): content = content or "" @@ -817,10 +833,12 @@ def _parse_tool_calls_from_content( ) ) content = None # Clear content since tool is called. - elif ( - tool_parser_cls - and enable_auto_tools - and (request.tool_choice == "auto" or request.tool_choice is None) + elif tool_parser_cls and ( + use_mistral_tool_parser + or ( + enable_auto_tools + and (request.tool_choice == "auto" or request.tool_choice is None) + ) ): if tokenizer is None: raise ValueError( diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 83b41bbda2d0..02b3a7b17cd6 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -52,6 +52,7 @@ prompt_to_seq, ) from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolParser from vllm.utils import random_uuid from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import mt as _mt @@ -534,9 +535,19 @@ async def preprocess_chat( # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser # is set, we want to prevent parsing a tool_call hallucinated by the LLM + # + # Exception: Mistral grammar-capable tokenizers always call + # adjust_request — even for tool_choice="none" — so that the grammar + # factory can prevent special-token leakage. if tool_parser is not None: tool_choice = getattr(request, "tool_choice", "none") - if tool_choice != "none": + tokenizer = renderer.get_tokenizer() + is_mistral_grammar_eligible = ( + issubclass(tool_parser, MistralToolParser) + and is_mistral_tokenizer(tokenizer) + and tokenizer.supports_grammar + ) + if tool_choice != "none" or is_mistral_grammar_eligible: if not isinstance(request, ChatCompletionRequest | ResponsesRequest): msg = ( "Tool usage is only supported " @@ -544,7 +555,6 @@ async def preprocess_chat( f"but got {type(request).__name__}" ) raise NotImplementedError(msg) - tokenizer = renderer.get_tokenizer() request = tool_parser(tokenizer, request.tools).adjust_request( request=request ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 97976b832097..93842620dc4a 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -51,6 +51,8 @@ class StructuredOutputsParams: """CAUTION: Should only be set by Processor._validate_structured_output""" _backend_was_auto: bool = field(default=False, init=False) """CAUTION: Should only be set by Processor._validate_structured_output""" + _from_tool_parser: bool = field(default=False, init=False) + """CAUTION: Should only be set by ToolParser.adjust_request""" def __post_init__(self): """Validate that some fields are mutually exclusive.""" @@ -153,6 +155,10 @@ class RequestOutputKind(Enum): FINAL_ONLY = 2 +def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool: + return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken + + class SamplingParams( PydanticMsgspecMixin, msgspec.Struct, @@ -801,17 +807,23 @@ def _validate_structured_outputs( # xgrammar with no fallback validate_xgrammar_grammar(self) elif backend.startswith("guidance"): + if _is_non_tekken_mistral(tokenizer=tokenizer): + raise ValueError( + "Non-tekken Mistral tokenizers are not supported for the 'guidance'" + " structured output backend. Please either use a more recent " + "Mistral model, the ['xgrammar', 'outlines'] " + "backends or tokenizer_mode='hf' instead." + ) # TODO: ideally we would have the LLTokenizer here as Lark syntax # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. - if is_mistral_tokenizer(tokenizer): - raise ValueError( - "Mistral tokenizer is not supported for the 'guidance' " - "structured output backend. Please use ['xgrammar', 'outlines'] " - "backends or tokenizer_mode='hf' instead." - ) - validate_guidance_grammar(self, tokenizer=None) + validate_guidance_grammar( + self, + tokenizer=tokenizer.llg_tokenizer + if is_mistral_tokenizer(tokenizer) + else None, + ) elif backend == "outlines": # outlines backend validate_structured_output_request_outlines(self) @@ -839,24 +851,30 @@ def _validate_structured_outputs( # or includes some jsonschema feature(s) that # are not supported in xgrammar. + skip_guidance = _is_non_tekken_mistral(tokenizer) + # Check if schema has features unsupported by guidance so_params = self.structured_outputs - skip_guidance = False - if so_params.json: + if not skip_guidance and so_params.json: if isinstance(so_params.json, str): schema = json_mod.loads(so_params.json) else: schema = so_params.json skip_guidance = has_guidance_unsupported_json_features(schema) - if is_mistral_tokenizer(tokenizer) or skip_guidance: - # Fall back to outlines if the tokenizer is Mistral - # or if schema contains features unsupported by guidance + if skip_guidance: + # Fall back to outlines if the tokenizer is non-tekken Mistral or + # the schema contains features unsupported by guidance validate_structured_output_request_outlines(self) self.structured_outputs._backend = "outlines" else: # Fall back to guidance by default. - validate_guidance_grammar(self, tokenizer=None) + validate_guidance_grammar( + self, + tokenizer=tokenizer.llg_tokenizer + if is_mistral_tokenizer(tokenizer) + else None, + ) self.structured_outputs._backend = "guidance" # Remember that this backend was set automatically self.structured_outputs._backend_was_auto = True diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index e20f1edd472e..147dca88877b 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence +from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, cast, overload +from mistral_common.guidance.grammar_factory import GrammarFactory +from mistral_common.guidance.tokenizer import from_mistral_tokenizer from mistral_common.protocol.instruct.request import ( ChatCompletionRequest as MistralChatCompletionRequest, ) @@ -45,6 +48,7 @@ ) if TYPE_CHECKING: + import llguidance from transformers import BatchEncoding logger = init_logger(__name__) @@ -574,3 +578,24 @@ def convert_ids_to_tokens( ] return tokens + + @property + def supports_grammar(self) -> bool: + return GrammarFactory.is_supported(self.mistral) + + @cached_property + def grammar_factory(self) -> GrammarFactory: + if not self.supports_grammar: + raise AttributeError( + "This tokenizer does not support `grammar_factory`. " + "This is only supported for tekken tokenizers with " + "version >= 11." + ) + # Cache grammar factory to avoid creating a llguidance tokenizer at every usage. + return GrammarFactory(self.mistral) + + @cached_property + def llg_tokenizer(self) -> "llguidance.LLTokenizer": + if not self.is_tekken: + raise ValueError("`llg_tokenizer` is only supported for Tekkenizers.") + return from_mistral_tokenizer(self.mistral) diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index dc92522a0520..f53da7fdd97c 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -1,15 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import json from collections.abc import Sequence +from dataclasses import dataclass from enum import Enum, auto from random import choices from string import ascii_letters, digits -from typing import Any +from typing import TYPE_CHECKING, Any import ijson import regex as re +from mistral_common.protocol.instruct.tool_calls import ( + NamedToolChoice as MistralNamedToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + Tool as MistralTool, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoice as MistralToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoiceEnum as MistralToolChoiceEnum, +) from pydantic import Field from vllm.entrypoints.openai.chat_completion.protocol import ( @@ -25,17 +40,25 @@ ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger +from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser +from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) from vllm.utils.mistral import is_mistral_tokenizer +if TYPE_CHECKING: + from vllm.reasoning import ReasoningParser + logger = init_logger(__name__) ALPHANUMERIC = ascii_letters + digits +_DEFAULT_JSON_SCHEMA = {"anyOf": [{"type": "object"}, {"type": "array"}]} + class StreamingState(Enum): """Enum for tracking the current streaming parsing state.""" @@ -71,15 +94,42 @@ def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool: return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11) -class MistralToolParser(ToolParser): +@dataclass +class MistralStreamingResult: + r"""Encapsulates the mutable state returned from + `MistralToolParser.extract_maybe_reasoning_and_tool_streaming`. """ - Tool call parser for Mistral 7B Instruct v0.3, intended for use with - - [`mistral_common`](https://github.com/mistralai/mistral-common/) - - the examples/tool_chat_template_mistral.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + delta_message: DeltaMessage | None + reasoning_ended: bool + added_content_delta: bool + tools_called: bool + current_text: str + current_token_ids: list[int] + + +class MistralToolParser(ToolParser): + r"""Tool call parser for Mistral models, intended for use with either: + + - `mistral_common `_ + (recommended) + - the `examples/tool_chat_template_mistral.jinja` template. + + Used when `--enable-auto-tool-choice --tool-call-parser mistral` are all + set. """ + # Used to generate correct grammar in `adjust_request` + model_can_reason: bool = False + + @staticmethod + def is_mistral_grammar_path(request: ChatCompletionRequest) -> bool: + r"""Check if the request was adjusted via the Mistral grammar factory path.""" + return ( + request.structured_outputs is not None + and request.structured_outputs._from_tool_parser + ) + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) @@ -115,20 +165,270 @@ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest ) -> ChatCompletionRequest | ResponsesRequest: - request = super().adjust_request(request) + so_non_supported_attributes = [ + "regex", + "choice", + "grammar", + # whitespace_pattern is not a constraint type but an option; + # Mistral grammar factory does not support it. + "whitespace_pattern", + "structural_tag", + ] + any_so_non_supported_active = request.structured_outputs is not None and any( + getattr(request.structured_outputs, attribute) is not None + for attribute in so_non_supported_attributes + ) + response_format_non_supported_active = ( + request.response_format is not None + and request.response_format.type == "structural_tag" + ) + if ( not is_mistral_tokenizer(self.model_tokenizer) - and request.tools - and request.tool_choice != "none" + or isinstance(request, ResponsesRequest) + or not self.model_tokenizer.supports_grammar + or any_so_non_supported_active + or response_format_non_supported_active + ): + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + # Do not skip special tokens when using chat template + # with Mistral parser as TOOL_CALL token is needed + # for tool detection. + # Note: we don't want skip_special_tokens=False + # with MistralTokenizer as it is incompatible + request.skip_special_tokens = False + return request + + json_schema: dict[str, Any] | None = None + if request.structured_outputs is not None: + if request.structured_outputs.json_object is not None: + json_schema = _DEFAULT_JSON_SCHEMA + elif request.structured_outputs.json is not None: + if isinstance(request.structured_outputs.json, str): + json_schema = json.loads(request.structured_outputs.json) + else: + json_schema = request.structured_outputs.json + else: + raise ValueError( + "Unsupported request.structured_outputs for MistralToolParser. " + "Only `json` and `json_object` are supported." + ) + elif ( + request.response_format is not None + and request.response_format.type != "text" ): - # Do not skip special tokens when using chat template - # with Mistral parser as TOOL_CALL token is needed - # for tool detection. - # Note: we don't want skip_special_tokens=False - # with MistralTokenizer as it is incompatible - request.skip_special_tokens = False + if request.response_format.type == "json_object": + json_schema = _DEFAULT_JSON_SCHEMA + elif request.response_format.type == "json_schema": + if request.response_format.json_schema is not None: + json_schema = request.response_format.json_schema.json_schema + else: + json_schema = _DEFAULT_JSON_SCHEMA + else: + raise ValueError( + "MistralToolParser only accepts `text`, `json_object` or " + f"`json_schema`, got {request.response_format=}" + ) + # Structured Outputs will be defined. + request.response_format = None + + grammar_factory = self.model_tokenizer.grammar_factory + + # TODO: Once unified parser, improve this. + # The issue is figuring out when a model is a reasoning one or not. + template = grammar_factory.select_jinja_template( + reasoning=self.model_can_reason + ) + + tools = ( + [ + MistralTool.from_openai(openai_tool=tool.model_dump()) + for tool in request.tools + ] + if request.tools is not None + else None + ) + + tool_choice: MistralToolChoice + match request.tool_choice: + case "none" | "auto" | "required": + tool_choice = MistralToolChoiceEnum(request.tool_choice) + case None: + tool_choice = MistralToolChoiceEnum.auto + # _ == Named tool choice + case _: + tool_choice = MistralNamedToolChoice.model_validate( + { + "type": "function", + "function": {"name": request.tool_choice.function.name}, + } + ) + + # Rendering grammar is cached in mistral-common given tools, template and mode. + match tool_choice, json_schema is not None: + case MistralToolChoiceEnum.none, True: + lark_grammar = grammar_factory.get_lark_for_json_schema( + template=template, json_schema=json_schema + ) + case _, _: + lark_grammar = grammar_factory.get_lark_from_jinja( + template=template, + mode=tool_choice, + tools=tools, + json_schema=json_schema, + parallel_tool_calls=request.parallel_tool_calls, + json_only=False, + ) + + request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar) + request.structured_outputs._from_tool_parser = True return request + def extract_maybe_reasoning_and_tool_streaming( + self, + *, + reasoning_parser: ReasoningParser | None, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + output_token_ids: Sequence[int], + reasoning_ended: bool, + added_content_delta: bool, + prompt_is_reasoning_end: bool | None, + request: ChatCompletionRequest, + ) -> MistralStreamingResult: + r"""Streaming extraction with reasoning followed by tool-call parsing. + + This method encapsulates the combined reasoning extraction and + tool-call streaming logic so that the serving layer only needs a + thin routing branch. + + The flow is: + + 1. If a *reasoning_parser* is present and reasoning has **not** ended, + extract reasoning tokens. Pre-v15 models may have pre-filled + `[THINK]...[/THINK]` in system prompts, so we skip the + prompt-level reasoning-end check for those. + 2. Once reasoning ends (or if there is no reasoning parser), delegate + to `extract_tool_calls_streaming` and track whether tools were + called. + + Args: + reasoning_parser: Optional reasoning parser instance. + previous_text: Accumulated text from prior chunks. + current_text: Full accumulated text including current chunk. + delta_text: New text in this chunk. + previous_token_ids: Token ids from prior chunks. + current_token_ids: Full token ids including current chunk. + output_token_ids: Raw output token ids from the engine. + reasoning_ended: Whether reasoning has already ended. + added_content_delta: Whether the first content delta after + reasoning has been emitted. + prompt_is_reasoning_end: Whether the prompt itself ends reasoning. + request: The originating chat completion request. + """ + delta_message: DeltaMessage | None = None + tools_called = False + + # For MistralReasoningParser, only enter the reasoning block when + # the model has actually emitted a [THINK] token. Other reasoning + # parsers always expect thinking to be present. + expect_thinking = ( + not isinstance(reasoning_parser, MistralReasoningParser) + or reasoning_parser.start_token_id in current_token_ids + ) + if reasoning_parser is not None and not reasoning_ended and expect_thinking: + # Pre-v15 models may have pre-filled [THINK]...[/THINK] in + # system prompts, so skip the prompt-level reasoning-end + # check and wait for the output's own end-of-think. + is_pre_v15 = ( + isinstance(self.model_tokenizer, MistralTokenizer) + and self.model_tokenizer.version < 15 + ) + + if not is_pre_v15 and prompt_is_reasoning_end: + reasoning_ended = True + current_token_ids = list(output_token_ids) + else: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + ) + if reasoning_parser.is_reasoning_end_streaming( + current_token_ids, output_token_ids + ): + reasoning_ended = True + current_token_ids = reasoning_parser.extract_content_ids( + output_token_ids + ) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + if not reasoning_ended: + return MistralStreamingResult( + delta_message=delta_message, + reasoning_ended=False, + added_content_delta=added_content_delta, + tools_called=False, + current_text=current_text, + current_token_ids=current_token_ids, + ) + + delta_token_ids = list(output_token_ids) + if reasoning_parser is not None and not added_content_delta: + # First chunk after reasoning ended: reset text state. + added_content_delta = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = self.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_called = True + + return MistralStreamingResult( + delta_message=delta_message, + reasoning_ended=reasoning_ended, + added_content_delta=added_content_delta, + tools_called=tools_called, + current_text=current_text, + current_token_ids=current_token_ids, + ) + + @staticmethod + def build_non_streaming_tool_calls( + tool_calls: list[FunctionCall] | None, + ) -> list[ToolCall]: + r"""Build `MistralToolCall` items for non-streaming responses.""" + if not tool_calls: + return [] + + return [ + MistralToolCall(id=tc.id, function=tc) + if tc.id + else MistralToolCall(function=tc) + for tc in tool_calls + ] + def extract_tool_calls( self, model_output: str, diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 6063a2dc2a6d..31178e9f2462 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.utils.import_utils import LazyLoader +from vllm.utils.mistral import is_mistral_tokenizer from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, StructuredOutputGrammar, @@ -92,9 +93,12 @@ def __post_init__(self): self.vllm_config.structured_outputs_config.disable_additional_properties ) - self.ll_tokenizer = llguidance_hf.from_tokenizer( - self.tokenizer, max(self.vocab_size, len(self.tokenizer)) - ) + if is_mistral_tokenizer(self.tokenizer): + self.ll_tokenizer = self.tokenizer.llg_tokenizer + else: + self.ll_tokenizer = llguidance_hf.from_tokenizer( + self.tokenizer, max(self.vocab_size, len(self.tokenizer)) + ) def compile_grammar( self, request_type: StructuredOutputOptions, grammar_spec: str