diff --git a/requirements/common.txt b/requirements/common.txt index 5d4519204ee9..652738eebe74 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -24,7 +24,7 @@ outlines_core == 0.2.14 # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 -xgrammar >= 0.1.32, < 1.0.0; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x" or platform_machine == "ppc64le" +xgrammar >= 0.2.0, < 1.0.0; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x" or platform_machine == "ppc64le" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/requirements/test/rocm.txt b/requirements/test/rocm.txt index 801af7db9db1..8445634ded40 100644 --- a/requirements/test/rocm.txt +++ b/requirements/test/rocm.txt @@ -42,6 +42,8 @@ anyio==4.13.0 # sse-starlette # starlette # watchfiles +apache-tvm-ffi==0.1.10 + # via xgrammar arctic-inference==0.1.1 # via -r requirements/test/rocm.in argcomplete==3.6.3 @@ -1264,6 +1266,7 @@ typing-extensions==4.15.0 # alembic # anthropic # anyio + # apache-tvm-ffi # azure-core # azure-identity # azure-storage-blob @@ -1345,7 +1348,7 @@ word2number==1.1 # via lm-eval wrapt==2.1.2 # via smart-open -xgrammar==0.1.33 +xgrammar==0.2.0 # via # -c requirements/common.txt # -r requirements/test/../common.txt diff --git a/tests/tool_parsers/test_deepseekv4_tool_parser.py b/tests/tool_parsers/test_deepseekv4_tool_parser.py index 631d0fb97b33..cc77a1f77756 100644 --- a/tests/tool_parsers/test_deepseekv4_tool_parser.py +++ b/tests/tool_parsers/test_deepseekv4_tool_parser.py @@ -6,6 +6,15 @@ import json from unittest.mock import MagicMock +import pytest +from xgrammar import StructuralTag + +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionNamedFunction, + ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, + ChatCompletionToolsParam, +) from vllm.tool_parsers import ToolParserManager from vllm.tool_parsers.deepseekv4_tool_parser import DeepSeekV4ToolParser @@ -20,6 +29,43 @@ PARAM_END = "" +@pytest.fixture +def sample_tools() -> list[ChatCompletionToolsParam]: + return [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + "state": {"type": "string", "description": "The state code"}, + "unit": {"type": "string", "enum": ["fahrenheit", "celsius"]}, + }, + "required": ["city", "state"], + }, + }, + ), + ChatCompletionToolsParam( + type="function", + function={ + "name": "calculate_area", + "description": "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": {"type": "string"}, + "dimensions": {"type": "object"}, + "precision": {"type": "integer"}, + }, + }, + }, + ), + ] + + def make_parser(tools=None) -> DeepSeekV4ToolParser: return DeepSeekV4ToolParser(MOCK_TOKENIZER, tools=tools) @@ -121,3 +167,39 @@ def test_streaming_extracts_complete_invokes(): ] assert names == ["search"] assert json.loads(reconstruct_args(deltas)) == {"query": "deepseek v4"} + + +def test_get_vllm_registry_structural_tag_returns_structural_tag( + sample_tools: list[ChatCompletionToolsParam], +) -> None: + parser = make_parser() + req = ChatCompletionRequest( + messages=[], + model="m", + tools=sample_tools, + tool_choice="auto", + ) + tag = parser.get_structural_tag(req) + assert isinstance(tag, StructuralTag) + + req = ChatCompletionRequest( + messages=[], + model="m", + tools=sample_tools, + tool_choice="required", + ) + tag = parser.get_structural_tag(req) + assert isinstance(tag, StructuralTag) + + if sample_tools: + tool = sample_tools[0] + req = ChatCompletionRequest( + messages=[], + model="m", + tools=sample_tools, + ) + req.tool_choice = ChatCompletionNamedToolChoiceParam( + function=ChatCompletionNamedFunction(name=tool.function.name) + ) + tag = parser.get_structural_tag(req) + assert isinstance(tag, StructuralTag) diff --git a/tests/tool_parsers/test_qwen3coder_tool_parser.py b/tests/tool_parsers/test_qwen3coder_tool_parser.py index c62e95830243..26bbf1a044bc 100644 --- a/tests/tool_parsers/test_qwen3coder_tool_parser.py +++ b/tests/tool_parsers/test_qwen3coder_tool_parser.py @@ -6,8 +6,11 @@ import pytest from openai.types.responses.function_tool import FunctionTool +from xgrammar import StructuralTag from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionNamedFunction, + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionToolsParam, ) @@ -108,6 +111,27 @@ def sample_tools(request): ] +def _as_chat_completion_tools( + tools: list[ChatCompletionToolsParam | FunctionTool], +) -> list[ChatCompletionToolsParam]: + normalized: list[ChatCompletionToolsParam] = [] + for tool in tools: + if isinstance(tool, ChatCompletionToolsParam): + normalized.append(tool) + else: + normalized.append( + ChatCompletionToolsParam( + type="function", + function={ + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + ) + ) + return normalized + + def assert_tool_calls( actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] ): @@ -1146,3 +1170,88 @@ def test_no_double_serialization_string_args(qwen3_tool_parser): args = json.loads(raw_arguments) assert args["message"] == "hello world" assert '\\"hello world\\"' not in raw_arguments + + +def test_get_vllm_registry_structural_tag_returns_structural_tag( + qwen3_tool_parser: Qwen3CoderToolParser, + sample_tools: list[ChatCompletionToolsParam], +) -> None: + request_tools = _as_chat_completion_tools(sample_tools) + req = ChatCompletionRequest( + messages=[], + model="m", + tools=request_tools, + tool_choice="auto", + ) + tag = qwen3_tool_parser.get_structural_tag(req) + assert isinstance(tag, StructuralTag) + + req = ChatCompletionRequest( + messages=[], + model="m", + tools=request_tools, + tool_choice="required", + ) + tag = qwen3_tool_parser.get_structural_tag(req) + assert isinstance(tag, StructuralTag) + + if request_tools: + tool = request_tools[0] + req = ChatCompletionRequest( + messages=[], + model="m", + tools=request_tools, + ) + req.tool_choice = ChatCompletionNamedToolChoiceParam( + function=ChatCompletionNamedFunction(name=tool.function.name) + ) + tag = qwen3_tool_parser.get_structural_tag(req) + assert isinstance(tag, StructuralTag) + + +@pytest.mark.parametrize("include_reasoning", [True, False]) +def test_adjust_request_auto_uses_vllm_registry_structural_tag( + monkeypatch: pytest.MonkeyPatch, + qwen3_tool_parser: Qwen3CoderToolParser, + sample_tools: list[ChatCompletionToolsParam], + include_reasoning: bool, +) -> None: + monkeypatch.setattr( + "vllm.tool_parsers.abstract_tool_parser.VLLM_ENFORCE_STRICT_TOOL_CALLING", + True, + ) + request_tools = _as_chat_completion_tools(sample_tools) + req = ChatCompletionRequest( + messages=[], + model="m", + tools=request_tools, + tool_choice="auto", + include_reasoning=include_reasoning, + ) + out = qwen3_tool_parser.adjust_request(req) + assert out.structured_outputs is not None + assert out.structured_outputs.structural_tag is not None + assert isinstance(out.structured_outputs.structural_tag, str) + loaded = json.loads(out.structured_outputs.structural_tag) + assert isinstance(loaded, dict) + + +def test_adjust_request_required_prefers_structural_tag( + monkeypatch: pytest.MonkeyPatch, + qwen3_tool_parser: Qwen3CoderToolParser, + sample_tools: list[ChatCompletionToolsParam], +) -> None: + monkeypatch.setattr( + "vllm.tool_parsers.abstract_tool_parser.VLLM_ENFORCE_STRICT_TOOL_CALLING", + True, + ) + request_tools = _as_chat_completion_tools(sample_tools) + req = ChatCompletionRequest( + messages=[], + model="m", + tools=request_tools, + tool_choice="required", + ) + out = qwen3_tool_parser.adjust_request(req) + assert out.structured_outputs is not None + assert out.structured_outputs.structural_tag is not None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9aac19e2fda5..da2ec10284c5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -321,6 +321,21 @@ async def init_app_state( supported_tasks: tuple["SupportedTask", ...] | None = None, ) -> None: vllm_config = engine_client.vllm_config + + # Propagate enable_in_reasoning to the API-server process. The engine core + # runs in a separate process, so the contextvar that backs + # `get_current_vllm_config_or_none()` is None on this stack. Tool parsers + # call `get_enable_structured_outputs_in_reasoning()` during request + # handling and need to see the real flag, otherwise they silently fall + # back to False and mismatch the engine-side bitmask gating. + from vllm.tool_parsers.structural_tag_registry import ( + set_enable_structured_outputs_in_reasoning, + ) + + set_enable_structured_outputs_in_reasoning( + vllm_config.structured_outputs_config.enable_in_reasoning + ) + if supported_tasks is None: warnings.warn( "The 'supported_tasks' parameter was not provided to " diff --git a/vllm/envs.py b/vllm/envs.py index 4191cd6a9743..acd5f7932f20 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -226,6 +226,7 @@ VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_SYSTEM_START_DATE: str | None = None VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False + VLLM_ENFORCE_STRICT_TOOL_CALLING: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True @@ -1591,6 +1592,12 @@ def _get_or_set_default() -> str: "VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY": lambda: bool( int(os.getenv("VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY", "0")) ), + # When 1,the model structural tags will be used to enforce the model + # output conforming to the model's tool-calling format and schema. + # Default 0 (off). + "VLLM_ENFORCE_STRICT_TOOL_CALLING": lambda: bool( + int(os.getenv("VLLM_ENFORCE_STRICT_TOOL_CALLING", "0")) + ), # Add optional custom scopes for profiling, disable to avoid overheads "VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool( int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0")) diff --git a/vllm/tool_parsers/abstract_tool_parser.py b/vllm/tool_parsers/abstract_tool_parser.py index 75181d8dfac6..c3438082a72d 100644 --- a/vllm/tool_parsers/abstract_tool_parser.py +++ b/vllm/tool_parsers/abstract_tool_parser.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib +import json import os from collections.abc import Callable, Sequence from functools import cached_property @@ -13,6 +14,7 @@ from openai.types.responses.function_tool import FunctionTool from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionToolsParam, ) @@ -23,6 +25,7 @@ from vllm.entrypoints.openai.responses.protocol import ( ResponsesRequest, ) +from vllm.envs import VLLM_ENFORCE_STRICT_TOOL_CALLING from vllm.logger import init_logger from vllm.sampling_params import ( StructuredOutputsParams, @@ -83,13 +86,39 @@ def vocab(self) -> dict[str, int]: return self.model_tokenizer.get_vocab() def adjust_request( - self, request: ChatCompletionRequest | ResponsesRequest + self, + request: ChatCompletionRequest | ResponsesRequest, ) -> ChatCompletionRequest | ResponsesRequest: - """ - Static method that used to adjust the request parameters. - """ + # If there are no tools, return the request as is. if not request.tools: return request + + # Step 1 (highest priority for ChatCompletionRequest): apply + # vLLM-owned structural tag support for model-specific tool formats. + if ( + isinstance(request, ChatCompletionRequest) + and VLLM_ENFORCE_STRICT_TOOL_CALLING + ): + need_tool_calling = ( + request.tool_choice == "auto" + or request.tool_choice == "required" + or isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) + ) + if need_tool_calling: + structure_tag = self.get_structural_tag(request) + if structure_tag is not None: + if request.structured_outputs is None: + request.structured_outputs = StructuredOutputsParams( + structural_tag=json.dumps(structure_tag.model_dump()), + ) + else: + request.structured_outputs.structural_tag = json.dumps( + structure_tag.model_dump() + ) + return request + + # Step 2: set structured output params when tool constraints are + # derived from the tool schema. json_schema_from_tool = get_json_schema_from_tools( tool_choice=request.tool_choice, tools=request.tools ) @@ -121,6 +150,9 @@ def adjust_request( return request + def get_structural_tag(self, request: ChatCompletionRequest): + return None + def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest ) -> ExtractedToolCallInformation: diff --git a/vllm/tool_parsers/deepseekv4_tool_parser.py b/vllm/tool_parsers/deepseekv4_tool_parser.py index 45a9c1302578..e32451cd8bbd 100644 --- a/vllm/tool_parsers/deepseekv4_tool_parser.py +++ b/vllm/tool_parsers/deepseekv4_tool_parser.py @@ -1,7 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser +from vllm.tool_parsers.structural_tag_registry import ( + get_enable_structured_outputs_in_reasoning, + get_model_structural_tag, +) class DeepSeekV4ToolParser(DeepSeekV32ToolParser): @@ -14,3 +21,11 @@ class DeepSeekV4ToolParser(DeepSeekV32ToolParser): tool_call_start_token: str = "<|DSML|tool_calls>" tool_call_end_token: str = "" + + def get_structural_tag(self, request: ChatCompletionRequest): + return get_model_structural_tag( + model="deepseek_v4", + tools=request.tools, + tool_choice=request.tool_choice, + reasoning=get_enable_structured_outputs_in_reasoning(), + ) diff --git a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py index 7b089ceffbc0..73850b2ab0c5 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -25,12 +25,18 @@ Tool, ToolParser, ) +from vllm.tool_parsers.structural_tag_registry import ( + get_enable_structured_outputs_in_reasoning, + get_model_structural_tag, +) from vllm.tool_parsers.utils import find_tool_properties logger = init_logger(__name__) class Qwen3CoderToolParser(ToolParser): + supports_required_and_named: bool = False + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) @@ -681,3 +687,11 @@ def extract_tool_calls_streaming( return result return None + + def get_structural_tag(self, request: ChatCompletionRequest): + return get_model_structural_tag( + model="qwen_3_5", + tools=request.tools, + tool_choice=request.tool_choice, + reasoning=get_enable_structured_outputs_in_reasoning(), + ) diff --git a/vllm/tool_parsers/structural_tag_registry.py b/vllm/tool_parsers/structural_tag_registry.py new file mode 100644 index 000000000000..754cc52361c5 --- /dev/null +++ b/vllm/tool_parsers/structural_tag_registry.py @@ -0,0 +1,330 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Model-specific structural tag builders adapted from XGrammar's +# builtin structural tag implementations: +# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/builtin_structural_tag.py + +from collections.abc import Callable +from typing import Any, Literal + +from xgrammar import StructuralTag +from xgrammar.structural_tag import ( + AnyTextFormat, + ConstStringFormat, + JSONSchemaFormat, + SequenceFormat, + TagFormat, + TagsWithSeparatorFormat, + TriggeredTagsFormat, +) + +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionToolsParam, +) + +SimplifiedToolChoice = Literal["auto", "required", "forced"] +ToolChoice = ( + Literal["none", "auto", "required"] | ChatCompletionNamedToolChoiceParam | None +) +StructuralTagBuilder = Callable[ + [list[ChatCompletionToolsParam], SimplifiedToolChoice, bool], + StructuralTag, +] + +_structural_tag_registry: dict[str, StructuralTagBuilder] = {} + + +def register_model_structural_tag(name: str): + """Register a vLLM-owned model-specific structural tag builder.""" + + def decorator(func: StructuralTagBuilder) -> StructuralTagBuilder: + _structural_tag_registry[name] = func + return func + + return decorator + + +def get_model_structural_tag( + model: str, + tools: list[ChatCompletionToolsParam] | None, + tool_choice: ToolChoice, + reasoning: bool, +) -> StructuralTag | None: + """Build a structural tag from vLLM-owned model-specific builders.""" + + builder = _structural_tag_registry.get(model) + if builder is None: + supported = list(_structural_tag_registry.keys()) + raise ValueError(f"Unknown format type: {model}, supported types: {supported}") + + normalized_tools, simplified_tool_choice = _normalize_tool_choice( + tools=tools, + tool_choice=tool_choice, + ) + if not normalized_tools: + return None + + return builder(normalized_tools, simplified_tool_choice, reasoning) + + +def _normalize_tool_choice( + tools: list[ChatCompletionToolsParam] | None, + tool_choice: ToolChoice, +) -> tuple[list[ChatCompletionToolsParam], SimplifiedToolChoice]: + """Normalize vLLM ChatCompletion tool_choice for structural tag builders.""" + + if not tools: + return [], "auto" + + if tool_choice is None or tool_choice == "none": + return [], "auto" + + if tool_choice == "auto": + return tools, "auto" + + if tool_choice == "required": + return tools, "required" + + if isinstance(tool_choice, ChatCompletionNamedToolChoiceParam): + tool_name = tool_choice.function.name + filtered_tools = [tool for tool in tools if tool.function.name == tool_name] + if not filtered_tools: + raise ValueError( + f"The tool with name '{tool_name}' is not found in the tools list." + ) + return filtered_tools, "forced" + + raise ValueError(f"Unsupported tool_choice for structural tag: {tool_choice}") + + +def _get_function_parameters(function: Any) -> dict[str, Any] | bool: + """Return the JSON schema used for constrained tool arguments.""" + + if getattr(function, "strict", None) is False: + return True + if function.parameters is None: + return True + return function.parameters + + +_enable_structured_outputs_in_reasoning: bool = False + + +def set_enable_structured_outputs_in_reasoning(enabled: bool) -> None: + """Publish the engine's ``enable_in_reasoning`` flag to tool parsers. + + Called once during APIServer startup so request-time parsers can read + it without going through the EngineCore-only contextvar. + """ + + global _enable_structured_outputs_in_reasoning + _enable_structured_outputs_in_reasoning = bool(enabled) + + +def get_enable_structured_outputs_in_reasoning() -> bool: + """Whether structured outputs are active during the reasoning phase. + + When ``True``, the structural tag will cover the reasoning part: + ``...`` prefix (if available); when ``False`` (default), the tag only + constrains the post-reasoning suffix. + """ + + return _enable_structured_outputs_in_reasoning + + +@register_model_structural_tag("deepseek_v4") +def get_deepseek_v4_structural_tag( + tools: list[ChatCompletionToolsParam], + tool_choice: SimplifiedToolChoice, + reasoning: bool, +) -> StructuralTag: + """Build DeepSeek V4 structural tags.""" + + invoke_begin_prefix = '<|DSML|invoke name="' + invoke_begin_suffix = '">\n' + invoke_end = "\n" + tool_calls_prefix = "\n\n" + function_calls_begin = "<|DSML|tool_calls>\n" + function_calls_end = "" + function_calls_trigger = "<|DSML|tool_calls>" + think_tag_end = "" + think_exclude_tokens = ["", ""] + xml_style = "deepseek_xml" + + if tool_choice == "auto": + tags = [] + for tool in tools: + function = tool.function + parameters = _get_function_parameters(function) + tags.append( + TagFormat( + begin=invoke_begin_prefix + function.name + invoke_begin_suffix, + content=JSONSchemaFormat( + json_schema=parameters, + style=xml_style, + ), + end=invoke_end, + ) + ) + + if tags: + function_calling_tags = TagsWithSeparatorFormat( + tags=tags, + separator="\n", + at_least_one=True, + ) + suffix_tag = TriggeredTagsFormat( + triggers=[function_calls_trigger], + tags=[ + TagFormat( + begin=function_calls_begin, + content=function_calling_tags, + end=function_calls_end, + ) + ], + excludes=think_exclude_tokens, + ) + else: + suffix_tag = AnyTextFormat(excludes=think_exclude_tokens) + + elif tool_choice == "forced": + if not tools: + raise ValueError("Forced tool choice must resolve to exactly one tool.") + function = tools[0].function + suffix_tag = SequenceFormat( + elements=[ + ConstStringFormat(value=tool_calls_prefix + function_calls_begin), + TagFormat( + begin=invoke_begin_prefix + function.name + invoke_begin_suffix, + content=JSONSchemaFormat( + json_schema=_get_function_parameters(function), + style=xml_style, + ), + end=invoke_end, + ), + ConstStringFormat(value=function_calls_end), + ] + ) + + elif tool_choice == "required": + tags = [] + for tool in tools: + function = tool.function + parameters = _get_function_parameters(function) + tags.append( + TagFormat( + begin=invoke_begin_prefix + function.name + invoke_begin_suffix, + content=JSONSchemaFormat( + json_schema=parameters, + style=xml_style, + ), + end=invoke_end, + ) + ) + assert len(tags) > 0 + suffix_tag = SequenceFormat( + elements=[ + ConstStringFormat(value=tool_calls_prefix + function_calls_begin), + TagsWithSeparatorFormat( + tags=tags, + separator="\n", + at_least_one=True, + ), + ConstStringFormat(value=function_calls_end), + ] + ) + + if not reasoning: + return StructuralTag(format=suffix_tag) + + prefix_tag = TagFormat(begin="", content=AnyTextFormat(), end=think_tag_end) + return StructuralTag(format=SequenceFormat(elements=[prefix_tag, suffix_tag])) + + +@register_model_structural_tag("qwen_3_5") +def get_qwen_3_5_structural_tag( + tools: list[ChatCompletionToolsParam], + tool_choice: SimplifiedToolChoice, + reasoning: bool, +) -> StructuralTag: + """Build Qwen XML structural tags. + + This format is used for Qwen3-Coder/Qwen3.5/Qwen3.6 and is compatible with + Qwen variants that use the same XML tool-call format. + """ + tool_call_begin_prefix = "\n", ""] + + if tool_choice == "auto": + tags = [] + for tool in tools: + function = tool.function + parameters = _get_function_parameters(function) + tags.append( + TagFormat( + begin=f"{tool_call_begin_prefix}{function.name}{tool_call_begin_suffix}", + content=JSONSchemaFormat(json_schema=parameters, style="qwen_xml"), + end=tool_call_end, + ) + ) + + if tags: + suffix_tag = TriggeredTagsFormat( + triggers=[tool_call_trigger], + tags=tags, + excludes=think_exclude_tokens, + ) + else: + suffix_tag = AnyTextFormat(excludes=think_exclude_tokens) + + elif tool_choice == "forced": + if not tools: + raise ValueError("Forced tool choice must resolve to exactly one tool.") + function = tools[0].function + suffix_tag = TagFormat( + begin=f"{tool_call_begin_prefix}{function.name}{tool_call_begin_suffix}", + content=JSONSchemaFormat( + json_schema=_get_function_parameters(function), + style="qwen_xml", + ), + end=tool_call_end, + ) + + elif tool_choice == "required": + tags = [] + for tool in tools: + function = tool.function + parameters = _get_function_parameters(function) + tags.append( + TagFormat( + begin=f"{tool_call_begin_prefix}{function.name}{tool_call_begin_suffix}", + content=JSONSchemaFormat(json_schema=parameters, style="qwen_xml"), + end=tool_call_end, + ) + ) + assert len(tags) > 0 + suffix_tag = TagsWithSeparatorFormat( + tags=tags, + separator="", + at_least_one=True, + ) + + if not reasoning: + result = StructuralTag(format=suffix_tag) + else: + prefix_tag = SequenceFormat( + elements=[ + TagFormat(begin="", content=AnyTextFormat(), end=think_tag_end), + ConstStringFormat(value=think_suffix), + ] + ) + result = StructuralTag(format=SequenceFormat(elements=[prefix_tag, suffix_tag])) + + return result