Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
932 changes: 752 additions & 180 deletions tests/tool_parsers/test_mistral_tool_parser.py

Large diffs are not rendered by default.

483 changes: 480 additions & 3 deletions tests/tool_use/mistral/test_mistral_tool_calls.py

Large diffs are not rendered by default.

34 changes: 24 additions & 10 deletions tests/tool_use/mistral/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


from typing_extensions import TypedDict


class ServerConfig(TypedDict, total=False):
model: str
arguments: list[str]
system_prompt: str | None
supports_parallel: bool | None
supports_rocm: bool | None

from tests.tool_use.utils import ServerConfig

ARGS: list[str] = ["--max-model-len", "1024"]

Expand All @@ -21,6 +12,11 @@ class ServerConfig(TypedDict, total=False):
"arguments": [
"--tokenizer-mode",
"mistral",
"--tool-call-parser",
"mistral",
"--enable-auto-tool-choice",
"--enforce-eager",
"--no-enable-prefix-caching",
'--ignore-patterns="consolidated.safetensors"',
],
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
Expand All @@ -29,4 +25,22 @@ class ServerConfig(TypedDict, total=False):
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
},
"ministral-3b": {
"model": "mistralai/Ministral-3-3B-Instruct-2512",
"arguments": [
"--tokenizer-mode",
"mistral",
"--tool-call-parser",
"mistral",
"--enable-auto-tool-choice",
"--enforce-eager",
"--no-enable-prefix-caching",
],
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
"supports_parallel": True,
},
}
12 changes: 4 additions & 8 deletions vllm/entrypoints/openai/chat_completion/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import Field, model_validator
from pydantic import Field, PrivateAttr, model_validator

from vllm.config import ModelConfig
from vllm.config.utils import replace
Expand Down Expand Up @@ -398,6 +398,9 @@ def _materialize_tool_calls_after(self) -> "ChatCompletionRequest":
msg["tool_calls"] = list(tool_calls)
return self

_grammar_from_tool_parser: bool = PrivateAttr(default=False)
"""CAUTION: Should only be set by ``ToolParser.adjust_request``."""

