From 0fa177a989509773138d6d521ccaf2676c0f148a Mon Sep 17 00:00:00 2001 From: juliendenize Date: Wed, 25 Mar 2026 09:12:06 +0000 Subject: [PATCH 01/16] Add Mistral grammar Signed-off-by: juliendenize --- tests/tokenizers_/test_mistral.py | 28 +++ .../tool_parsers/test_mistral_tool_parser.py | 191 +++++++++++++++++- .../test_backend_guidance.py | 44 ++++ vllm/sampling_params.py | 27 ++- vllm/tokenizers/mistral.py | 25 +++ vllm/tool_parsers/mistral_tool_parser.py | 89 +++++++- vllm/v1/structured_output/backend_guidance.py | 10 +- 7 files changed, 388 insertions(+), 26 deletions(-) diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py index faff61150265..2b101e8f98d9 100644 --- a/tests/tokenizers_/test_mistral.py +++ b/tests/tokenizers_/test_mistral.py @@ -3,8 +3,10 @@ from typing import Any +import llguidance import pytest from mistral_common.exceptions import InvalidMessageStructureException +from mistral_common.guidance.grammar_factory import GrammarFactory from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.tokenizers.mistral import ( @@ -2407,3 +2409,29 @@ def test_convert_ids_to_tokens( assert actual_tokens == expected_tokens assert mistral_tokenizer.convert_ids_to_tokens([]) == [] + + def test_grammar_factory(self, mistral_tokenizer: MistralTokenizer) -> None: + # works in this case cause Mistral 7B is < v11 and SPM + if not mistral_tokenizer.is_tekken: + with pytest.raises(AttributeError): + mistral_tokenizer.grammar_factory # noqa: B018 + return + factory = mistral_tokenizer.grammar_factory + assert isinstance(factory, GrammarFactory) + + # Test caching + factory_2 = mistral_tokenizer.grammar_factory + assert factory is factory_2 + + def test_llg_tokenizer(self, mistral_tokenizer: MistralTokenizer) -> None: + if not mistral_tokenizer.is_tekken: + with pytest.raises(ValueError): + mistral_tokenizer.llg_tokenizer # noqa: B018 + return + + llg_tokenizer = mistral_tokenizer.llg_tokenizer + assert isinstance(llg_tokenizer, llguidance.LLTokenizer) + + # Test caching + llg_tokenizer_2 = mistral_tokenizer.llg_tokenizer + assert llg_tokenizer is llg_tokenizer_2 diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index 4be5646669be..bade77459e53 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -3,15 +3,35 @@ import json from collections.abc import Generator +from unittest.mock import MagicMock, patch import partial_json_parser import pytest from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.request import InstructRequest -from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall +from mistral_common.protocol.instruct.tool_calls import ( + FunctionCall, + ToolCall, +) +from mistral_common.protocol.instruct.tool_calls import ( + NamedToolChoice as MistralNamedToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoice as MistralToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoiceEnum as MistralToolChoiceEnum, +) from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaMessage, + DeltaToolCall, +) +from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally from vllm.tokenizers.mistral import MistralTokenizer @@ -40,6 +60,13 @@ def mistral_tool_parser(mistral_tokenizer): return MistralToolParser(mistral_tokenizer) +@pytest.fixture +def non_mistral_parser() -> MistralToolParser: + mock_tokenizer = MagicMock() + mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1} + return MistralToolParser(mock_tokenizer) + + def assert_tool_calls( actual_tool_calls: list[ToolCall] | list[DeltaToolCall], expected_tool_calls: list[ToolCall], @@ -951,3 +978,163 @@ def test_fast_detokenization_text_detection_pre_v11( assert len(delta_message.tool_calls) > 0 assert delta_message.tool_calls[0].function is not None assert delta_message.tool_calls[0].function.name == "add" + + +SAMPLE_TOOLS_DICTS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "add", + "description": "Add two numbers", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + }, + }, +] + + +def _make_request(**kwargs) -> ChatCompletionRequest: + defaults: dict = { + "messages": [], + "model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506", + "tools": SAMPLE_TOOLS_DICTS, + "tool_choice": "auto", + } + defaults.update(kwargs) + return ChatCompletionRequest(**defaults) + + +@pytest.mark.parametrize( + "request_kwargs,expected_mode,expected_parallel", + [ + ({"tool_choice": "auto"}, MistralToolChoiceEnum.auto, True), + ({"tool_choice": "none"}, MistralToolChoiceEnum.none, True), + ({"tool_choice": "required"}, MistralToolChoiceEnum.required, True), + ({"tool_choice": None, "tools": None}, MistralToolChoiceEnum.auto, True), + ( + { + "tool_choice": { + "type": "function", + "function": {"name": "get_weather"}, + } + }, + MistralNamedToolChoice.model_validate( + {"type": "function", "function": {"name": "get_weather"}} + ), + True, + ), + ( + {"tool_choice": "auto", "parallel_tool_calls": False}, + MistralToolChoiceEnum.auto, + False, + ), + ( + {"tool_choice": "auto", "response_format": {"type": "text"}}, + MistralToolChoiceEnum.auto, + True, + ), + ], + ids=[ + "auto", + "none", + "required", + "null_tool_choice", + "named_tool_choice", + "parallel_false", + "response_format_text", + ], +) +def test_adjust_request_grammar_factory( + mistral_tool_parser: MistralToolParser, + request_kwargs: dict, + expected_mode: MistralToolChoice, + expected_parallel: bool, +) -> None: + request = _make_request(**request_kwargs) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + call_kwargs = mock_get_lark.call_args + + assert call_kwargs.kwargs["mode"] == expected_mode + assert call_kwargs.kwargs["json_schema"] is None + assert call_kwargs.kwargs["parallel_tool_calls"] == expected_parallel + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +@pytest.mark.parametrize( + "request_kwargs", + [ + {"structured_outputs": StructuredOutputsParams(json='{"type": "object"}')}, + {"response_format": {"type": "json_object"}}, + ], + ids=["existing_structured_outputs", "response_format_json_object"], +) +def test_user_grammar( + mistral_tool_parser: MistralToolParser, request_kwargs: dict +) -> None: + original_so = request_kwargs.get("structured_outputs") + request = _make_request(**request_kwargs) + result = mistral_tool_parser.adjust_request(request) + + if original_so is not None: + assert result.structured_outputs is original_so + else: + assert result.structured_outputs is None + + +def test_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None: + with patch.object( + type(mistral_tokenizer), + "supports_grammar", + new_callable=lambda: property(lambda self: False), + ): + parser = MistralToolParser(mistral_tokenizer) + request = _make_request() + result = parser.adjust_request(request) + + assert result.structured_outputs is None + + +@pytest.mark.parametrize( + "tool_choice,expected_skip", + [("auto", False), ("none", True)], + ids=["auto_skip_false", "none_skip_true"], +) +def test_non_mistral_tokenizer( + non_mistral_parser: MistralToolParser, + tool_choice: str, + expected_skip: bool, +) -> None: + request = _make_request(tool_choice=tool_choice) + result = non_mistral_parser.adjust_request(request) + + assert result.skip_special_tokens is expected_skip diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py index 704ed8b9c9e9..ca8c9b0d7853 100644 --- a/tests/v1/structured_output/test_backend_guidance.py +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -11,6 +11,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.speculative import SpeculativeConfig from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.tokenizers import get_tokenizer from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output.backend_guidance import GuidanceBackend @@ -19,6 +20,14 @@ TOKENIZER = "gpt2" +@pytest.fixture(scope="module") +def mistral_tokenizer(): + return get_tokenizer( + tokenizer_name="mistralai/Mistral-Small-3.2-24B-Instruct-2506", + tokenizer_mode="mistral", + ) + + def test_backend_guidance_rollback_terminated(): # Test that the backend guidance successfully rollbacks from a # terminated state. This can happen with speculative decoding, @@ -187,3 +196,38 @@ def test_grammar_init_async_and_sync(async_grammar): # Verify the grammar can accept valid tokens assert grammar.accept_tokens(request.request_id, prompt) + + +@pytest.mark.parametrize( + "request_type,grammar_spec", + [ + pytest.param( + StructuredOutputOptions.JSON, + '{"type": "object"}', + id="json", + ), + pytest.param( + StructuredOutputOptions.GRAMMAR, + 'start: "hello" | "world"', + id="lark", + ), + ], +) +def test_mistral_tokenizer_compile_grammar( + mistral_tokenizer, + request_type: StructuredOutputOptions, + grammar_spec: str, +) -> None: + vllm_config = VllmConfig( + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), + ) + backend = GuidanceBackend( + vllm_config, + tokenizer=mistral_tokenizer, + vocab_size=mistral_tokenizer.vocab_size, + ) + assert backend.ll_tokenizer is mistral_tokenizer.llg_tokenizer + + grammar = backend.compile_grammar(request_type, grammar_spec) + assert grammar is not None + assert not grammar.is_terminated() diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 97976b832097..3b7fdebfa15e 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -153,6 +153,10 @@ class RequestOutputKind(Enum): FINAL_ONLY = 2 +def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool: + return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken + + class SamplingParams( PydanticMsgspecMixin, msgspec.Struct, @@ -801,16 +805,16 @@ def _validate_structured_outputs( # xgrammar with no fallback validate_xgrammar_grammar(self) elif backend.startswith("guidance"): + if _is_non_tekken_mistral(tokenizer=tokenizer): + raise ValueError( + "Non-tekken Mistral tokenizers are not supported for the 'guidance'" + " structured output backend. Please use ['xgrammar', 'outlines'] " + "backends or tokenizer_mode='hf' instead." + ) # TODO: ideally we would have the LLTokenizer here as Lark syntax # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. - if is_mistral_tokenizer(tokenizer): - raise ValueError( - "Mistral tokenizer is not supported for the 'guidance' " - "structured output backend. Please use ['xgrammar', 'outlines'] " - "backends or tokenizer_mode='hf' instead." - ) validate_guidance_grammar(self, tokenizer=None) elif backend == "outlines": # outlines backend @@ -839,19 +843,20 @@ def _validate_structured_outputs( # or includes some jsonschema feature(s) that # are not supported in xgrammar. + skip_guidance = _is_non_tekken_mistral(tokenizer) + # Check if schema has features unsupported by guidance so_params = self.structured_outputs - skip_guidance = False - if so_params.json: + if not skip_guidance and so_params.json: if isinstance(so_params.json, str): schema = json_mod.loads(so_params.json) else: schema = so_params.json skip_guidance = has_guidance_unsupported_json_features(schema) - if is_mistral_tokenizer(tokenizer) or skip_guidance: - # Fall back to outlines if the tokenizer is Mistral - # or if schema contains features unsupported by guidance + if skip_guidance: + # Fall back to outlines if the tokenizer is non-tekken Mistral or + # the schema contains features unsupported by guidance validate_structured_output_request_outlines(self) self.structured_outputs._backend = "outlines" else: diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index e20f1edd472e..147dca88877b 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence +from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, cast, overload +from mistral_common.guidance.grammar_factory import GrammarFactory +from mistral_common.guidance.tokenizer import from_mistral_tokenizer from mistral_common.protocol.instruct.request import ( ChatCompletionRequest as MistralChatCompletionRequest, ) @@ -45,6 +48,7 @@ ) if TYPE_CHECKING: + import llguidance from transformers import BatchEncoding logger = init_logger(__name__) @@ -574,3 +578,24 @@ def convert_ids_to_tokens( ] return tokens + + @property + def supports_grammar(self) -> bool: + return GrammarFactory.is_supported(self.mistral) + + @cached_property + def grammar_factory(self) -> GrammarFactory: + if not self.supports_grammar: + raise AttributeError( + "This tokenizer does not support `grammar_factory`. " + "This is only supported for tekken tokenizers with " + "version >= 11." + ) + # Cache grammar factory to avoid creating a llguidance tokenizer at every usage. + return GrammarFactory(self.mistral) + + @cached_property + def llg_tokenizer(self) -> "llguidance.LLTokenizer": + if not self.is_tekken: + raise ValueError("`llg_tokenizer` is only supported for Tekkenizers.") + return from_mistral_tokenizer(self.mistral) diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index dc92522a0520..bd19a221ac6f 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -10,6 +10,18 @@ import ijson import regex as re +from mistral_common.protocol.instruct.tool_calls import ( + NamedToolChoice as MistralNamedToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + Tool as MistralTool, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoice as MistralToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoiceEnum as MistralToolChoiceEnum, +) from pydantic import Field from vllm.entrypoints.openai.chat_completion.protocol import ( @@ -25,6 +37,7 @@ ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger +from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( Tool, @@ -80,6 +93,9 @@ class MistralToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser mistral are all set """ + # Used to generate correct grammar in `adjust_request` + model_can_reason: bool = False + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) @@ -115,18 +131,71 @@ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest ) -> ChatCompletionRequest | ResponsesRequest: - request = super().adjust_request(request) + if not is_mistral_tokenizer(self.model_tokenizer): + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + # Do not skip special tokens when using chat template + # with Mistral parser as TOOL_CALL token is needed + # for tool detection. + # Note: we don't want skip_special_tokens=False + # with MistralTokenizer as it is incompatible + request.skip_special_tokens = False + return request + if ( - not is_mistral_tokenizer(self.model_tokenizer) - and request.tools - and request.tool_choice != "none" + not self.model_tokenizer.supports_grammar + or request.structured_outputs is not None + or ( + request.response_format is not None + and request.response_format.type != "text" + ) ): - # Do not skip special tokens when using chat template - # with Mistral parser as TOOL_CALL token is needed - # for tool detection. - # Note: we don't want skip_special_tokens=False - # with MistralTokenizer as it is incompatible - request.skip_special_tokens = False + return request + + grammar_factory = self.model_tokenizer.grammar_factory + + # TODO: Once unified parser, improve this. + # The issue is figuring out when a model is a reasoning one or not. + template = grammar_factory.select_jinja_template( + reasoning=request.include_reasoning + and request.reasoning_effort != "none" + and self.model_can_reason + ) + + tools = ( + [ + MistralTool.from_openai(openai_tool=tool.model_dump()) + for tool in request.tools + ] + if request.tools is not None + else None + ) + + tool_choice: MistralToolChoice + match request.tool_choice: + case "none" | "auto" | "required": + tool_choice = MistralToolChoiceEnum(request.tool_choice) + case None: + tool_choice = MistralToolChoiceEnum.auto + # _ == Named tool choice + case _: + tool_choice = MistralNamedToolChoice.model_validate( + { + "type": "function", + "function": {"name": request.tool_choice.function.name}, + } + ) + + # Rendering grammar is cached in mistral-common given tools, template and mode. + lark_grammar = grammar_factory.get_lark_from_jinja( + template=template, + mode=tool_choice, + tools=tools, + json_schema=None, + parallel_tool_calls=request.parallel_tool_calls, + ) + + request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar) return request def extract_tool_calls( diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 6063a2dc2a6d..31178e9f2462 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.utils.import_utils import LazyLoader +from vllm.utils.mistral import is_mistral_tokenizer from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, StructuredOutputGrammar, @@ -92,9 +93,12 @@ def __post_init__(self): self.vllm_config.structured_outputs_config.disable_additional_properties ) - self.ll_tokenizer = llguidance_hf.from_tokenizer( - self.tokenizer, max(self.vocab_size, len(self.tokenizer)) - ) + if is_mistral_tokenizer(self.tokenizer): + self.ll_tokenizer = self.tokenizer.llg_tokenizer + else: + self.ll_tokenizer = llguidance_hf.from_tokenizer( + self.tokenizer, max(self.vocab_size, len(self.tokenizer)) + ) def compile_grammar( self, request_type: StructuredOutputOptions, grammar_spec: str From 0be486137b4c7f671c7d533765901f25b28b6fd1 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Thu, 26 Mar 2026 16:32:30 +0000 Subject: [PATCH 02/16] Improve adjust_request to support json and redirect Signed-off-by: juliendenize --- .../tool_parsers/test_mistral_tool_parser.py | 216 ++++++++++++++++-- vllm/tool_parsers/mistral_tool_parser.py | 93 ++++++-- 2 files changed, 268 insertions(+), 41 deletions(-) diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index bade77459e53..c254562b2a63 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -30,12 +30,16 @@ from vllm.entrypoints.openai.engine.protocol import ( DeltaMessage, DeltaToolCall, + StructuralTagResponseFormat, ) from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally from vllm.tokenizers.mistral import MistralTokenizer -from vllm.tool_parsers.mistral_tool_parser import MistralToolParser +from vllm.tool_parsers.mistral_tool_parser import ( + _DEFAULT_JSON_SCHEMA, + MistralToolParser, +) @pytest.fixture(scope="module") @@ -1090,28 +1094,7 @@ def test_adjust_request_grammar_factory( assert len(result.structured_outputs.grammar) > 0 -@pytest.mark.parametrize( - "request_kwargs", - [ - {"structured_outputs": StructuredOutputsParams(json='{"type": "object"}')}, - {"response_format": {"type": "json_object"}}, - ], - ids=["existing_structured_outputs", "response_format_json_object"], -) -def test_user_grammar( - mistral_tool_parser: MistralToolParser, request_kwargs: dict -) -> None: - original_so = request_kwargs.get("structured_outputs") - request = _make_request(**request_kwargs) - result = mistral_tool_parser.adjust_request(request) - - if original_so is not None: - assert result.structured_outputs is original_so - else: - assert result.structured_outputs is None - - -def test_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None: +def test_adjust_request_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None: with patch.object( type(mistral_tokenizer), "supports_grammar", @@ -1129,7 +1112,7 @@ def test_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None: [("auto", False), ("none", True)], ids=["auto_skip_false", "none_skip_true"], ) -def test_non_mistral_tokenizer( +def test_adjust_request_non_mistral_tokenizer( non_mistral_parser: MistralToolParser, tool_choice: str, expected_skip: bool, @@ -1138,3 +1121,188 @@ def test_non_mistral_tokenizer( result = non_mistral_parser.adjust_request(request) assert result.skip_special_tokens is expected_skip + + +@pytest.mark.parametrize( + "so_kwargs", + [ + {"regex": r"\d+"}, + {"choice": ["a", "b"]}, + {"structural_tag": '{"key": "value"}'}, + {"grammar": "start: 'hello'"}, + ], + ids=["regex", "choice", "structural_tag", "grammar"], +) +def test_adjust_request_unsupported_structured_outputs( + mistral_tool_parser: MistralToolParser, + so_kwargs: dict, +) -> None: + request = _make_request( + structured_outputs=StructuredOutputsParams(**so_kwargs), + ) + result = mistral_tool_parser.adjust_request(request) + + assert result.structured_outputs == request.structured_outputs + + +def test_adjust_request_unsupported_response_format( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + response_format=StructuralTagResponseFormat( + type="structural_tag", format={"some": "config"} + ), + ) + result = mistral_tool_parser.adjust_request(request) + assert result.structured_outputs is None + assert result.response_format == request.response_format + + +@pytest.mark.parametrize( + "so_kwargs,expected_json_schema", + [ + ({"json_object": True}, _DEFAULT_JSON_SCHEMA), + ({"json": '{"type": "object"}'}, '{"type": "object"}'), + ( + {"json": {"type": "object", "properties": {"x": {"type": "integer"}}}}, + json.dumps( + {"type": "object", "properties": {"x": {"type": "integer"}}}, + ensure_ascii=False, + ), + ), + ], + ids=["json_object", "json_str", "json_dict"], +) +def test_adjust_request_structured_outputs_generates_grammar( + mistral_tool_parser: MistralToolParser, + so_kwargs: dict, + expected_json_schema: str, +) -> None: + request = _make_request( + structured_outputs=StructuredOutputsParams(**so_kwargs), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +@pytest.mark.parametrize( + "response_format_kwargs,expected_json_schema", + [ + ({"type": "json_object"}, _DEFAULT_JSON_SCHEMA), + ( + { + "type": "json_schema", + "json_schema": { + "name": "my_schema", + "schema": { + "type": "object", + "properties": {"x": {"type": "integer"}}, + }, + }, + }, + json.dumps( + {"type": "object", "properties": {"x": {"type": "integer"}}}, + ensure_ascii=False, + ), + ), + ], + ids=["json_object", "json_schema_with_schema"], +) +def test_adjust_request_response_format_generates_grammar( + mistral_tool_parser: MistralToolParser, + response_format_kwargs: dict, + expected_json_schema: str, +) -> None: + request = _make_request(response_format=response_format_kwargs) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + tool_choice="none", + structured_outputs=StructuredOutputsParams(json='{"type": "object"}'), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with ( + patch.object( + factory, + "get_lark_for_json_schema", + wraps=factory.get_lark_for_json_schema, + ) as mock_json_schema, + patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_jinja, + ): + result = mistral_tool_parser.adjust_request(request) + + mock_json_schema.assert_called_once() + assert mock_json_schema.call_args.kwargs["json_schema"] == '{"type": "object"}' + mock_jinja.assert_not_called() + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + tool_choice="auto", + structured_outputs=StructuredOutputsParams(json='{"type": "object"}'), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with ( + patch.object( + factory, + "get_lark_for_json_schema", + wraps=factory.get_lark_for_json_schema, + ) as mock_json_schema, + patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_jinja, + ): + result = mistral_tool_parser.adjust_request(request) + + mock_jinja.assert_called_once() + assert mock_jinja.call_args.kwargs["json_schema"] == '{"type": "object"}' + mock_json_schema.assert_not_called() + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index bd19a221ac6f..393fe09d3966 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -49,6 +49,10 @@ ALPHANUMERIC = ascii_letters + digits +_DEFAULT_JSON_SCHEMA = json.dumps( + {"anyOf": [{"type": "object"}, {"type": "array"}]}, ensure_ascii=False +) + class StreamingState(Enum): """Enum for tracking the current streaming parsing state.""" @@ -131,7 +135,31 @@ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest ) -> ChatCompletionRequest | ResponsesRequest: - if not is_mistral_tokenizer(self.model_tokenizer): + so_non_supported_attributes = [ + "regex", + "choice", + "grammar", + # whitespace_pattern is not a constraint type but an option; + # Mistral grammar factory does not support it. + "whitespace_pattern", + "structural_tag", + ] + structured_outputs = request.structured_outputs + response_format = request.response_format + any_so_non_supported_active = structured_outputs is not None and any( + getattr(structured_outputs, attribute) is not None + for attribute in so_non_supported_attributes + ) + response_format_non_supported_active = ( + response_format is not None and response_format.type == "structural_tag" + ) + + if ( + not is_mistral_tokenizer(self.model_tokenizer) + or not self.model_tokenizer.supports_grammar + or any_so_non_supported_active + or response_format_non_supported_active + ): request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # Do not skip special tokens when using chat template @@ -142,15 +170,40 @@ def adjust_request( request.skip_special_tokens = False return request - if ( - not self.model_tokenizer.supports_grammar - or request.structured_outputs is not None - or ( - request.response_format is not None - and request.response_format.type != "text" - ) - ): - return request + json_schema: str | None = None + if structured_outputs is not None: + if structured_outputs.json_object is not None: + json_schema = _DEFAULT_JSON_SCHEMA + elif structured_outputs.json is not None: + if isinstance(structured_outputs.json, str): + json_schema = structured_outputs.json + else: + json_schema = json.dumps( + structured_outputs.json, ensure_ascii=False + ) + else: + raise ValueError( + "Unsupported request.structured_outputs for MistralToolParser. " + "Only `json` and `json_object` are supported." + ) + elif response_format is not None and response_format.type != "text": + if response_format.type == "json_object": + json_schema = _DEFAULT_JSON_SCHEMA + elif response_format.type == "json_schema": + if response_format.json_schema is not None: + json_schema = json.dumps( + response_format.json_schema.json_schema, + ensure_ascii=False, + ) + else: + json_schema = _DEFAULT_JSON_SCHEMA + else: + raise ValueError( + "MistralToolParser only accepts `text`, `json_object` or " + f"`json_schema` for request.response_format, got {response_format=}" + ) + # Structured Outputs will be defined. + request.response_format = None grammar_factory = self.model_tokenizer.grammar_factory @@ -187,13 +240,19 @@ def adjust_request( ) # Rendering grammar is cached in mistral-common given tools, template and mode. - lark_grammar = grammar_factory.get_lark_from_jinja( - template=template, - mode=tool_choice, - tools=tools, - json_schema=None, - parallel_tool_calls=request.parallel_tool_calls, - ) + match tool_choice, json_schema is not None: + case MistralToolChoiceEnum.none, True: + lark_grammar = grammar_factory.get_lark_for_json_schema( + json_schema=json_schema + ) + case _, _: + lark_grammar = grammar_factory.get_lark_from_jinja( + template=template, + mode=tool_choice, + tools=tools, + json_schema=json_schema, + parallel_tool_calls=request.parallel_tool_calls, + ) request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar) return request From 3a80fb49aef52c95be68c047f8303ff2a434fe6f Mon Sep 17 00:00:00 2001 From: juliendenize Date: Thu, 26 Mar 2026 17:00:57 +0000 Subject: [PATCH 03/16] Improve error message for Guidance backend Signed-off-by: juliendenize --- vllm/sampling_params.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 3b7fdebfa15e..9bcc669591eb 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -808,7 +808,8 @@ def _validate_structured_outputs( if _is_non_tekken_mistral(tokenizer=tokenizer): raise ValueError( "Non-tekken Mistral tokenizers are not supported for the 'guidance'" - " structured output backend. Please use ['xgrammar', 'outlines'] " + " structured output backend. Please either use a more recent " + "Mistral model, the ['xgrammar', 'outlines'] " "backends or tokenizer_mode='hf' instead." ) # TODO: ideally we would have the LLTokenizer here as Lark syntax From a355de4ea7a52f64c87df5f0dfb7f6a08509133e Mon Sep 17 00:00:00 2001 From: juliendenize Date: Fri, 27 Mar 2026 10:32:32 +0000 Subject: [PATCH 04/16] Fix json grammar arguments Signed-off-by: juliendenize --- .../tool_parsers/test_mistral_tool_parser.py | 16 +++++--------- vllm/tool_parsers/mistral_tool_parser.py | 21 ++++++------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index c254562b2a63..d761dddd38c3 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -1162,13 +1162,10 @@ def test_adjust_request_unsupported_response_format( "so_kwargs,expected_json_schema", [ ({"json_object": True}, _DEFAULT_JSON_SCHEMA), - ({"json": '{"type": "object"}'}, '{"type": "object"}'), + ({"json": '{"type": "object"}'}, {"type": "object"}), ( {"json": {"type": "object", "properties": {"x": {"type": "integer"}}}}, - json.dumps( - {"type": "object", "properties": {"x": {"type": "integer"}}}, - ensure_ascii=False, - ), + {"type": "object", "properties": {"x": {"type": "integer"}}}, ), ], ids=["json_object", "json_str", "json_dict"], @@ -1213,10 +1210,7 @@ def test_adjust_request_structured_outputs_generates_grammar( }, }, }, - json.dumps( - {"type": "object", "properties": {"x": {"type": "integer"}}}, - ensure_ascii=False, - ), + {"type": "object", "properties": {"x": {"type": "integer"}}}, ), ], ids=["json_object", "json_schema_with_schema"], @@ -1268,7 +1262,7 @@ def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_facto result = mistral_tool_parser.adjust_request(request) mock_json_schema.assert_called_once() - assert mock_json_schema.call_args.kwargs["json_schema"] == '{"type": "object"}' + assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"} mock_jinja.assert_not_called() assert result.structured_outputs is not None @@ -1300,7 +1294,7 @@ def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory( result = mistral_tool_parser.adjust_request(request) mock_jinja.assert_called_once() - assert mock_jinja.call_args.kwargs["json_schema"] == '{"type": "object"}' + assert mock_jinja.call_args.kwargs["json_schema"] == {"type": "object"} mock_json_schema.assert_not_called() assert result.structured_outputs is not None diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index 393fe09d3966..9d2d8734f26c 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -49,9 +49,7 @@ ALPHANUMERIC = ascii_letters + digits -_DEFAULT_JSON_SCHEMA = json.dumps( - {"anyOf": [{"type": "object"}, {"type": "array"}]}, ensure_ascii=False -) +_DEFAULT_JSON_SCHEMA = {"anyOf": [{"type": "object"}, {"type": "array"}]} class StreamingState(Enum): @@ -170,17 +168,15 @@ def adjust_request( request.skip_special_tokens = False return request - json_schema: str | None = None + json_schema: dict[str, Any] | None = None if structured_outputs is not None: if structured_outputs.json_object is not None: json_schema = _DEFAULT_JSON_SCHEMA elif structured_outputs.json is not None: if isinstance(structured_outputs.json, str): - json_schema = structured_outputs.json + json_schema = json.loads(structured_outputs.json) else: - json_schema = json.dumps( - structured_outputs.json, ensure_ascii=False - ) + json_schema = structured_outputs.json else: raise ValueError( "Unsupported request.structured_outputs for MistralToolParser. " @@ -191,10 +187,7 @@ def adjust_request( json_schema = _DEFAULT_JSON_SCHEMA elif response_format.type == "json_schema": if response_format.json_schema is not None: - json_schema = json.dumps( - response_format.json_schema.json_schema, - ensure_ascii=False, - ) + json_schema = response_format.json_schema.json_schema else: json_schema = _DEFAULT_JSON_SCHEMA else: @@ -210,9 +203,7 @@ def adjust_request( # TODO: Once unified parser, improve this. # The issue is figuring out when a model is a reasoning one or not. template = grammar_factory.select_jinja_template( - reasoning=request.include_reasoning - and request.reasoning_effort != "none" - and self.model_can_reason + reasoning=self.model_can_reason ) tools = ( From 23f76ca4de8453b7f7236e12dd5419640c5f73d1 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Fri, 27 Mar 2026 14:13:26 +0000 Subject: [PATCH 05/16] Update with respect to mistral-common guidance Signed-off-by: juliendenize --- vllm/tool_parsers/mistral_tool_parser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index 9d2d8734f26c..6fc127f61911 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -234,7 +234,7 @@ def adjust_request( match tool_choice, json_schema is not None: case MistralToolChoiceEnum.none, True: lark_grammar = grammar_factory.get_lark_for_json_schema( - json_schema=json_schema + template=template, json_schema=json_schema ) case _, _: lark_grammar = grammar_factory.get_lark_from_jinja( @@ -243,6 +243,7 @@ def adjust_request( tools=tools, json_schema=json_schema, parallel_tool_calls=request.parallel_tool_calls, + json_only=False, ) request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar) From 71be6ace23c71e1c4ea8a387933bbb23313a00ea Mon Sep 17 00:00:00 2001 From: juliendenize Date: Fri, 27 Mar 2026 14:40:18 +0000 Subject: [PATCH 06/16] Fix tests Signed-off-by: juliendenize --- tests/tool_parsers/test_mistral_tool_parser.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index d761dddd38c3..064ccb39ef4b 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -1247,23 +1247,15 @@ def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_facto ) factory = mistral_tool_parser.model_tokenizer.grammar_factory - with ( - patch.object( - factory, - "get_lark_for_json_schema", - wraps=factory.get_lark_for_json_schema, - ) as mock_json_schema, - patch.object( - factory, - "get_lark_from_jinja", - wraps=factory.get_lark_from_jinja, - ) as mock_jinja, - ): + with patch.object( + factory, + "get_lark_for_json_schema", + wraps=factory.get_lark_for_json_schema, + ) as mock_json_schema: result = mistral_tool_parser.adjust_request(request) mock_json_schema.assert_called_once() assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"} - mock_jinja.assert_not_called() assert result.structured_outputs is not None assert isinstance(result.structured_outputs.grammar, str) From 3772f0cdfd2c87b843e330b547c1610780cfde3f Mon Sep 17 00:00:00 2001 From: juliendenize Date: Wed, 1 Apr 2026 14:38:03 +0000 Subject: [PATCH 07/16] bump Signed-off-by: juliendenize --- requirements/common.txt | 2 +- requirements/rocm-test.txt | 2 +- requirements/test.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 05666c5d14b0..b610fd678687 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.17.0 -mistral_common[image] >= 1.10.0 +mistral_common[image] >= 1.11.0 opencv-python-headless >= 4.13.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index dd4c7c24f40c..3d5df9814cae 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -502,7 +502,7 @@ mbstrdecoder==1.1.4 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.10.0 +mistral-common==1.11.0 # via # -c requirements/common.txt # -r requirements/rocm-test.in diff --git a/requirements/test.txt b/requirements/test.txt index 642e589a6a27..c8ff5fcabb28 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -508,7 +508,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.10.0 +mistral-common==1.11.0 # via # -c requirements/common.txt # -r requirements/test.in From 0506d4e9e2d7e583f2042f0ee9aa23275667dac0 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Wed, 1 Apr 2026 14:53:39 +0000 Subject: [PATCH 08/16] Exclude ResponsesRequest Signed-off-by: juliendenize --- vllm/tool_parsers/mistral_tool_parser.py | 37 +++++++++++++----------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index 6fc127f61911..a7cfee602f6a 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -142,18 +142,18 @@ def adjust_request( "whitespace_pattern", "structural_tag", ] - structured_outputs = request.structured_outputs - response_format = request.response_format - any_so_non_supported_active = structured_outputs is not None and any( - getattr(structured_outputs, attribute) is not None + any_so_non_supported_active = request.structured_outputs is not None and any( + getattr(request.structured_outputs, attribute) is not None for attribute in so_non_supported_attributes ) response_format_non_supported_active = ( - response_format is not None and response_format.type == "structural_tag" + request.response_format is not None + and request.response_format.type == "structural_tag" ) if ( not is_mistral_tokenizer(self.model_tokenizer) + or isinstance(request, ResponsesRequest) or not self.model_tokenizer.supports_grammar or any_so_non_supported_active or response_format_non_supported_active @@ -169,31 +169,34 @@ def adjust_request( return request json_schema: dict[str, Any] | None = None - if structured_outputs is not None: - if structured_outputs.json_object is not None: + if request.structured_outputs is not None: + if request.structured_outputs.json_object is not None: json_schema = _DEFAULT_JSON_SCHEMA - elif structured_outputs.json is not None: - if isinstance(structured_outputs.json, str): - json_schema = json.loads(structured_outputs.json) + elif request.structured_outputs.json is not None: + if isinstance(request.structured_outputs.json, str): + json_schema = json.loads(request.structured_outputs.json) else: - json_schema = structured_outputs.json + json_schema = request.structured_outputs.json else: raise ValueError( "Unsupported request.structured_outputs for MistralToolParser. " "Only `json` and `json_object` are supported." ) - elif response_format is not None and response_format.type != "text": - if response_format.type == "json_object": + elif ( + request.response_format is not None + and request.response_format.type != "text" + ): + if request.response_format.type == "json_object": json_schema = _DEFAULT_JSON_SCHEMA - elif response_format.type == "json_schema": - if response_format.json_schema is not None: - json_schema = response_format.json_schema.json_schema + elif request.response_format.type == "json_schema": + if request.response_format.json_schema is not None: + json_schema = request.response_format.json_schema.json_schema else: json_schema = _DEFAULT_JSON_SCHEMA else: raise ValueError( "MistralToolParser only accepts `text`, `json_object` or " - f"`json_schema` for request.response_format, got {response_format=}" + f"`json_schema`, got {request.response_format=}" ) # Structured Outputs will be defined. request.response_format = None From 4b6eb0295d22545ac293d7876757691de8b15ff6 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Thu, 26 Mar 2026 19:28:54 +0000 Subject: [PATCH 09/16] Add _from_tool_parser attribute Signed-off-by: juliendenize --- vllm/sampling_params.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 9bcc669591eb..1f8303efa312 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -51,6 +51,8 @@ class StructuredOutputsParams: """CAUTION: Should only be set by Processor._validate_structured_output""" _backend_was_auto: bool = field(default=False, init=False) """CAUTION: Should only be set by Processor._validate_structured_output""" + _from_tool_parser: bool = field(default=False, init=False) + """CAUTION: Should only be set by ToolParser.adjust_request""" def __post_init__(self): """Validate that some fields are mutually exclusive.""" From 86ba6cf34a61f38fefe0a87a2b5ef110af6a63e7 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Thu, 26 Mar 2026 20:34:49 +0000 Subject: [PATCH 10/16] Core parser changes Signed-off-by: juliendenize --- vllm/tool_parsers/mistral_tool_parser.py | 189 ++++++++++++++++++++++- 1 file changed, 183 insertions(+), 6 deletions(-) diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index a7cfee602f6a..f53da7fdd97c 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import json from collections.abc import Sequence +from dataclasses import dataclass from enum import Enum, auto from random import choices from string import ascii_letters, digits -from typing import Any +from typing import TYPE_CHECKING, Any import ijson import regex as re @@ -37,14 +40,19 @@ ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger +from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) from vllm.utils.mistral import is_mistral_tokenizer +if TYPE_CHECKING: + from vllm.reasoning import ReasoningParser + logger = init_logger(__name__) ALPHANUMERIC = ascii_letters + digits @@ -86,18 +94,42 @@ def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool: return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11) -class MistralToolParser(ToolParser): +@dataclass +class MistralStreamingResult: + r"""Encapsulates the mutable state returned from + `MistralToolParser.extract_maybe_reasoning_and_tool_streaming`. """ - Tool call parser for Mistral 7B Instruct v0.3, intended for use with - - [`mistral_common`](https://github.com/mistralai/mistral-common/) - - the examples/tool_chat_template_mistral.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + delta_message: DeltaMessage | None + reasoning_ended: bool + added_content_delta: bool + tools_called: bool + current_text: str + current_token_ids: list[int] + + +class MistralToolParser(ToolParser): + r"""Tool call parser for Mistral models, intended for use with either: + + - `mistral_common `_ + (recommended) + - the `examples/tool_chat_template_mistral.jinja` template. + + Used when `--enable-auto-tool-choice --tool-call-parser mistral` are all + set. """ # Used to generate correct grammar in `adjust_request` model_can_reason: bool = False + @staticmethod + def is_mistral_grammar_path(request: ChatCompletionRequest) -> bool: + r"""Check if the request was adjusted via the Mistral grammar factory path.""" + return ( + request.structured_outputs is not None + and request.structured_outputs._from_tool_parser + ) + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) @@ -250,8 +282,153 @@ def adjust_request( ) request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar) + request.structured_outputs._from_tool_parser = True return request + def extract_maybe_reasoning_and_tool_streaming( + self, + *, + reasoning_parser: ReasoningParser | None, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + output_token_ids: Sequence[int], + reasoning_ended: bool, + added_content_delta: bool, + prompt_is_reasoning_end: bool | None, + request: ChatCompletionRequest, + ) -> MistralStreamingResult: + r"""Streaming extraction with reasoning followed by tool-call parsing. + + This method encapsulates the combined reasoning extraction and + tool-call streaming logic so that the serving layer only needs a + thin routing branch. + + The flow is: + + 1. If a *reasoning_parser* is present and reasoning has **not** ended, + extract reasoning tokens. Pre-v15 models may have pre-filled + `[THINK]...[/THINK]` in system prompts, so we skip the + prompt-level reasoning-end check for those. + 2. Once reasoning ends (or if there is no reasoning parser), delegate + to `extract_tool_calls_streaming` and track whether tools were + called. + + Args: + reasoning_parser: Optional reasoning parser instance. + previous_text: Accumulated text from prior chunks. + current_text: Full accumulated text including current chunk. + delta_text: New text in this chunk. + previous_token_ids: Token ids from prior chunks. + current_token_ids: Full token ids including current chunk. + output_token_ids: Raw output token ids from the engine. + reasoning_ended: Whether reasoning has already ended. + added_content_delta: Whether the first content delta after + reasoning has been emitted. + prompt_is_reasoning_end: Whether the prompt itself ends reasoning. + request: The originating chat completion request. + """ + delta_message: DeltaMessage | None = None + tools_called = False + + # For MistralReasoningParser, only enter the reasoning block when + # the model has actually emitted a [THINK] token. Other reasoning + # parsers always expect thinking to be present. + expect_thinking = ( + not isinstance(reasoning_parser, MistralReasoningParser) + or reasoning_parser.start_token_id in current_token_ids + ) + if reasoning_parser is not None and not reasoning_ended and expect_thinking: + # Pre-v15 models may have pre-filled [THINK]...[/THINK] in + # system prompts, so skip the prompt-level reasoning-end + # check and wait for the output's own end-of-think. + is_pre_v15 = ( + isinstance(self.model_tokenizer, MistralTokenizer) + and self.model_tokenizer.version < 15 + ) + + if not is_pre_v15 and prompt_is_reasoning_end: + reasoning_ended = True + current_token_ids = list(output_token_ids) + else: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + ) + if reasoning_parser.is_reasoning_end_streaming( + current_token_ids, output_token_ids + ): + reasoning_ended = True + current_token_ids = reasoning_parser.extract_content_ids( + output_token_ids + ) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + if not reasoning_ended: + return MistralStreamingResult( + delta_message=delta_message, + reasoning_ended=False, + added_content_delta=added_content_delta, + tools_called=False, + current_text=current_text, + current_token_ids=current_token_ids, + ) + + delta_token_ids = list(output_token_ids) + if reasoning_parser is not None and not added_content_delta: + # First chunk after reasoning ended: reset text state. + added_content_delta = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = self.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_called = True + + return MistralStreamingResult( + delta_message=delta_message, + reasoning_ended=reasoning_ended, + added_content_delta=added_content_delta, + tools_called=tools_called, + current_text=current_text, + current_token_ids=current_token_ids, + ) + + @staticmethod + def build_non_streaming_tool_calls( + tool_calls: list[FunctionCall] | None, + ) -> list[ToolCall]: + r"""Build `MistralToolCall` items for non-streaming responses.""" + if not tool_calls: + return [] + + return [ + MistralToolCall(id=tc.id, function=tc) + if tc.id + else MistralToolCall(function=tc) + for tc in tool_calls + ] + def extract_tool_calls( self, model_output: str, From 84d35f177239c8d9d35a7ce7792224fda36586c5 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Thu, 26 Mar 2026 20:43:41 +0000 Subject: [PATCH 11/16] Allow `adjust_request` for `tool_choice="none"` Signed-off-by: juliendenize --- vllm/entrypoints/serve/render/serving.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 83b41bbda2d0..02b3a7b17cd6 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -52,6 +52,7 @@ prompt_to_seq, ) from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolParser from vllm.utils import random_uuid from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import mt as _mt @@ -534,9 +535,19 @@ async def preprocess_chat( # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser # is set, we want to prevent parsing a tool_call hallucinated by the LLM + # + # Exception: Mistral grammar-capable tokenizers always call + # adjust_request — even for tool_choice="none" — so that the grammar + # factory can prevent special-token leakage. if tool_parser is not None: tool_choice = getattr(request, "tool_choice", "none") - if tool_choice != "none": + tokenizer = renderer.get_tokenizer() + is_mistral_grammar_eligible = ( + issubclass(tool_parser, MistralToolParser) + and is_mistral_tokenizer(tokenizer) + and tokenizer.supports_grammar + ) + if tool_choice != "none" or is_mistral_grammar_eligible: if not isinstance(request, ChatCompletionRequest | ResponsesRequest): msg = ( "Tool usage is only supported " @@ -544,7 +555,6 @@ async def preprocess_chat( f"but got {type(request).__name__}" ) raise NotImplementedError(msg) - tokenizer = renderer.get_tokenizer() request = tool_parser(tokenizer, request.tools).adjust_request( request=request ) From 0f4418a3ba6ff180b7c45bd0b080e8301bce2eee Mon Sep 17 00:00:00 2001 From: juliendenize Date: Thu, 26 Mar 2026 20:52:28 +0000 Subject: [PATCH 12/16] Streaming and Non-streaming mistral grammar Signed-off-by: juliendenize --- .../openai/chat_completion/serving.py | 70 +++++++++++++++++-- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index a426836afd35..ccbe4271ae0c 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -73,7 +73,10 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser -from vllm.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.tool_parsers.mistral_tool_parser import ( + MistralToolCall, + MistralToolParser, +) from vllm.tool_parsers.utils import partial_json_loads from vllm.utils.collection_utils import as_list from vllm.utils.mistral import is_mistral_tokenizer @@ -134,6 +137,12 @@ def __init__( enable_auto_tools=enable_auto_tools, model_name=self.model_config.model, ) + _is_mistral_tool_parser = self.tool_parser is not None and issubclass( + self.tool_parser, MistralToolParser + ) + if _is_mistral_tool_parser and self.reasoning_parser_cls is not None: + MistralToolParser.model_can_reason = True + self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.enable_prompt_tokens_details = enable_prompt_tokens_details @@ -523,6 +532,8 @@ async def chat_completion_stream_generator( harmony_tools_streamed = [False] * num_choices tools_streamed = [False] * num_choices + is_mistral_grammar_path = MistralToolParser.is_mistral_grammar_path(request) + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name else: @@ -546,7 +557,7 @@ async def chat_completion_stream_generator( # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. - if tool_choice_auto or reasoning_parser: + if is_mistral_grammar_path or tool_choice_auto or reasoning_parser: # These are only required in "auto" tool choice case all_previous_token_ids = [[] for _ in range(num_choices)] # For reasoning parser and tool call all enabled @@ -558,7 +569,7 @@ async def chat_completion_stream_generator( # Prepare the tool parser if it's needed try: - if tool_choice_auto and self.tool_parser: + if (is_mistral_grammar_path or tool_choice_auto) and self.tool_parser: if tokenizer is None: raise ValueError( "Tokenizer not available when `skip_tokenizer_init=True`" @@ -740,7 +751,7 @@ async def chat_completion_stream_generator( delta_message: DeltaMessage | None # just update previous_texts and previous_token_ids - if tool_choice_auto or reasoning_parser: + if is_mistral_grammar_path or tool_choice_auto or reasoning_parser: assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -764,6 +775,33 @@ async def chat_completion_stream_generator( ) ) harmony_tools_streamed[i] |= tools_streamed_flag + # Mistral grammar path: combined reasoning + tool streaming + elif is_mistral_grammar_path: + assert tool_parser is not None + assert isinstance(tool_parser, MistralToolParser) + assert added_content_delta_arr is not None + assert reasoning_end_arr is not None + output_token_ids = as_list(output.token_ids) + result = tool_parser.extract_maybe_reasoning_and_tool_streaming( + reasoning_parser=reasoning_parser, + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + output_token_ids=output_token_ids, + reasoning_ended=reasoning_end_arr[i], + added_content_delta=added_content_delta_arr[i], + prompt_is_reasoning_end=(prompt_is_reasoning_end_arr[i]), + request=request, + ) + delta_message = result.delta_message + reasoning_end_arr[i] = result.reasoning_ended + added_content_delta_arr[i] = result.added_content_delta + current_text = result.current_text + current_token_ids = result.current_token_ids + if result.tools_called: + tools_streamed[i] = True # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: # When encountering think end id in prompt_token_ids @@ -1010,7 +1048,9 @@ async def chat_completion_stream_generator( delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if (tool_choice_auto or reasoning_parser) and not self.use_harmony: + if ( + is_mistral_grammar_path or tool_choice_auto or reasoning_parser + ) and not self.use_harmony: assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -1397,7 +1437,25 @@ async def chat_completion_full_generator( tool_call_class = ( MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall ) - if (not self.enable_auto_tools or not self.tool_parser) and ( + + use_mistral_tool_parser = MistralToolParser.is_mistral_grammar_path(request) + if use_mistral_tool_parser: + tool_call_items = MistralToolParser.build_non_streaming_tool_calls( + tool_calls + ) + if tool_call_items: + auto_tools_called = not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam, + ) + message = ChatMessage( + role=role, + reasoning=reasoning, + content=content, + tool_calls=tool_call_items, + ) + + elif (not self.enable_auto_tools or not self.tool_parser) and ( not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) and request.tool_choice != "required" ): From 5a8cbcf0278cdb685c872730a0f1a3617013e4a0 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Thu, 26 Mar 2026 21:09:28 +0000 Subject: [PATCH 13/16] Use MistralToolParser in _parse_tool_calls_from_content Signed-off-by: juliendenize --- vllm/entrypoints/openai/engine/serving.py | 30 +++++++++++++++++------ 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index f5f011a96f27..3d4f4ff364d5 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -72,6 +72,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolParser from vllm.tracing import ( contains_trace_headers, extract_trace_headers, @@ -785,16 +786,31 @@ def _parse_tool_calls_from_content( tool_parser_cls: type[ToolParser] | None, content: str | None = None, ) -> tuple[list[FunctionCall] | None, str | None]: + # When the Mistral grammar factory injected structured outputs, + # let the parser handle the output. + use_mistral_tool_parser = ( + isinstance(request, ChatCompletionRequest) + and tool_parser_cls is not None + and issubclass(tool_parser_cls, MistralToolParser) + and MistralToolParser.is_mistral_grammar_path(request=request) + ) + function_calls = list[FunctionCall]() - if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): + if ( + not use_mistral_tool_parser + and request.tool_choice + and isinstance(request.tool_choice, ToolChoiceFunction) + ): assert content is not None # Forced Function Call function_calls.append( FunctionCall(name=request.tool_choice.name, arguments=content) ) content = None # Clear content since tool is called. - elif request.tool_choice and isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam + elif ( + not use_mistral_tool_parser + and request.tool_choice + and isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) ): assert content is not None # Forced Function Call @@ -802,7 +818,7 @@ def _parse_tool_calls_from_content( FunctionCall(name=request.tool_choice.function.name, arguments=content) ) content = None # Clear content since tool is called. - elif request.tool_choice == "required": + elif not use_mistral_tool_parser and request.tool_choice == "required": tool_calls = [] with contextlib.suppress(ValidationError): content = content or "" @@ -817,9 +833,9 @@ def _parse_tool_calls_from_content( ) ) content = None # Clear content since tool is called. - elif ( - tool_parser_cls - and enable_auto_tools + elif tool_parser_cls and ( + use_mistral_tool_parser + or enable_auto_tools and (request.tool_choice == "auto" or request.tool_choice is None) ): if tokenizer is None: From 7a0e426c4c162d3537785310d4c8fbf709dca725 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Fri, 27 Mar 2026 10:36:26 +0000 Subject: [PATCH 14/16] Add tests Signed-off-by: juliendenize --- .../tool_parsers/test_mistral_tool_parser.py | 285 +++++++++++++ .../mistral/test_mistral_tool_calls.py | 379 +++++++++++++++++- tests/tool_use/mistral/utils.py | 16 + 3 files changed, 679 insertions(+), 1 deletion(-) diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index 064ccb39ef4b..750dd4d15fc9 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -3,6 +3,7 @@ import json from collections.abc import Generator +from typing import Any from unittest.mock import MagicMock, patch import partial_json_parser @@ -28,16 +29,21 @@ ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, DeltaMessage, DeltaToolCall, StructuralTagResponseFormat, ) +from vllm.entrypoints.openai.engine.protocol import FunctionCall as VllmFunctionCall +from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally from vllm.tokenizers.mistral import MistralTokenizer from vllm.tool_parsers.mistral_tool_parser import ( _DEFAULT_JSON_SCHEMA, + MistralStreamingResult, + MistralToolCall, MistralToolParser, ) @@ -1292,3 +1298,282 @@ def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory( assert result.structured_outputs is not None assert isinstance(result.structured_outputs.grammar, str) assert len(result.structured_outputs.grammar) > 0 + + +@pytest.mark.parametrize( + "so, set_from_tool_parser, expected", + [ + (None, False, False), + (StructuredOutputsParams(grammar="user grammar"), False, False), + (StructuredOutputsParams(grammar="factory grammar"), True, True), + ], + ids=["no_structured_outputs", "user_supplied_grammar", "from_tool_parser"], +) +def test_is_mistral_grammar_path( + so: StructuredOutputsParams | None, + set_from_tool_parser: bool, + expected: bool, +) -> None: + request = _make_request(structured_outputs=so) + if set_from_tool_parser: + assert request.structured_outputs is not None + request.structured_outputs._from_tool_parser = True + + assert MistralToolParser.is_mistral_grammar_path(request) == expected + + +@pytest.mark.parametrize( + "tool_calls, expected_len", + [ + (None, 0), + ([], 0), + ([VllmFunctionCall(id="abc123xyz", name="f", arguments="{}")], 1), + ([VllmFunctionCall(name="f", arguments="{}")], 1), + ( + [ + VllmFunctionCall(id="fixed1234", name="a", arguments='{"x": 1}'), + VllmFunctionCall(name="b", arguments='{"y": 2}'), + ], + 2, + ), + ], + ids=["none", "empty", "with_id", "without_id", "mixed"], +) +def test_build_non_streaming_tool_calls( + tool_calls: list[VllmFunctionCall] | None, + expected_len: int, +) -> None: + result = MistralToolParser.build_non_streaming_tool_calls(tool_calls) + assert len(result) == expected_len + + if tool_calls is None: + return + + for i, tc in enumerate(result): + assert isinstance(tc, MistralToolCall) + assert tc.type == "function" + + input_tc = tool_calls[i] + if input_tc.id: + assert tc.id == input_tc.id + else: + assert len(tc.id) == 9 + assert tc.id.isalnum() + + assert tc.function.name == input_tc.name + assert tc.function.arguments == input_tc.arguments + + +class TestExtractMaybeReasoningAndToolStreaming: + r"""Tests for ``MistralToolParser.extract_maybe_reasoning_and_tool_streaming``.""" + + @pytest.fixture + def parser(self) -> MistralToolParser: + mock_tokenizer = MagicMock() + mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1} + return MistralToolParser(mock_tokenizer) + + @pytest.fixture + def request_obj(self) -> ChatCompletionRequest: + return _make_request() + + @staticmethod + def _call( + parser: MistralToolParser, + request: ChatCompletionRequest, + *, + reasoning_parser: Any = None, + previous_text: str = "", + current_text: str = "hello", + delta_text: str = "hello", + previous_token_ids: list[int] | None = None, + current_token_ids: list[int] | None = None, + output_token_ids: list[int] | None = None, + reasoning_ended: bool = False, + added_content_delta: bool = False, + prompt_is_reasoning_end: bool | None = None, + ) -> MistralStreamingResult: + return parser.extract_maybe_reasoning_and_tool_streaming( + reasoning_parser=reasoning_parser, + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids or [], + current_token_ids=current_token_ids or [1, 2, 3], + output_token_ids=output_token_ids or [1, 2, 3], + reasoning_ended=reasoning_ended, + added_content_delta=added_content_delta, + prompt_is_reasoning_end=prompt_is_reasoning_end, + request=request, + ) + + def test_no_reasoning_tools_called( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + tool_delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + function=DeltaFunctionCall(name="f", arguments="{}"), + ) + ] + ) + with patch.object( + parser, "extract_tool_calls_streaming", return_value=tool_delta + ): + result = self._call(parser, request_obj, reasoning_parser=None) + + assert result.tools_called + assert result.delta_message is not None + assert result.delta_message.tool_calls + + def test_no_reasoning_no_tools( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + content_delta = DeltaMessage(content="hello") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call(parser, request_obj, reasoning_parser=None) + + assert not result.tools_called + assert result.delta_message is not None + assert result.delta_message.content == "hello" + + def test_mistral_reasoning_parser_no_think_token( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + content_delta = DeltaMessage(content="direct") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 2, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_not_called() + assert result.delta_message is not None + assert result.delta_message.content == "direct" + + def test_mistral_reasoning_parser_with_think_token( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 999, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert not result.reasoning_ended + assert not result.tools_called + + def test_non_mistral_reasoning_parser_always_expects_thinking( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock() + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 2, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert not result.reasoning_ended + assert not result.tools_called + + def test_reasoning_ended_first_chunk_resets_state( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + content_delta = DeltaMessage(content="content") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ) as mock_extract: + result = self._call( + parser, + request_obj, + reasoning_parser=MagicMock(), + reasoning_ended=True, + added_content_delta=False, + ) + + _, call_kwargs = mock_extract.call_args + assert call_kwargs["previous_text"] == "" + assert call_kwargs["previous_token_ids"] == [] + + assert result.added_content_delta + + def test_pre_v15_ignores_prompt_reasoning_end( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_tokenizer = MagicMock(spec=MistralTokenizer) + mock_tokenizer.version = 13 + parser.model_tokenizer = mock_tokenizer + + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + prompt_is_reasoning_end=True, + current_token_ids=[999, 1, 2], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert not result.reasoning_ended + + def test_non_pre_v15_prompt_reasoning_end( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_tokenizer = MagicMock(spec=MistralTokenizer) + mock_tokenizer.version = 15 + parser.model_tokenizer = mock_tokenizer + + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + + content_delta = DeltaMessage(content="after reasoning") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + prompt_is_reasoning_end=True, + current_token_ids=[999, 1, 2], + ) + + mock_rp.extract_reasoning_streaming.assert_not_called() + assert result.reasoning_ended diff --git a/tests/tool_use/mistral/test_mistral_tool_calls.py b/tests/tool_use/mistral/test_mistral_tool_calls.py index 3c4a543abe41..b1b6a9143309 100644 --- a/tests/tool_use/mistral/test_mistral_tool_calls.py +++ b/tests/tool_use/mistral/test_mistral_tool_calls.py @@ -1,10 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + import openai import pytest -from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL +from tests.tool_use.utils import ( + MESSAGES_ASKING_FOR_TOOLS, + MESSAGES_WITH_TOOL_RESPONSE, + MESSAGES_WITHOUT_TOOLS, + SEARCH_TOOL, + SEED, + WEATHER_TOOL, + ensure_system_prompt, +) + +from .utils import ServerConfig + + +def _requires_auto_tool_choice(server_config: ServerConfig) -> None: + r"""Skip test if server was not started with --enable-auto-tool-choice.""" + if "--enable-auto-tool-choice" not in server_config.get("arguments", []): + pytest.skip( + f"Skipping: {server_config['model']} not configured with " + "--enable-auto-tool-choice" + ) # test: a tool_choice with mistral-tokenizer results in an ID of length 9 @@ -28,3 +49,359 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): assert choice.message.role == "assistant" assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1 assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral + + +@pytest.mark.asyncio +async def test_tool_call_auto( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_auto_tool_choice(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + tool_calls = choice.message.tool_calls + + assert choice.finish_reason == "tool_calls" + assert tool_calls is not None and len(tool_calls) >= 1 + assert tool_calls[0].type == "function" + assert tool_calls[0].function.name == "get_current_weather" + assert isinstance(tool_calls[0].function.arguments, str) + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert "city" in parsed_arguments + assert len(tool_calls[0].id) == 9 + + # --- streaming --- + function_name: str | None = None + function_args_str: str = "" + tool_call_id: str | None = None + role_name: str | None = None + finish_reason_count: int = 0 + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == "tool_calls" + + if chunk.choices[0].delta.role: + assert not role_name or role_name == "assistant" + role_name = "assistant" + + streamed_tool_calls = chunk.choices[0].delta.tool_calls + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + if tool_call.function: + if tool_call.function.name: + assert function_name is None + function_name = tool_call.function.name + if tool_call.function.arguments: + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == "assistant" + assert function_name == "get_current_weather" + streamed_args = json.loads(function_args_str) + assert "city" in streamed_args + assert isinstance(tool_call_id, str) and len(tool_call_id) == 9 + assert parsed_arguments == streamed_args + + +@pytest.mark.asyncio +async def test_tool_call_required( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_auto_tool_choice(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="required", + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + tool_calls = choice.message.tool_calls + + assert choice.finish_reason == "tool_calls" + assert tool_calls is not None and len(tool_calls) >= 1 + assert tool_calls[0].function.name == "get_current_weather" + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert len(tool_calls[0].id) == 9 + + # --- streaming --- + function_name: str | None = None + function_args_str: str = "" + tool_call_id: str | None = None + role_name: str | None = None + finish_reason_count: int = 0 + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="required", + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == "tool_calls" + + if chunk.choices[0].delta.role: + assert not role_name or role_name == "assistant" + role_name = "assistant" + + streamed_tool_calls = chunk.choices[0].delta.tool_calls + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + if tool_call.function: + if tool_call.function.name: + assert function_name is None + function_name = tool_call.function.name + if tool_call.function.arguments: + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == "assistant" + assert function_name == "get_current_weather" + streamed_args = json.loads(function_args_str) + assert isinstance(tool_call_id, str) and len(tool_call_id) == 9 + assert parsed_arguments == streamed_args + + +@pytest.mark.asyncio +async def test_tool_call_none_with_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="none", + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "[TOOL_CALLS]" not in choice.message.content + + non_streaming_content = choice.message.content + + # --- streaming --- + chunks: list[str] = [] + finish_reason_count: int = 0 + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="none", + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason != "tool_calls" + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert finish_reason_count == 1 + streamed_content = "".join(chunks) + assert "[TOOL_CALLS]" not in streamed_content + assert streamed_content == non_streaming_content + + +@pytest.mark.asyncio +async def test_chat_without_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config), + temperature=0, + max_completion_tokens=150, + model=model_name, + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + output_text = choice.message.content + + assert output_text is not None and len(output_text) > 0 + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + + # --- streaming --- + chunks: list[str] = [] + finish_reason_count: int = 0 + role_sent: bool = False + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config), + temperature=0, + max_completion_tokens=150, + model=model_name, + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == output_text + + +@pytest.mark.asyncio +async def test_tool_call_with_results( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_auto_tool_choice(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content + + # --- streaming --- + chunks: list[str] = [] + finish_reason_count: int = 0 + role_sent: bool = False + + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + stream=True, + ) + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/tool_use/mistral/utils.py b/tests/tool_use/mistral/utils.py index 4d772ba63793..6f6ee2d8654f 100644 --- a/tests/tool_use/mistral/utils.py +++ b/tests/tool_use/mistral/utils.py @@ -29,4 +29,20 @@ class ServerConfig(TypedDict, total=False): "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally.", }, + "ministral-3b": { + "model": "mistralai/Ministral-3-3B-Instruct-2512", + "arguments": [ + "--tokenizer-mode", + "mistral", + "--tool-call-parser", + "mistral", + "--enable-auto-tool-choice", + ], + "system_prompt": "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally.", + "supports_parallel": True, + }, } From b32bc382d0dc42b09fa0984ef3186d7e6d5bcecd Mon Sep 17 00:00:00 2001 From: juliendenize Date: Fri, 27 Mar 2026 14:02:15 +0000 Subject: [PATCH 15/16] Fixes Signed-off-by: juliendenize --- .../entrypoints/openai/chat_completion/protocol.py | 7 ------- vllm/entrypoints/openai/chat_completion/serving.py | 5 +++++ vllm/sampling_params.py | 14 ++++++++++++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 533959df6094..16d06f4651a5 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -781,13 +781,6 @@ def check_system_message_content_type(cls, data): return data - @model_validator(mode="before") - @classmethod - def set_include_reasoning_for_none_effort(cls, data: Any) -> Any: - if data.get("reasoning_effort") == "none": - data["include_reasoning"] = False - return data - class BatchChatCompletionRequest(OpenAIBaseModel): """Request model for the /v1/chat/completions/batch endpoint. diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index ccbe4271ae0c..e2479179d6dd 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -314,6 +314,11 @@ async def create_chat_completion( else: if not request.include_reasoning: reasoning_ended = True + elif MistralToolParser.is_mistral_grammar_path(request): + # The Mistral grammar already includes an optional + # `think?` rule that handles both reasoning and + # non-reasoning outputs. + reasoning_ended = True elif reasoning_parser: reasoning_ended = reasoning_parser.is_reasoning_end( prompt_token_ids or [] diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 1f8303efa312..93842620dc4a 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -818,7 +818,12 @@ def _validate_structured_outputs( # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. - validate_guidance_grammar(self, tokenizer=None) + validate_guidance_grammar( + self, + tokenizer=tokenizer.llg_tokenizer + if is_mistral_tokenizer(tokenizer) + else None, + ) elif backend == "outlines": # outlines backend validate_structured_output_request_outlines(self) @@ -864,7 +869,12 @@ def _validate_structured_outputs( self.structured_outputs._backend = "outlines" else: # Fall back to guidance by default. - validate_guidance_grammar(self, tokenizer=None) + validate_guidance_grammar( + self, + tokenizer=tokenizer.llg_tokenizer + if is_mistral_tokenizer(tokenizer) + else None, + ) self.structured_outputs._backend = "guidance" # Remember that this backend was set automatically self.structured_outputs._backend_was_auto = True From 5b4589abc74437c7c47bdd5d6a61a1f5ca08bcd6 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Fri, 27 Mar 2026 20:56:10 +0000 Subject: [PATCH 16/16] Minor improvements Signed-off-by: juliendenize --- .../tool_use/mistral/test_mistral_tool_calls.py | 17 +++++++++-------- vllm/entrypoints/openai/engine/serving.py | 6 ++++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/tool_use/mistral/test_mistral_tool_calls.py b/tests/tool_use/mistral/test_mistral_tool_calls.py index b1b6a9143309..fdb2846c664d 100644 --- a/tests/tool_use/mistral/test_mistral_tool_calls.py +++ b/tests/tool_use/mistral/test_mistral_tool_calls.py @@ -19,12 +19,11 @@ from .utils import ServerConfig -def _requires_auto_tool_choice(server_config: ServerConfig) -> None: - r"""Skip test if server was not started with --enable-auto-tool-choice.""" - if "--enable-auto-tool-choice" not in server_config.get("arguments", []): +def _requires_tool_parser(server_config: ServerConfig) -> None: + r"""Skip test if server was not started with --tool-call-parser.""" + if "--tool-call-parser" not in server_config.get("arguments", []): pytest.skip( - f"Skipping: {server_config['model']} not configured with " - "--enable-auto-tool-choice" + f"Skipping: {server_config['model']} not configured with --tool-call-parser" ) @@ -55,7 +54,7 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): async def test_tool_call_auto( client: openai.AsyncOpenAI, server_config: ServerConfig ) -> None: - _requires_auto_tool_choice(server_config) + _requires_tool_parser(server_config) models = await client.models.list() model_name: str = models.data[0].id @@ -139,7 +138,7 @@ async def test_tool_call_auto( async def test_tool_call_required( client: openai.AsyncOpenAI, server_config: ServerConfig ) -> None: - _requires_auto_tool_choice(server_config) + _requires_tool_parser(server_config) models = await client.models.list() model_name: str = models.data[0].id @@ -221,6 +220,8 @@ async def test_tool_call_required( async def test_tool_call_none_with_tools( client: openai.AsyncOpenAI, server_config: ServerConfig ) -> None: + _requires_tool_parser(server_config) + models = await client.models.list() model_name: str = models.data[0].id @@ -345,7 +346,7 @@ async def test_chat_without_tools( async def test_tool_call_with_results( client: openai.AsyncOpenAI, server_config: ServerConfig ) -> None: - _requires_auto_tool_choice(server_config) + _requires_tool_parser(server_config) models = await client.models.list() model_name: str = models.data[0].id diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 3d4f4ff364d5..7cf43dc1f45b 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -835,8 +835,10 @@ def _parse_tool_calls_from_content( content = None # Clear content since tool is called. elif tool_parser_cls and ( use_mistral_tool_parser - or enable_auto_tools - and (request.tool_choice == "auto" or request.tool_choice is None) + or ( + enable_auto_tools + and (request.tool_choice == "auto" or request.tool_choice is None) + ) ): if tokenizer is None: raise ValueError(