diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index 064ccb39ef4b..473eb716266f 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -3,6 +3,7 @@ import json from collections.abc import Generator +from typing import Any from unittest.mock import MagicMock, patch import partial_json_parser @@ -23,24 +24,33 @@ ToolChoiceEnum as MistralToolChoiceEnum, ) from partial_json_parser.core.options import Allow +from pydantic import ValidationError from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, DeltaMessage, DeltaToolCall, + ExtractedToolCallInformation, StructuralTagResponseFormat, ) +from vllm.entrypoints.openai.engine.protocol import FunctionCall as VllmFunctionCall +from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally from vllm.tokenizers.mistral import MistralTokenizer from vllm.tool_parsers.mistral_tool_parser import ( _DEFAULT_JSON_SCHEMA, + MistralStreamingResult, + MistralToolCall, MistralToolParser, ) +_DUMMY_REQUEST = ChatCompletionRequest(messages=[], model="test") + @pytest.fixture(scope="module") def mistral_pre_v11_tokenizer(): @@ -205,7 +215,7 @@ def stream_delta_message_generator( previous_token_ids, current_token_ids, delta_token_ids, - request=None, # type: ignore[arg-type] + request=_DUMMY_REQUEST, ) if delta_message: yield delta_message @@ -218,14 +228,18 @@ def stream_delta_message_generator( read_offset = new_read_offset -def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser): +@pytest.mark.parametrize( + "parser_fixture", + ["mistral_pre_v11_tool_parser", "mistral_tool_parser"], + ids=["pre_v11", "v11"], +) +def test_extract_tool_calls_no_tools(parser_fixture, request): + parser = request.getfixturevalue(parser_fixture) model_output = "This is a test" - extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls( - model_output, request=None - ) # type: ignore[arg-type] - assert not extracted_tool_calls.tools_called - assert extracted_tool_calls.tool_calls == [] - assert extracted_tool_calls.content == model_output + result = parser.extract_tool_calls(model_output, request=_DUMMY_REQUEST) + assert result == ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) @pytest.mark.parametrize( @@ -234,6 +248,8 @@ def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser): "single_tool_weather", "argument_before_name", "argument_before_name_and_name_in_argument", + "multiple_tools", + "content_before_tool", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -292,14 +308,44 @@ def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser): ], None, ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + None, + ), + ( + """Hello[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 1, "b": 2}) + ) + ) + ], + "Hello", + ), ], ) def test_extract_tool_calls_pre_v11_tokenizer( mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content ): extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls( - model_output, request=None - ) # type: ignore[arg-type] + model_output, request=_DUMMY_REQUEST + ) assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -307,6 +353,46 @@ def test_extract_tool_calls_pre_v11_tokenizer( assert extracted_tool_calls.content == expected_content +def test_extract_tool_calls_pre_v11_multiple_bot_tokens_raises( + mistral_pre_v11_tool_parser, +): + model_output = ( + '[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1}}]' + '[TOOL_CALLS] [{"name": "sub", "arguments":{"b": 2}}]' + ) + with pytest.raises(ValueError, match="Only one BOT token"): + mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=_DUMMY_REQUEST + ) + + +def test_extract_tool_calls_pre_v11_regex_fallback_raises( + mistral_pre_v11_tool_parser, +): + """The regex fallback path finds valid JSON but does not re-serialize + the `arguments` dict to a string, causing a Pydantic + `ValidationError` when constructing `FunctionCall`.""" + model_output = ( + '[TOOL_CALLS] junk [{"name": "add", "arguments":{"a": 1, "b": 2}}] trail' + ) + with pytest.raises(ValidationError): + mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=_DUMMY_REQUEST + ) + + +def test_extract_tool_calls_pre_v11_regex_fallback_fails( + mistral_pre_v11_tool_parser, +): + model_output = "[TOOL_CALLS] not json at all" + result = mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=_DUMMY_REQUEST + ) + assert result == ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content="not json at all" + ) + + @pytest.mark.parametrize( ids=[ "single_tool_add", @@ -395,8 +481,8 @@ def test_extract_tool_calls( mistral_tool_parser, model_output, expected_tool_calls, expected_content ): extracted_tool_calls = mistral_tool_parser.extract_tool_calls( - model_output, request=None - ) # type: ignore[arg-type] + model_output, request=_DUMMY_REQUEST + ) assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -404,6 +490,16 @@ def test_extract_tool_calls( assert extracted_tool_calls.content == expected_content +def test_extract_tool_calls_v11_without_args_skipped(mistral_tool_parser): + model_output = "[TOOL_CALLS]toolname_no_args" + result = mistral_tool_parser.extract_tool_calls( + model_output, request=_DUMMY_REQUEST + ) + assert result == ExtractedToolCallInformation( + tools_called=True, tool_calls=[], content=None + ) + + def _test_extract_tool_calls_streaming( tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content ): @@ -669,17 +765,65 @@ def test_extract_tool_calls_streaming( ) +def test_extract_tool_calls_streaming_v11_no_tools( + mistral_tool_parser, mistral_tokenizer +): + model_output = "This is a test" + if isinstance(mistral_tokenizer, MistralTokenizer): + all_token_ids = mistral_tokenizer.encode(model_output) + else: + all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False) + skip_special = isinstance(mistral_tokenizer, MistralTokenizer) + collected_content = "" + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i in range(len(all_token_ids)): + current_token_ids = all_token_ids[: i + 1] + previous_token_ids = all_token_ids[:i] + delta_token_ids = [all_token_ids[i]] + + new_tokens, delta_text, prefix_offset, read_offset = detokenize_incrementally( + tokenizer=mistral_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=skip_special, + spaces_between_special_tokens=True, + ) + current_text = previous_text + delta_text + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=_DUMMY_REQUEST, + ) + if delta_message and delta_message.content: + collected_content += delta_message.content + if delta_message: + assert not delta_message.tool_calls + + previous_text = current_text + + assert collected_content == model_output + + @pytest.mark.parametrize( - ids=[ - "single_tool_add", - "single_tool_weather", - "multiple_tool_calls", - "content_before_tool", - "complex", - ], - argnames=["model_output", "expected_tool_calls", "expected_content"], - argvalues=[ - ( + "parser_fixture, tokenizer_fixture, model_output," + " expected_tool_calls, expected_content", + [ + pytest.param( + "mistral_tool_parser", + "mistral_tokenizer", """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 [ ToolCall( @@ -690,8 +834,11 @@ def test_extract_tool_calls_streaming( ) ], "", + id="v11-single_tool_add", ), - ( + pytest.param( + "mistral_tool_parser", + "mistral_tokenizer", """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 [ ToolCall( @@ -704,8 +851,11 @@ def test_extract_tool_calls_streaming( ) ], "", + id="v11-single_tool_weather", ), - ( + pytest.param( + "mistral_tool_parser", + "mistral_tokenizer", """[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501 [ ToolCall( @@ -720,9 +870,11 @@ def test_extract_tool_calls_streaming( ), ], "", + id="v11-multiple_tool_calls", ), - ( - # Additional content should not be after the tool calls + pytest.param( + "mistral_tool_parser", + "mistral_tokenizer", """bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 [ ToolCall( @@ -733,9 +885,11 @@ def test_extract_tool_calls_streaming( ) ], "bla", + id="v11-content_before_tool", ), - ( - # Complex + pytest.param( + "mistral_tool_parser", + "mistral_tokenizer", """hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 [ ToolCall( @@ -748,58 +902,19 @@ def test_extract_tool_calls_streaming( ) ], "hi{hi", + id="v11-complex", ), - ], -) -def test_extract_tool_calls_streaming_one_chunk( - mistral_tool_parser, - mistral_tokenizer, - model_output, - expected_tool_calls, - expected_content, -): - if isinstance(mistral_tokenizer, MistralTokenizer): - all_token_ids = mistral_tokenizer.encode(model_output) - else: - all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False) - all_token_ids = fix_tool_call_tokenization( - all_token_ids, mistral_tool_parser, mistral_tokenizer - ) - - delta_message = mistral_tool_parser.extract_tool_calls_streaming( - previous_text="", - current_text=model_output, - delta_text=model_output, - previous_token_ids=[], - current_token_ids=all_token_ids, - delta_token_ids=all_token_ids, - request=None, - ) # type: ignore[arg-type] - assert isinstance(delta_message, DeltaMessage) - assert len(delta_message.tool_calls) == len(expected_tool_calls) - - assert_tool_calls(delta_message.tool_calls, expected_tool_calls) - - if delta_message.content is None: - assert expected_content == "" - else: - assert delta_message.content == expected_content - - -@pytest.mark.parametrize( - ids=[ - "no_tools", - "single_tool_add", - "single_tool_add_strings", - "single_tool_weather", - "argument_before_name", - "argument_before_name_and_name_in_argument", - "multiple_tools", - ], - argnames=["model_output", "expected_tool_calls", "expected_content"], - argvalues=[ - ("""This is a test""", [], """This is a test"""), - ( + pytest.param( + "mistral_pre_v11_tool_parser", + "mistral_pre_v11_tokenizer", + """This is a test""", + [], + """This is a test""", + id="pre_v11-no_tools", + ), + pytest.param( + "mistral_pre_v11_tool_parser", + "mistral_pre_v11_tokenizer", """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 [ ToolCall( @@ -809,8 +924,11 @@ def test_extract_tool_calls_streaming_one_chunk( ) ], "", + id="pre_v11-single_tool_add", ), - ( + pytest.param( + "mistral_pre_v11_tool_parser", + "mistral_pre_v11_tokenizer", """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 [ ToolCall( @@ -820,8 +938,11 @@ def test_extract_tool_calls_streaming_one_chunk( ) ], "", + id="pre_v11-single_tool_add_strings", ), - ( + pytest.param( + "mistral_pre_v11_tool_parser", + "mistral_pre_v11_tokenizer", """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 [ ToolCall( @@ -834,8 +955,11 @@ def test_extract_tool_calls_streaming_one_chunk( ) ], "", + id="pre_v11-single_tool_weather", ), - ( + pytest.param( + "mistral_pre_v11_tool_parser", + "mistral_pre_v11_tokenizer", """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 [ ToolCall( @@ -848,8 +972,11 @@ def test_extract_tool_calls_streaming_one_chunk( ) ], "", + id="pre_v11-argument_before_name", ), - ( + pytest.param( + "mistral_pre_v11_tool_parser", + "mistral_pre_v11_tokenizer", """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 [ ToolCall( @@ -864,8 +991,11 @@ def test_extract_tool_calls_streaming_one_chunk( ) ], "", + id="pre_v11-argument_before_name_and_name_in_argument", ), - ( + pytest.param( + "mistral_pre_v11_tool_parser", + "mistral_pre_v11_tokenizer", """[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 [ ToolCall( @@ -883,35 +1013,50 @@ def test_extract_tool_calls_streaming_one_chunk( ), ], "", + id="pre_v11-multiple_tools", + ), + pytest.param( + "mistral_pre_v11_tool_parser", + "mistral_pre_v11_tokenizer", + """Some text[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 1, "b": 2}) + ) + ) + ], + "Some text", + id="pre_v11-content_before_tool", ), ], ) -def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk( - mistral_pre_v11_tool_parser, - mistral_pre_v11_tokenizer, +def test_extract_tool_calls_streaming_one_chunk( + parser_fixture, + tokenizer_fixture, model_output, expected_tool_calls, expected_content, + request, ): - if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer): - all_token_ids = mistral_pre_v11_tokenizer.encode(model_output) + tool_parser = request.getfixturevalue(parser_fixture) + tokenizer = request.getfixturevalue(tokenizer_fixture) + + if isinstance(tokenizer, MistralTokenizer): + all_token_ids = tokenizer.encode(model_output) else: - all_token_ids = mistral_pre_v11_tokenizer.encode( - model_output, add_special_tokens=False - ) - all_token_ids = fix_tool_call_tokenization( - all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer - ) + all_token_ids = tokenizer.encode(model_output, add_special_tokens=False) + all_token_ids = fix_tool_call_tokenization(all_token_ids, tool_parser, tokenizer) - delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming( + delta_message = tool_parser.extract_tool_calls_streaming( previous_text="", current_text=model_output, delta_text=model_output, previous_token_ids=[], current_token_ids=all_token_ids, delta_token_ids=all_token_ids, - request=None, - ) # type: ignore[arg-type] + request=_DUMMY_REQUEST, + ) assert isinstance(delta_message, DeltaMessage) assert len(delta_message.tool_calls) == len(expected_tool_calls) @@ -923,65 +1068,105 @@ def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk( assert delta_message.content == expected_content -def test_fast_detokenization_text_detection(mistral_tool_parser): +@pytest.mark.parametrize( + "parser_fixture, model_output, fake_count, two_phase", + [ + pytest.param( + "mistral_tool_parser", + '[TOOL_CALLS]add{"a": 1, "b": 2}', + 20, + True, + id="v11", + ), + pytest.param( + "mistral_pre_v11_tool_parser", + '[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]', + 30, + False, + id="pre_v11", + ), + ], +) +def test_fast_detokenization_text_detection( + parser_fixture, model_output, fake_count, two_phase, request +): """Regression: bot_token in text but not token_ids (PR #37209).""" - model_output = '[TOOL_CALLS]add{"a": 1, "b": 2}' + parser = request.getfixturevalue(parser_fixture) # Token IDs that do NOT contain bot_token_id. - fake_token_ids = list(range(99, 99 + 20)) - - # First delta: pure content, no bot token yet - delta_message_before = mistral_tool_parser.extract_tool_calls_streaming( - previous_text="", - current_text="Hello", - delta_text="Hello", - previous_token_ids=[], - current_token_ids=[99], - delta_token_ids=[99], - request=None, - ) - assert delta_message_before is not None - assert delta_message_before.content == "Hello" - assert not delta_message_before.tool_calls - - # Second delta: bot token in text but NOT in token_ids - delta_message = mistral_tool_parser.extract_tool_calls_streaming( - previous_text="Hello", - current_text="Hello" + model_output, + fake_token_ids = list(range(99, 99 + fake_count)) + + if two_phase: + # First delta: pure content, no bot token yet + delta_message_before = parser.extract_tool_calls_streaming( + previous_text="", + current_text="Hello", + delta_text="Hello", + previous_token_ids=[], + current_token_ids=[99], + delta_token_ids=[99], + request=_DUMMY_REQUEST, + ) + assert delta_message_before is not None + assert delta_message_before.content == "Hello" + assert not delta_message_before.tool_calls + + previous_text = "Hello" + current_text = "Hello" + model_output + previous_token_ids = [99] + delta_token_ids = fake_token_ids[1:] + else: + previous_text = "" + current_text = model_output + previous_token_ids = [] + delta_token_ids = fake_token_ids + + delta_message = parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, delta_text=model_output, - previous_token_ids=[99], + previous_token_ids=previous_token_ids, current_token_ids=fake_token_ids, - delta_token_ids=fake_token_ids[1:], - request=None, + delta_token_ids=delta_token_ids, + request=_DUMMY_REQUEST, ) assert delta_message is not None assert delta_message.tool_calls is not None - assert len(delta_message.tool_calls) > 0 + assert len(delta_message.tool_calls) == 1 assert delta_message.tool_calls[0].function is not None assert delta_message.tool_calls[0].function.name == "add" -def test_fast_detokenization_text_detection_pre_v11( - mistral_pre_v11_tool_parser, +@pytest.mark.parametrize( + "parser_fixture, patched_method, current_text", + [ + ( + "mistral_tool_parser", + "_extract_tool_calls_streaming", + "[TOOL_CALLS]add{}", + ), + ( + "mistral_pre_v11_tool_parser", + "_extract_tool_calls_streaming_pre_v11_tokenizer", + '[TOOL_CALLS] [{"name":"a","arguments":{}}]', + ), + ], + ids=["v11", "pre_v11"], +) +def test_extract_tool_calls_streaming_exception_returns_none( + parser_fixture, patched_method, current_text, request ): - """Regression: bot_token text detection for pre-v11 tokenizer (PR #37209).""" - model_output = '[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]' - - fake_token_ids = list(range(99, 99 + 30)) - - delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming( - previous_text="", - current_text=model_output, - delta_text=model_output, - previous_token_ids=[], - current_token_ids=fake_token_ids, - delta_token_ids=fake_token_ids, - request=None, - ) - assert delta_message is not None - assert delta_message.tool_calls is not None - assert len(delta_message.tool_calls) > 0 - assert delta_message.tool_calls[0].function is not None - assert delta_message.tool_calls[0].function.name == "add" + parser = request.getfixturevalue(parser_fixture) + with patch.object(parser, patched_method, side_effect=RuntimeError("boom")): + result = parser.extract_tool_calls_streaming( + previous_text="", + current_text=current_text, + delta_text=current_text, + previous_token_ids=[], + current_token_ids=[parser.bot_token_id], + delta_token_ids=[parser.bot_token_id], + request=_DUMMY_REQUEST, + ) + assert result is None SAMPLE_TOOLS_DICTS = [ @@ -1238,57 +1423,444 @@ def test_adjust_request_response_format_generates_grammar( assert len(result.structured_outputs.grammar) > 0 -def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory( +@pytest.mark.parametrize( + "tool_choice, expected_method, not_called_method", + [ + ("none", "get_lark_for_json_schema", None), + ("auto", "get_lark_from_jinja", "get_lark_for_json_schema"), + ], + ids=["none_uses_json_schema_factory", "auto_uses_jinja_factory"], +) +def test_adjust_request_tool_choice_with_json_schema_factory_routing( mistral_tool_parser: MistralToolParser, + tool_choice: str, + expected_method: str, + not_called_method: str | None, ) -> None: request = _make_request( - tool_choice="none", + tool_choice=tool_choice, structured_outputs=StructuredOutputsParams(json='{"type": "object"}'), ) factory = mistral_tool_parser.model_tokenizer.grammar_factory - with patch.object( - factory, - "get_lark_for_json_schema", - wraps=factory.get_lark_for_json_schema, - ) as mock_json_schema: - result = mistral_tool_parser.adjust_request(request) + patches = { + expected_method: patch.object( + factory, + expected_method, + wraps=getattr(factory, expected_method), + ), + } + if not_called_method: + patches[not_called_method] = patch.object( + factory, + not_called_method, + wraps=getattr(factory, not_called_method), + ) + + with patches[expected_method] as mock_expected: + ctx = patches[not_called_method] if not_called_method else None + if ctx: + with ctx as mock_not_called: + result = mistral_tool_parser.adjust_request(request) + mock_not_called.assert_not_called() + else: + result = mistral_tool_parser.adjust_request(request) - mock_json_schema.assert_called_once() - assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"} + mock_expected.assert_called_once() + assert mock_expected.call_args.kwargs["json_schema"] == {"type": "object"} assert result.structured_outputs is not None assert isinstance(result.structured_outputs.grammar, str) assert len(result.structured_outputs.grammar) > 0 -def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory( +def test_grammar_from_tool_parser_default_false() -> None: + request = _make_request() + assert request._grammar_from_tool_parser is False + + +def test_grammar_from_tool_parser_set_by_adjust_request( mistral_tool_parser: MistralToolParser, ) -> None: - request = _make_request( - tool_choice="auto", - structured_outputs=StructuredOutputsParams(json='{"type": "object"}'), - ) - factory = mistral_tool_parser.model_tokenizer.grammar_factory + request = _make_request() + result = mistral_tool_parser.adjust_request(request) + assert result._grammar_from_tool_parser is True - with ( - patch.object( - factory, - "get_lark_for_json_schema", - wraps=factory.get_lark_for_json_schema, - ) as mock_json_schema, - patch.object( - factory, - "get_lark_from_jinja", - wraps=factory.get_lark_from_jinja, - ) as mock_jinja, - ): - result = mistral_tool_parser.adjust_request(request) - mock_jinja.assert_called_once() - assert mock_jinja.call_args.kwargs["json_schema"] == {"type": "object"} - mock_json_schema.assert_not_called() +@pytest.mark.parametrize( + "tool_calls, expected_len", + [ + (None, 0), + ([], 0), + ([VllmFunctionCall(id="abc123xyz", name="f", arguments="{}")], 1), + ([VllmFunctionCall(name="f", arguments="{}")], 1), + ( + [ + VllmFunctionCall(id="fixed1234", name="a", arguments='{"x": 1}'), + VllmFunctionCall(name="b", arguments='{"y": 2}'), + ], + 2, + ), + ], + ids=["none", "empty", "with_id", "without_id", "mixed"], +) +def test_build_non_streaming_tool_calls( + tool_calls: list[VllmFunctionCall] | None, + expected_len: int, +) -> None: + result = MistralToolParser.build_non_streaming_tool_calls(tool_calls) + assert len(result) == expected_len - assert result.structured_outputs is not None - assert isinstance(result.structured_outputs.grammar, str) - assert len(result.structured_outputs.grammar) > 0 + if tool_calls is None: + return + + for i, tc in enumerate(result): + assert isinstance(tc, MistralToolCall) + assert tc.type == "function" + + input_tc = tool_calls[i] + if input_tc.id: + assert tc.id == input_tc.id + else: + assert len(tc.id) == 9 + assert tc.id.isalnum() + + assert tc.function.name == input_tc.name + assert tc.function.arguments == input_tc.arguments + + +class TestExtractMaybeReasoningAndToolStreaming: + r"""Tests for `MistralToolParser.extract_maybe_reasoning_and_tool_streaming`.""" + + @pytest.fixture + def parser(self) -> MistralToolParser: + mock_tokenizer = MagicMock() + mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1} + return MistralToolParser(mock_tokenizer) + + @pytest.fixture + def request_obj(self) -> ChatCompletionRequest: + return _make_request() + + @staticmethod + def _call( + parser: MistralToolParser, + request: ChatCompletionRequest, + *, + reasoning_parser: Any = None, + previous_text: str = "", + current_text: str = "hello", + delta_text: str = "hello", + previous_token_ids: list[int] | None = None, + current_token_ids: list[int] | None = None, + output_token_ids: list[int] | None = None, + reasoning_ended: bool = False, + prompt_is_reasoning_end: bool | None = None, + ) -> MistralStreamingResult: + return parser.extract_maybe_reasoning_and_tool_streaming( + reasoning_parser=reasoning_parser, + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids or [], + current_token_ids=current_token_ids or [1, 2, 3], + output_token_ids=output_token_ids or [1, 2, 3], + reasoning_ended=reasoning_ended, + prompt_is_reasoning_end=prompt_is_reasoning_end, + request=request, + ) + + def test_no_reasoning_tools_called( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + tool_delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + function=DeltaFunctionCall(name="f", arguments="{}"), + ) + ] + ) + with patch.object( + parser, "extract_tool_calls_streaming", return_value=tool_delta + ): + result = self._call(parser, request_obj, reasoning_parser=None) + + assert result == MistralStreamingResult( + delta_message=tool_delta, + reasoning_ended=False, + tools_called=True, + current_text="hello", + current_token_ids=[1, 2, 3], + ) + + def test_no_reasoning_no_tools( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + content_delta = DeltaMessage(content="hello") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call(parser, request_obj, reasoning_parser=None) + + assert result == MistralStreamingResult( + delta_message=content_delta, + reasoning_ended=False, + tools_called=False, + current_text="hello", + current_token_ids=[1, 2, 3], + ) + + def test_mistral_reasoning_parser_no_think_token( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + content_delta = DeltaMessage(content="direct") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 2, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_not_called() + assert result == MistralStreamingResult( + delta_message=content_delta, + reasoning_ended=False, + tools_called=False, + current_text="hello", + current_token_ids=[1, 2, 3], + ) + + def test_mistral_reasoning_parser_with_think_token( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 999, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert result == MistralStreamingResult( + delta_message=DeltaMessage(reasoning="thinking..."), + reasoning_ended=False, + tools_called=False, + current_text="hello", + current_token_ids=[1, 999, 3], + ) + + def test_non_mistral_reasoning_parser_always_expects_thinking( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_rp = MagicMock() + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[1, 2, 3], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert result == MistralStreamingResult( + delta_message=DeltaMessage(reasoning="thinking..."), + reasoning_ended=False, + tools_called=False, + current_text="hello", + current_token_ids=[1, 2, 3], + ) + + def test_reasoning_already_ended_no_reset( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + content_delta = DeltaMessage(content="content") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ) as mock_extract: + result = self._call( + parser, + request_obj, + reasoning_parser=MagicMock(), + reasoning_ended=True, + previous_text="prior_tool_text", + previous_token_ids=[10, 20], + current_text="prior_tool_texthello", + current_token_ids=[10, 20, 1, 2, 3], + ) + + _, call_kwargs = mock_extract.call_args + assert call_kwargs["previous_text"] == "prior_tool_text" + assert call_kwargs["previous_token_ids"] == [10, 20] + + assert result == MistralStreamingResult( + delta_message=content_delta, + reasoning_ended=True, + tools_called=False, + current_text="prior_tool_texthello", + current_token_ids=[10, 20, 1, 2, 3], + ) + + def test_pre_v15_ignores_prompt_reasoning_end( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_tokenizer = MagicMock(spec=MistralTokenizer) + mock_tokenizer.version = 13 + parser.model_tokenizer = mock_tokenizer + + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="thinking..." + ) + mock_rp.is_reasoning_end_streaming.return_value = False + + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + prompt_is_reasoning_end=True, + current_token_ids=[999, 1, 2], + ) + + mock_rp.extract_reasoning_streaming.assert_called_once() + assert result == MistralStreamingResult( + delta_message=DeltaMessage(reasoning="thinking..."), + reasoning_ended=False, + tools_called=False, + current_text="hello", + current_token_ids=[999, 1, 2], + ) + + def test_non_pre_v15_prompt_reasoning_end( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + mock_tokenizer = MagicMock(spec=MistralTokenizer) + mock_tokenizer.version = 15 + parser.model_tokenizer = mock_tokenizer + + mock_rp = MagicMock(spec=MistralReasoningParser) + mock_rp.start_token_id = 999 + + content_delta = DeltaMessage(content="after reasoning") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ): + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + prompt_is_reasoning_end=True, + current_token_ids=[999, 1, 2], + output_token_ids=[10, 20, 30], + ) + + mock_rp.extract_reasoning_streaming.assert_not_called() + assert result == MistralStreamingResult( + delta_message=content_delta, + reasoning_ended=True, + tools_called=False, + current_text="hello", + current_token_ids=[10, 20, 30], + ) + + def test_reasoning_end_transition_with_content( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + """When reasoning ends and the delta has content, that content is + cleared from delta_message and used as current_text for tool parsing.""" + mock_rp = MagicMock() + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="think", content="leftover" + ) + mock_rp.is_reasoning_end_streaming.return_value = True + mock_rp.extract_content_ids.return_value = [50, 51] + + content_delta = DeltaMessage(content="leftover") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=content_delta + ) as mock_extract: + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[999, 1, 2], + output_token_ids=[10, 20, 30], + ) + + mock_rp.extract_content_ids.assert_called_once_with([10, 20, 30]) + _, call_kwargs = mock_extract.call_args + assert call_kwargs["previous_text"] == "" + assert call_kwargs["previous_token_ids"] == [] + assert call_kwargs["delta_text"] == "leftover" + assert call_kwargs["current_token_ids"] == [50, 51] + + assert result == MistralStreamingResult( + delta_message=content_delta, + reasoning_ended=True, + tools_called=False, + current_text="leftover", + current_token_ids=[50, 51], + ) + + def test_reasoning_end_transition_without_content( + self, parser: MistralToolParser, request_obj: ChatCompletionRequest + ) -> None: + """When reasoning ends but the delta has no content, current_text + is set to empty string.""" + mock_rp = MagicMock() + mock_rp.start_token_id = 999 + mock_rp.extract_reasoning_streaming.return_value = DeltaMessage( + reasoning="think" + ) + mock_rp.is_reasoning_end_streaming.return_value = True + mock_rp.extract_content_ids.return_value = [50, 51] + + empty_delta = DeltaMessage(content="") + with patch.object( + parser, "extract_tool_calls_streaming", return_value=empty_delta + ) as mock_extract: + result = self._call( + parser, + request_obj, + reasoning_parser=mock_rp, + reasoning_ended=False, + current_token_ids=[999, 1, 2], + output_token_ids=[10, 20, 30], + ) + + _, call_kwargs = mock_extract.call_args + assert call_kwargs["delta_text"] == "" + assert call_kwargs["current_token_ids"] == [50, 51] + + assert result == MistralStreamingResult( + delta_message=empty_delta, + reasoning_ended=True, + tools_called=False, + current_text="", + current_token_ids=[50, 51], + ) diff --git a/tests/tool_use/mistral/test_mistral_tool_calls.py b/tests/tool_use/mistral/test_mistral_tool_calls.py index 3c4a543abe41..6dcfd43a9497 100644 --- a/tests/tool_use/mistral/test_mistral_tool_calls.py +++ b/tests/tool_use/mistral/test_mistral_tool_calls.py @@ -1,25 +1,198 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from dataclasses import dataclass, field + import openai import pytest -from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL +from tests.tool_use.utils import ( + MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_ASKING_FOR_TOOLS, + MESSAGES_WITH_TOOL_RESPONSE, + MESSAGES_WITHOUT_TOOLS, + SEARCH_TOOL, + SEED, + WEATHER_TOOL, + ensure_system_prompt, +) + +from .utils import ServerConfig + + +def _requires_tool_parser(server_config: ServerConfig) -> None: + r"""Skip test if server was not started with --tool-call-parser.""" + if "--tool-call-parser" not in server_config.get("arguments", []): + pytest.skip( + f"Skipping: {server_config['model']} not configured with --tool-call-parser" + ) + + +def _is_pre_v11(server_config: ServerConfig) -> bool: + r"""Pre-v11 Mistral models lack grammar-based tool call enforcement.""" + return "7B" in server_config.get("model", "") + + +@dataclass +class StreamedToolCallResult: + r"""Accumulated result from streaming a single tool call.""" + + function_name: str | None = None + function_args_str: str = "" + tool_call_id: str | None = None + role_name: str | None = None + finish_reason_count: int = 0 + finish_reason: str | None = None + + +async def _collect_streamed_tool_call( + stream: openai.AsyncStream, + *, + expected_finish_reason: str = "tool_calls", +) -> StreamedToolCallResult: + result = StreamedToolCallResult() + + async for chunk in stream: + if chunk.choices[0].finish_reason: + result.finish_reason_count += 1 + result.finish_reason = chunk.choices[0].finish_reason + assert chunk.choices[0].finish_reason == expected_finish_reason + + if chunk.choices[0].delta.role: + assert not result.role_name or result.role_name == "assistant" + result.role_name = "assistant" + + streamed_tool_calls = chunk.choices[0].delta.tool_calls + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + if tool_call.id: + assert not result.tool_call_id + result.tool_call_id = tool_call.id + + if tool_call.function: + if tool_call.function.name: + assert result.function_name is None + result.function_name = tool_call.function.name + if tool_call.function.arguments: + result.function_args_str += tool_call.function.arguments + + return result + + +@dataclass +class StreamedContentResult: + r"""Accumulated result from streaming a content-only response.""" + + chunks: list[str] = field(default_factory=list) + finish_reason_count: int = 0 + finish_reason: str | None = None + role_sent: bool = False + + +async def _collect_streamed_content( + stream: openai.AsyncStream, + *, + expected_finish_reason: str | None = None, + no_tool_calls: bool = True, +) -> StreamedContentResult: + r"""Consume a streaming response and collect text content.""" + result = StreamedContentResult() + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not result.role_sent + assert delta.role == "assistant" + result.role_sent = True + + if delta.content: + result.chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + result.finish_reason_count += 1 + result.finish_reason = chunk.choices[0].finish_reason + if expected_finish_reason is not None: + assert result.finish_reason == expected_finish_reason + + if no_tool_calls: + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + return result + + +@dataclass +class StreamedParallelToolCallResult: + r"""Accumulated result from streaming parallel tool calls.""" + + function_names: list[str] = field(default_factory=list) + function_args_strs: list[str] = field(default_factory=list) + tool_call_ids: list[str] = field(default_factory=list) + role_name: str | None = None + finish_reason_count: int = 0 + + +async def _collect_streamed_parallel_tool_calls( + stream: openai.AsyncStream, +) -> StreamedParallelToolCallResult: + r"""Consume a streaming response and collect parallel tool calls.""" + result = StreamedParallelToolCallResult() + tool_call_idx: int = -1 + + async for chunk in stream: + if chunk.choices[0].finish_reason: + result.finish_reason_count += 1 + assert chunk.choices[0].finish_reason == "tool_calls" + + if chunk.choices[0].delta.role: + assert not result.role_name or result.role_name == "assistant" + result.role_name = "assistant" + + streamed_tool_calls = chunk.choices[0].delta.tool_calls + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + result.function_args_strs.append("") + result.tool_call_ids.append("") + + if tool_call.id: + result.tool_call_ids[tool_call.index] = tool_call.id + + if tool_call.function: + if tool_call.function.name: + result.function_names.append(tool_call.function.name) + if tool_call.function.arguments: + result.function_args_strs[tool_call.index] += ( + tool_call.function.arguments + ) + + return result # test: a tool_choice with mistral-tokenizer results in an ID of length 9 @pytest.mark.asyncio -async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): +async def test_tool_call_with_tool_choice( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_tool_parser(server_config) + models = await client.models.list() model_name: str = models.data[0].id chat_completion = await client.chat.completions.create( - messages=MESSAGES_ASKING_FOR_TOOLS, + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), temperature=0, max_completion_tokens=100, model=model_name, tools=[WEATHER_TOOL], tool_choice=WEATHER_TOOL, logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -28,3 +201,307 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): assert choice.message.role == "assistant" assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1 assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral + + +_NOT_SET = object() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "tools, tool_choice, streaming_id_len_pre_v11", + [ + pytest.param( + [WEATHER_TOOL, SEARCH_TOOL], + _NOT_SET, + 9, + id="auto", + ), + pytest.param( + [WEATHER_TOOL], + "required", + 30, + id="required", + ), + ], +) +async def test_tool_call_auto_or_required( + client: openai.AsyncOpenAI, + server_config: ServerConfig, + tools: list, + tool_choice: object, + streaming_id_len_pre_v11: int, +) -> None: + _requires_tool_parser(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + create_kwargs: dict = { + "messages": ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + "temperature": 0, + "max_completion_tokens": 100, + "model": model_name, + "tools": tools, + "logprobs": False, + "seed": SEED, + } + if tool_choice is not _NOT_SET: + create_kwargs["tool_choice"] = tool_choice + + # --- non-streaming --- + chat_completion = await client.chat.completions.create(**create_kwargs) + + choice = chat_completion.choices[0] + tool_calls = choice.message.tool_calls + + assert choice.finish_reason == "tool_calls" + assert tool_calls is not None and len(tool_calls) >= 1 + assert tool_calls[0].function.name == "get_current_weather" + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert "city" in parsed_arguments + assert len(tool_calls[0].id) == 9 + + # --- streaming --- + stream = await client.chat.completions.create(**create_kwargs, stream=True) + + result = await _collect_streamed_tool_call(stream) + + assert result.finish_reason_count == 1 + assert result.role_name == "assistant" + assert result.function_name == "get_current_weather" + streamed_args = json.loads(result.function_args_str) + assert isinstance(result.tool_call_id, str) + if _is_pre_v11(server_config): + assert len(result.tool_call_id) == streaming_id_len_pre_v11 + else: + assert len(result.tool_call_id) == 9 + assert parsed_arguments == streamed_args + + +@pytest.mark.asyncio +async def test_tool_call_none_with_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_tool_parser(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="none", + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + # Without grammar enforcement, pre-v11 models may still emit [TOOL_CALLS] + if not _is_pre_v11(server_config): + assert "[TOOL_CALLS]" not in choice.message.content + + non_streaming_content = choice.message.content + + # --- streaming --- + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice="none", + logprobs=False, + seed=SEED, + stream=True, + ) + + # Pre-v11 models lack grammar enforcement, so the model may still + # emit tool calls even with tool_choice="none". + pre_v11 = _is_pre_v11(server_config) + result = await _collect_streamed_content(stream, no_tool_calls=not pre_v11) + + assert result.finish_reason_count == 1 + if not pre_v11: + assert result.finish_reason != "tool_calls" + streamed_content = "".join(result.chunks) + if not pre_v11: + assert "[TOOL_CALLS]" not in streamed_content + assert streamed_content == non_streaming_content + + +@pytest.mark.asyncio +async def test_chat_without_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config), + temperature=0, + max_completion_tokens=150, + model=model_name, + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + output_text = choice.message.content + + assert output_text is not None and len(output_text) > 0 + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + + # --- streaming --- + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config), + temperature=0, + max_completion_tokens=150, + model=model_name, + logprobs=False, + seed=SEED, + stream=True, + ) + + result = await _collect_streamed_content( + stream, expected_finish_reason=choice.finish_reason + ) + + assert result.role_sent + assert result.finish_reason_count == 1 + assert len(result.chunks) + assert "".join(result.chunks) == output_text + + +@pytest.mark.asyncio +async def test_tool_call_with_results( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_tool_parser(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content + + # --- streaming --- + stream = await client.chat.completions.create( + messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config), + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + seed=SEED, + stream=True, + ) + + result = await _collect_streamed_content( + stream, expected_finish_reason=choice.finish_reason + ) + + assert result.role_sent + assert result.finish_reason_count == 1 + assert len(result.chunks) + assert "".join(result.chunks) == choice.message.content + + +def _requires_parallel(server_config: ServerConfig) -> None: + r"""Skip test if the model does not support parallel tool calls.""" + if not server_config.get("supports_parallel"): + pytest.skip( + f"Skipping: {server_config['model']} does not support parallel tool calls" + ) + + +@pytest.mark.asyncio +async def test_tool_call_parallel( + client: openai.AsyncOpenAI, server_config: ServerConfig +) -> None: + _requires_tool_parser(server_config) + _requires_parallel(server_config) + + models = await client.models.list() + model_name: str = models.data[0].id + + # --- non-streaming --- + chat_completion = await client.chat.completions.create( + messages=ensure_system_prompt( + MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config + ), + temperature=0, + max_completion_tokens=200, + model=model_name, + tools=[WEATHER_TOOL], + logprobs=False, + seed=SEED, + ) + + choice = chat_completion.choices[0] + tool_calls = choice.message.tool_calls + + assert choice.finish_reason == "tool_calls" + assert tool_calls is not None and len(tool_calls) >= 2 + for tc in tool_calls: + assert tc.type == "function" + assert tc.function.name == "get_current_weather" + assert isinstance(tc.function.arguments, str) + parsed = json.loads(tc.function.arguments) + assert "city" in parsed + assert len(tc.id) == 9 + + non_streaming_tool_calls = tool_calls + + # --- streaming --- + stream = await client.chat.completions.create( + messages=ensure_system_prompt( + MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config + ), + temperature=0, + max_completion_tokens=200, + model=model_name, + tools=[WEATHER_TOOL], + logprobs=False, + seed=SEED, + stream=True, + ) + + result = await _collect_streamed_parallel_tool_calls(stream) + + assert result.finish_reason_count == 1 + assert result.role_name == "assistant" + assert len(result.function_names) >= 2 + assert all(name == "get_current_weather" for name in result.function_names) + assert len(result.tool_call_ids) >= 2 + assert all(isinstance(tid, str) and len(tid) == 9 for tid in result.tool_call_ids) + + for args_str in result.function_args_strs: + streamed_args = json.loads(args_str) + assert "city" in streamed_args + + assert len(result.function_names) == len(non_streaming_tool_calls) diff --git a/tests/tool_use/mistral/utils.py b/tests/tool_use/mistral/utils.py index 4d772ba63793..01a2aaee6d2e 100644 --- a/tests/tool_use/mistral/utils.py +++ b/tests/tool_use/mistral/utils.py @@ -2,16 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing_extensions import TypedDict - - -class ServerConfig(TypedDict, total=False): - model: str - arguments: list[str] - system_prompt: str | None - supports_parallel: bool | None - supports_rocm: bool | None - +from tests.tool_use.utils import ServerConfig ARGS: list[str] = ["--max-model-len", "1024"] @@ -21,6 +12,11 @@ class ServerConfig(TypedDict, total=False): "arguments": [ "--tokenizer-mode", "mistral", + "--tool-call-parser", + "mistral", + "--enable-auto-tool-choice", + "--enforce-eager", + "--no-enable-prefix-caching", '--ignore-patterns="consolidated.safetensors"', ], "system_prompt": "You are a helpful assistant with access to tools. If a tool" @@ -29,4 +25,22 @@ class ServerConfig(TypedDict, total=False): "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally.", }, + "ministral-3b": { + "model": "mistralai/Ministral-3-3B-Instruct-2512", + "arguments": [ + "--tokenizer-mode", + "mistral", + "--tool-call-parser", + "mistral", + "--enable-auto-tool-choice", + "--enforce-eager", + "--no-enable-prefix-caching", + ], + "system_prompt": "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally.", + "supports_parallel": True, + }, } diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 2bc1b6e08750..437e41fb286e 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -11,7 +11,7 @@ ChatCompletionAudio as OpenAIChatCompletionAudio, ) from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation -from pydantic import Field, model_validator +from pydantic import Field, PrivateAttr, model_validator from vllm.config import ModelConfig from vllm.config.utils import replace @@ -398,6 +398,9 @@ def _materialize_tool_calls_after(self) -> "ChatCompletionRequest": msg["tool_calls"] = list(tool_calls) return self + _grammar_from_tool_parser: bool = PrivateAttr(default=False) + """CAUTION: Should only be set by ``ToolParser.adjust_request``.""" + def build_chat_params( self, default_template: str | None, @@ -822,13 +825,6 @@ def check_system_message_content_type(cls, data): return data - @model_validator(mode="before") - @classmethod - def set_include_reasoning_for_none_effort(cls, data: Any) -> Any: - if data.get("reasoning_effort") == "none": - data["include_reasoning"] = False - return data - class BatchChatCompletionRequest(OpenAIBaseModel): """Request model for the /v1/chat/completions/batch endpoint. diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 0b8dd0aa28ef..446f127a91e3 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -73,7 +73,10 @@ from vllm.renderers import ChatParams from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike -from vllm.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.tool_parsers.mistral_tool_parser import ( + MistralToolCall, + MistralToolParser, +) from vllm.tool_parsers.utils import partial_json_loads from vllm.utils.collection_utils import as_list from vllm.utils.mistral import is_mistral_tokenizer @@ -140,6 +143,12 @@ def __init__( enable_auto_tools=enable_auto_tools, model_name=self.model_config.model, ) + _is_mistral_tool_parser = self.tool_parser is not None and issubclass( + self.tool_parser, MistralToolParser + ) + if _is_mistral_tool_parser and self.reasoning_parser_cls is not None: + MistralToolParser.model_can_reason = True + self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.enable_prompt_tokens_details = enable_prompt_tokens_details @@ -310,6 +319,11 @@ async def create_chat_completion( else: if not request.include_reasoning: reasoning_ended = True + elif request._grammar_from_tool_parser: + # The Mistral grammar already includes an optional + # `think?` rule that handles both reasoning and + # non-reasoning outputs. + reasoning_ended = True elif reasoning_parser: reasoning_ended = reasoning_parser.is_reasoning_end( prompt_token_ids or [] @@ -530,6 +544,8 @@ async def chat_completion_stream_generator( harmony_tools_streamed = [False] * num_choices tools_streamed = [False] * num_choices + is_mistral_grammar_path = request._grammar_from_tool_parser + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name else: @@ -553,7 +569,7 @@ async def chat_completion_stream_generator( # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. - if tool_choice_auto or reasoning_parser: + if is_mistral_grammar_path or tool_choice_auto or reasoning_parser: # These are only required in "auto" tool choice case all_previous_token_ids = [[] for _ in range(num_choices)] reasoning_end_arr = [False] * num_choices @@ -748,7 +764,7 @@ async def chat_completion_stream_generator( delta_message: DeltaMessage | None # just update previous_texts and previous_token_ids - if tool_choice_auto or reasoning_parser: + if is_mistral_grammar_path or tool_choice_auto or reasoning_parser: assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -772,6 +788,30 @@ async def chat_completion_stream_generator( ) ) harmony_tools_streamed[i] |= tools_streamed_flag + # Mistral grammar path: combined reasoning + tool streaming + elif is_mistral_grammar_path: + assert tool_parser is not None + assert isinstance(tool_parser, MistralToolParser) + assert reasoning_end_arr is not None + output_token_ids = as_list(output.token_ids) + result = tool_parser.extract_maybe_reasoning_and_tool_streaming( + reasoning_parser=reasoning_parser, + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + output_token_ids=output_token_ids, + reasoning_ended=reasoning_end_arr[i], + prompt_is_reasoning_end=(prompt_is_reasoning_end_arr[i]), + request=request, + ) + delta_message = result.delta_message + reasoning_end_arr[i] = result.reasoning_ended + current_text = result.current_text + current_token_ids = result.current_token_ids + if result.tools_called: + tools_streamed[i] = True # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: # When encountering think end id in prompt_token_ids @@ -925,7 +965,9 @@ async def chat_completion_stream_generator( delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if (tool_choice_auto or reasoning_parser) and not self.use_harmony: + if ( + is_mistral_grammar_path or tool_choice_auto or reasoning_parser + ) and not self.use_harmony: assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -1312,7 +1354,24 @@ async def chat_completion_full_generator( tool_call_class = ( MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall ) - if (not self.enable_auto_tools or not self.tool_parser) and ( + + use_mistral_tool_parser = request._grammar_from_tool_parser + if use_mistral_tool_parser: + tool_call_items = MistralToolParser.build_non_streaming_tool_calls( + tool_calls + ) + if tool_call_items: + auto_tools_called = ( + request.tool_choice is None or request.tool_choice == "auto" + ) + message = ChatMessage( + role=role, + reasoning=reasoning, + content=content, + tool_calls=tool_call_items, + ) + + elif (not self.enable_auto_tools or not self.tool_parser) and ( not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) and request.tool_choice != "required" ): diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 5bd415b4fbd2..85237604e5a1 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -65,6 +65,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolParser from vllm.tracing import ( contains_trace_headers, extract_trace_headers, @@ -610,16 +611,31 @@ def _parse_tool_calls_from_content( tool_parser_cls: type[ToolParser] | None, content: str | None = None, ) -> tuple[list[FunctionCall] | None, str | None]: + # When the Mistral grammar factory injected structured outputs, + # let the parser handle the output. + use_mistral_tool_parser = ( + isinstance(request, ChatCompletionRequest) + and tool_parser_cls is not None + and issubclass(tool_parser_cls, MistralToolParser) + and request._grammar_from_tool_parser + ) + function_calls = list[FunctionCall]() - if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): + if ( + not use_mistral_tool_parser + and request.tool_choice + and isinstance(request.tool_choice, ToolChoiceFunction) + ): assert content is not None # Forced Function Call function_calls.append( FunctionCall(name=request.tool_choice.name, arguments=content) ) content = None # Clear content since tool is called. - elif request.tool_choice and isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam + elif ( + not use_mistral_tool_parser + and request.tool_choice + and isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) ): assert content is not None # Forced Function Call @@ -627,7 +643,7 @@ def _parse_tool_calls_from_content( FunctionCall(name=request.tool_choice.function.name, arguments=content) ) content = None # Clear content since tool is called. - elif request.tool_choice == "required": + elif not use_mistral_tool_parser and request.tool_choice == "required": tool_calls = [] with contextlib.suppress(ValidationError): content = content or "" @@ -642,10 +658,12 @@ def _parse_tool_calls_from_content( ) ) content = None # Clear content since tool is called. - elif ( - tool_parser_cls - and enable_auto_tools - and (request.tool_choice == "auto" or request.tool_choice is None) + elif tool_parser_cls and ( + use_mistral_tool_parser + or ( + enable_auto_tools + and (request.tool_choice == "auto" or request.tool_choice is None) + ) ): if tokenizer is None: raise ValueError( diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 5aa2449797b3..99b6cb47cb5c 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -53,6 +53,7 @@ prompt_to_seq, ) from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolParser from vllm.utils import random_uuid from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import mt as _mt @@ -555,9 +556,19 @@ async def preprocess_chat( # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser # is set, we want to prevent parsing a tool_call hallucinated by the LLM + # + # Exception: Mistral grammar-capable tokenizers always call + # adjust_request — even for tool_choice="none" — so that the grammar + # factory can prevent special-token leakage. if tool_parser is not None: tool_choice = getattr(request, "tool_choice", "none") - if tool_choice != "none": + tokenizer = renderer.get_tokenizer() + is_mistral_grammar_eligible = ( + issubclass(tool_parser, MistralToolParser) + and is_mistral_tokenizer(tokenizer) + and tokenizer.supports_grammar + ) + if tool_choice != "none" or is_mistral_grammar_eligible: if not isinstance(request, ChatCompletionRequest | ResponsesRequest): msg = ( "Tool usage is only supported " @@ -565,7 +576,6 @@ async def preprocess_chat( f"but got {type(request).__name__}" ) raise NotImplementedError(msg) - tokenizer = renderer.get_tokenizer() request = tool_parser(tokenizer, request.tools).adjust_request( request=request ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 9bcc669591eb..77fa6402180e 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -157,6 +157,10 @@ def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool: return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken +def _get_llg_tokenizer(tokenizer: TokenizerLike) -> Any: + return tokenizer.llg_tokenizer if is_mistral_tokenizer(tokenizer) else None + + class SamplingParams( PydanticMsgspecMixin, msgspec.Struct, @@ -816,7 +820,10 @@ def _validate_structured_outputs( # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. - validate_guidance_grammar(self, tokenizer=None) + validate_guidance_grammar( + self, + tokenizer=_get_llg_tokenizer(tokenizer), + ) elif backend == "outlines": # outlines backend validate_structured_output_request_outlines(self) @@ -862,7 +869,10 @@ def _validate_structured_outputs( self.structured_outputs._backend = "outlines" else: # Fall back to guidance by default. - validate_guidance_grammar(self, tokenizer=None) + validate_guidance_grammar( + self, + tokenizer=_get_llg_tokenizer(tokenizer), + ) self.structured_outputs._backend = "guidance" # Remember that this backend was set automatically self.structured_outputs._backend_was_auto = True diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index 147dca88877b..ef58b1b75d68 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -54,6 +54,50 @@ logger = init_logger(__name__) +def _pop_unallowed_keys_and_warn( + dictionary: dict[str, Any], allowed_keys: set[str], err_dict_name: str +): + keys = list(dictionary.keys()) + for key in keys: + if key not in allowed_keys: + dictionary.pop(key) + logger.warning_once( + f"'{key=}' is not supported by mistral-common " + f"for {err_dict_name}. It has been popped from the " + "object." + ) + + +# TODO(juliendenize): remove this once OpenAI API is better supported by +# `mistral-common`. +def adapt_inplace_to_mistral_tool( + tool: dict[str, Any], +) -> dict[str, Any]: + tools_fields = set(Tool.model_fields.keys()) + function_fields = set(Function.model_fields.keys()) + + # The Mistral client, in comparison to the OpenAI client, requires the + # "parameters" dict and the "description" string to be present + # even if they are empty. + if function := tool.get("function"): + if function.get("parameters") is None: + function["parameters"] = {} + if function.get("description") is None: + function["description"] = "" + + _pop_unallowed_keys_and_warn( + dictionary=function, + allowed_keys=function_fields, + err_dict_name="function", + ) + + _pop_unallowed_keys_and_warn( + dictionary=tool, allowed_keys=tools_fields, err_dict_name="tools" + ) + + return tool + + def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"): # SEE: https://github.com/vllm-project/vllm/pull/9951 # Credits go to: @gcalmettes @@ -159,44 +203,11 @@ def _prepare_apply_chat_template_tools_and_messages( # Remove reasoning as unsupported by Mistral _ = message.pop("reasoning", None) # type: ignore - # The Mistral client, in comparison to the OpenAI client, requires the - # "parameters" dict and the "description" string to be present - # even if they are empty. - if tools: - for function in [ - tool["function"] for tool in tools if tool["type"] == "function" - ]: - if function.get("parameters") is None: - function["parameters"] = {} - if function.get("description") is None: - function["description"] = "" - - # We filter not supported arguments to avoid throwing an error. - # TODO(juliendenize): remove this once OpenAI API is better supported by - # `mistral-common`. - tools_fields = set(Tool.model_fields.keys()) - function_fields = set(Function.model_fields.keys()) - for tool in tools: - tool_keys = list(tool.keys()) - for tool_key in tool_keys: - if tool_key not in tools_fields: - tool.pop(tool_key) - logger.warning_once( - f"'{tool_key}' is not supported by mistral-common for tools. " - "It has been popped from the tool definition." - ) - if tool["type"] == "function": - function_keys = list(tool["function"].keys()) - for function_key in function_keys: - if function_key not in function_fields: - tool["function"].pop(function_key) - logger.warning_once( - f"'{function_key}' is not supported by mistral-common " - "for function tools. It has been popped from the " - "function definition." - ) - else: - raise ValueError("mistral-common only supports function tools.") + tools = ( + [adapt_inplace_to_mistral_tool(tool=tool) for tool in tools] + if tools is not None + else None + ) return messages, tools diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index 4d1aaffedd0e..1d2613104fc1 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import json from collections.abc import Sequence +from dataclasses import dataclass from enum import Enum, auto from random import choices from string import ascii_letters, digits -from typing import Any +from typing import TYPE_CHECKING, Any import ijson import regex as re @@ -37,14 +40,19 @@ ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger +from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer, adapt_inplace_to_mistral_tool from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) from vllm.utils.mistral import is_mistral_tokenizer +if TYPE_CHECKING: + from vllm.reasoning import ReasoningParser + logger = init_logger(__name__) ALPHANUMERIC = ascii_letters + digits @@ -86,13 +94,28 @@ def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool: return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11) -class MistralToolParser(ToolParser): +@dataclass +class MistralStreamingResult: + r"""Encapsulates the mutable state returned from + `MistralToolParser.extract_maybe_reasoning_and_tool_streaming`. """ - Tool call parser for Mistral 7B Instruct v0.3, intended for use with - - [`mistral_common`](https://github.com/mistralai/mistral-common/) - - the examples/tool_chat_template_mistral.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + delta_message: DeltaMessage | None + reasoning_ended: bool + tools_called: bool + current_text: str + current_token_ids: list[int] + + +class MistralToolParser(ToolParser): + r"""Tool call parser for Mistral models, intended for use with either: + + - `mistral_common `_ + (recommended) + - the `examples/tool_chat_template_mistral.jinja` template. + + Used when `--enable-auto-tool-choice --tool-call-parser mistral` are all + set. """ # Used to generate correct grammar in `adjust_request` @@ -210,9 +233,11 @@ def adjust_request( reasoning=self.model_can_reason ) - tools = ( + mistral_tools = ( [ - MistralTool.from_openai(openai_tool=tool.model_dump()) + MistralTool.model_validate( + adapt_inplace_to_mistral_tool(tool.model_dump()) + ) for tool in request.tools ] if request.tools is not None @@ -244,15 +269,158 @@ def adjust_request( lark_grammar = grammar_factory.get_lark_from_jinja( template=template, mode=tool_choice, - tools=tools, + tools=mistral_tools, json_schema=json_schema, parallel_tool_calls=request.parallel_tool_calls, json_only=False, ) request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar) + request._grammar_from_tool_parser = True return request + def extract_maybe_reasoning_and_tool_streaming( + self, + *, + reasoning_parser: ReasoningParser | None, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + output_token_ids: Sequence[int], + reasoning_ended: bool, + prompt_is_reasoning_end: bool | None, + request: ChatCompletionRequest, + ) -> MistralStreamingResult: + r"""Streaming extraction with reasoning followed by tool-call parsing. + + This method encapsulates the combined reasoning extraction and + tool-call streaming logic so that the serving layer only needs a + thin routing branch. + + The flow is: + + 1. If a *reasoning_parser* is present and reasoning has **not** ended, + extract reasoning tokens. Pre-v15 models may have pre-filled + `[THINK]...[/THINK]` in system prompts, so we skip the + prompt-level reasoning-end check for those. + 2. Once reasoning ends (or if there is no reasoning parser), delegate + to `extract_tool_calls_streaming` and track whether tools were + called. + + Args: + reasoning_parser: Optional reasoning parser instance. + previous_text: Accumulated text from prior chunks. + current_text: Full accumulated text including current chunk. + delta_text: New text in this chunk. + previous_token_ids: Token ids from prior chunks. + current_token_ids: Full token ids including current chunk. + output_token_ids: Raw output token ids from the engine. + reasoning_ended: Whether reasoning has already ended. + prompt_is_reasoning_end: Whether the prompt itself ends reasoning. + request: The originating chat completion request. + """ + delta_message: DeltaMessage | None = None + tools_called = False + reasoning_ended_at_entry = reasoning_ended + + # For MistralReasoningParser, only enter the reasoning block when + # the model has actually emitted a [THINK] token. Other reasoning + # parsers always expect thinking to be present. + expect_thinking = ( + not isinstance(reasoning_parser, MistralReasoningParser) + or reasoning_parser.start_token_id in current_token_ids + ) + if reasoning_parser is not None and not reasoning_ended and expect_thinking: + # Pre-v15 models may have pre-filled [THINK]...[/THINK] in + # system prompts, so skip the prompt-level reasoning-end + # check and wait for the output's own end-of-think. + is_pre_v15 = ( + isinstance(self.model_tokenizer, MistralTokenizer) + and self.model_tokenizer.version < 15 + ) + + if not is_pre_v15 and prompt_is_reasoning_end: + reasoning_ended = True + current_token_ids = list(output_token_ids) + else: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + ) + if reasoning_parser.is_reasoning_end_streaming( + current_token_ids, output_token_ids + ): + reasoning_ended = True + current_token_ids = reasoning_parser.extract_content_ids( + list(output_token_ids) + ) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + if not reasoning_ended: + return MistralStreamingResult( + delta_message=delta_message, + reasoning_ended=False, + tools_called=False, + current_text=current_text, + current_token_ids=current_token_ids, + ) + + delta_token_ids = list(output_token_ids) + + # On the iteration where reasoning just ended, reset the text/token + # state so the tool parser sees a clean history instead of the + # accumulated reasoning text. + if not reasoning_ended_at_entry and reasoning_ended: + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = self.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_called = True + + return MistralStreamingResult( + delta_message=delta_message, + reasoning_ended=reasoning_ended, + tools_called=tools_called, + current_text=current_text, + current_token_ids=current_token_ids, + ) + + @staticmethod + def build_non_streaming_tool_calls( + tool_calls: list[FunctionCall] | None, + ) -> list[ToolCall]: + r"""Build `MistralToolCall` items for non-streaming responses.""" + if not tool_calls: + return [] + + return [ + MistralToolCall(id=tc.id, function=tc) + if tc.id + else MistralToolCall(function=tc) + for tc in tool_calls + ] + def extract_tool_calls( self, model_output: str, @@ -323,7 +491,7 @@ def extract_tool_calls( )[0] tool_calls = json.loads(raw_tool_call) except (IndexError, json.JSONDecodeError): - logger.exception("Error in extracting tool call from response: {e}") + logger.exception("Error in extracting tool call from response.") # If raw decoding and decoding post regex rule fails, then just # return content. return ExtractedToolCallInformation(