def build_chat_params(
self,
default_template: str | None,
Expand Down Expand Up @@ -822,13 +825,6 @@ def check_system_message_content_type(cls, data):

return data

@model_validator(mode="before")
@classmethod
def set_include_reasoning_for_none_effort(cls, data: Any) -> Any:
if data.get("reasoning_effort") == "none":
data["include_reasoning"] = False
return data
Comment on lines -825 to -830
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was introduced by #36238 but it was a bad idea because sometimes the model might want to try to reason so it forces it to be OOD.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree thanks!



class BatchChatCompletionRequest(OpenAIBaseModel):
"""Request model for the /v1/chat/completions/batch endpoint.
Expand Down
69 changes: 64 additions & 5 deletions vllm/entrypoints/openai/chat_completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@
from vllm.renderers import ChatParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolCall,
MistralToolParser,
)
from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer
Expand Down Expand Up @@ -140,6 +143,12 @@ def __init__(
enable_auto_tools=enable_auto_tools,
model_name=self.model_config.model,
)
_is_mistral_tool_parser = self.tool_parser is not None and issubclass(
self.tool_parser, MistralToolParser
)
if _is_mistral_tool_parser and self.reasoning_parser_cls is not None:
MistralToolParser.model_can_reason = True
Comment on lines +146 to +150
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Setting a class attribute MistralToolParser.model_can_reason = True in the __init__ method of OpenAIServingChat is a global state mutation that will affect all instances of MistralToolParser across the entire application. This is a thread-safety issue and can lead to unpredictable behavior if multiple models with different reasoning capabilities are served simultaneously. This should be handled via instance-level configuration or a more robust mechanism.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah so indeed this is not clean but was discussed in previous PR. I don't know how else we should do this 😄

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may not have seen all the previous discussion - did you consider just looking at the reasoning_effort in the request? That's what gates the actual prompt to enable reasoning outputs, right? Or do you need model_can_reason to be true any time a reasoning parser is set, even if the requests are using reasoning_effort=None?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually using reasoning_effort does not help because it:

  • was recently introduced so previous models won't work out
  • even if reasoning_effort is set to "high" or "none" sometimes the model won't follow the instruction. Even if this behavior is not desired and could be prevented by the grammar, we found that it is usually not stable as the model is forced to be doing something it didn't "want" which could end up to infinite loop

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's reasonable. To avoid mutating global state, can we just set this on the instance of the tool parser as opposed to on the module as a global mutation? I think it's just changing this line to self.tool_parser.model_can_reason = True instead of changing it for the module itself? And moving the definition of that model_can_reason field to inside the constructor of the Mistral tool parser?

That makes it set per instance of tool parser as opposed to globally.


self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none

self.enable_prompt_tokens_details = enable_prompt_tokens_details
Expand Down Expand Up @@ -310,6 +319,11 @@ async def create_chat_completion(
else:
if not request.include_reasoning:
reasoning_ended = True
elif request._grammar_from_tool_parser:
# The Mistral grammar already includes an optional
# `think?` rule that handles both reasoning and
# non-reasoning outputs.
reasoning_ended = True
elif reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
prompt_token_ids or []
Expand Down Expand Up @@ -530,6 +544,8 @@ async def chat_completion_stream_generator(
harmony_tools_streamed = [False] * num_choices
tools_streamed = [False] * num_choices

is_mistral_grammar_path = request._grammar_from_tool_parser

if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
Expand All @@ -553,7 +569,7 @@ async def chat_completion_stream_generator(

# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
if tool_choice_auto or reasoning_parser:
if is_mistral_grammar_path or tool_choice_auto or reasoning_parser:
# These are only required in "auto" tool choice case
all_previous_token_ids = [[] for _ in range(num_choices)]
reasoning_end_arr = [False] * num_choices
Expand Down Expand Up @@ -748,7 +764,7 @@ async def chat_completion_stream_generator(
delta_message: DeltaMessage | None

# just update previous_texts and previous_token_ids
if tool_choice_auto or reasoning_parser:
if is_mistral_grammar_path or tool_choice_auto or reasoning_parser:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
Expand All @@ -772,6 +788,30 @@ async def chat_completion_stream_generator(
)
)
harmony_tools_streamed[i] |= tools_streamed_flag
# Mistral grammar path: combined reasoning + tool streaming
elif is_mistral_grammar_path:
assert tool_parser is not None
assert isinstance(tool_parser, MistralToolParser)
assert reasoning_end_arr is not None
output_token_ids = as_list(output.token_ids)
result = tool_parser.extract_maybe_reasoning_and_tool_streaming(
reasoning_parser=reasoning_parser,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
output_token_ids=output_token_ids,
reasoning_ended=reasoning_end_arr[i],
prompt_is_reasoning_end=(prompt_is_reasoning_end_arr[i]),
request=request,
)
delta_message = result.delta_message
reasoning_end_arr[i] = result.reasoning_ended
current_text = result.current_text
current_token_ids = result.current_token_ids
if result.tools_called:
tools_streamed[i] = True
# handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name:
# When encountering think end id in prompt_token_ids
Expand Down Expand Up @@ -925,7 +965,9 @@ async def chat_completion_stream_generator(
delta_message = DeltaMessage(content=delta_text)

# update the previous values for the next iteration
if (tool_choice_auto or reasoning_parser) and not self.use_harmony:
if (
is_mistral_grammar_path or tool_choice_auto or reasoning_parser
) and not self.use_harmony:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_texts[i] = current_text
Expand Down Expand Up @@ -1312,7 +1354,24 @@ async def chat_completion_full_generator(
tool_call_class = (
MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall
)
if (not self.enable_auto_tools or not self.tool_parser) and (

use_mistral_tool_parser = request._grammar_from_tool_parser
if use_mistral_tool_parser:
tool_call_items = MistralToolParser.build_non_streaming_tool_calls(
tool_calls
)
if tool_call_items:
auto_tools_called = (
request.tool_choice is None or request.tool_choice == "auto"
)
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_call_items,
)

elif (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required"
):
Expand Down
34 changes: 26 additions & 8 deletions vllm/entrypoints/openai/engine/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
Expand Down Expand Up @@ -610,24 +611,39 @@ def _parse_tool_calls_from_content(
tool_parser_cls: type[ToolParser] | None,
content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]:
# When the Mistral grammar factory injected structured outputs,
# let the parser handle the output.
use_mistral_tool_parser = (
isinstance(request, ChatCompletionRequest)
and tool_parser_cls is not None
and issubclass(tool_parser_cls, MistralToolParser)
and request._grammar_from_tool_parser
)

function_calls = list[FunctionCall]()
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
if (
not use_mistral_tool_parser
and request.tool_choice
and isinstance(request.tool_choice, ToolChoiceFunction)
):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
elif (
not use_mistral_tool_parser
and request.tool_choice
and isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.function.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice == "required":
elif not use_mistral_tool_parser and request.tool_choice == "required":
tool_calls = []
with contextlib.suppress(ValidationError):
content = content or ""
Expand All @@ -642,10 +658,12 @@ def _parse_tool_calls_from_content(
)
)
content = None # Clear content since tool is called.
elif (
tool_parser_cls
and enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
elif tool_parser_cls and (
use_mistral_tool_parser
or (
enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
)
):
if tokenizer is None:
raise ValueError(
Expand Down
14 changes: 12 additions & 2 deletions vllm/entrypoints/serve/render/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
prompt_to_seq,
)
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.utils import random_uuid
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt
Expand Down Expand Up @@ -555,17 +556,26 @@ async def preprocess_chat(
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
#
# Exception: Mistral grammar-capable tokenizers always call
# adjust_request — even for tool_choice="none" — so that the grammar
# factory can prevent special-token leakage.
if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none":
tokenizer = renderer.get_tokenizer()
is_mistral_grammar_eligible = (
issubclass(tool_parser, MistralToolParser)
and is_mistral_tokenizer(tokenizer)
and tokenizer.supports_grammar
)
if tool_choice != "none" or is_mistral_grammar_eligible:
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported "
"for Chat Completions API or Responses API requests, "
f"but got {type(request).__name__}"
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer, request.tools).adjust_request(
request=request
)
Expand Down
14 changes: 12 additions & 2 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool:
return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken


def _get_llg_tokenizer(tokenizer: TokenizerLike) -> Any:
return tokenizer.llg_tokenizer if is_mistral_tokenizer(tokenizer) else None


class SamplingParams(
PydanticMsgspecMixin,
msgspec.Struct,
Expand Down Expand Up @@ -816,7 +820,10 @@ def _validate_structured_outputs(
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(self, tokenizer=None)
validate_guidance_grammar(
self,
tokenizer=_get_llg_tokenizer(tokenizer),
)
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(self)
Expand Down Expand Up @@ -862,7 +869,10 @@ def _validate_structured_outputs(
self.structured_outputs._backend = "outlines"
else:
# Fall back to guidance by default.
validate_guidance_grammar(self, tokenizer=None)
validate_guidance_grammar(
self,
tokenizer=_get_llg_tokenizer(tokenizer),
)
self.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
self.structured_outputs._backend_was_auto = True
Expand Down
Loading
Loading