diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 58dd328b325a..8adee5a12b38 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -166,6 +166,32 @@ def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser): validate_parsed_serve_args(args) +def test_deepseek_v4_agentic_flags_pass_validation(monkeypatch): + import vllm.platforms as platforms + from vllm.platforms.cpu import CpuPlatform + + monkeypatch.setattr(platforms, "_current_platform", CpuPlatform()) + serve_parser = _build_vllm_parsers()["vllm serve"] + + args = serve_parser.parse_args( + args=[ + "--tokenizer-mode", + "deepseek_v4", + "--tool-call-parser", + "deepseek_v4", + "--enable-auto-tool-choice", + "--reasoning-parser", + "deepseek_v4", + ] + ) + + validate_parsed_serve_args(args) + assert args.tokenizer_mode == "deepseek_v4" + assert args.tool_call_parser == "deepseek_v4" + assert args.enable_auto_tool_choice + assert args.reasoning_parser == "deepseek_v4" + + def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser): """Ensure validation fails if reasoning is enabled with auto tool choice""" args = serve_parser.parse_args( diff --git a/tests/tokenizers_/test_deepseek_v4.py b/tests/tokenizers_/test_deepseek_v4.py index 358732eabf40..bcf21ff96ae7 100644 --- a/tests/tokenizers_/test_deepseek_v4.py +++ b/tests/tokenizers_/test_deepseek_v4.py @@ -183,6 +183,59 @@ def test_deepseek_v4_renders_parsed_history_tool_arguments(): assert 'parameter name="arguments"' not in prompt +def test_deepseek_v4_escapes_arguments_tool_schema_name(): + tools = [ + { + "type": "function", + "function": { + "name": "echo_args", + "description": "Echo arguments", + "parameters": { + "type": "object", + "properties": { + "arguments": {"type": "string"}, + }, + "required": ["arguments"], + }, + }, + } + ] + + prompt = _tokenizer().apply_chat_template( + [{"role": "user", "content": "Echo this"}], + tools=tools, + tokenize=False, + ) + + assert "__vllm_param_arguments__" in prompt + assert '"required": ["__vllm_param_arguments__"]' in prompt + assert '"arguments": {"type": "string"}' not in prompt + + +def test_deepseek_v4_escapes_arguments_history_tool_call_name(): + prompt = _tokenizer().apply_chat_template( + [ + {"role": "user", "content": "Echo this"}, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "echo_args", + "arguments": '{"arguments": "hello"}', + }, + } + ], + }, + ], + tokenize=False, + ) + + assert 'parameter name="__vllm_param_arguments__" string="true">hello' in prompt + assert 'parameter name="arguments"' not in prompt + + @pytest.mark.parametrize("reasoning_effort", ["minimal", "low", "medium", "high"]) def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort): prompt = _tokenizer().apply_chat_template( diff --git a/tests/tool_parsers/test_deepseekv4_tool_parser.py b/tests/tool_parsers/test_deepseekv4_tool_parser.py index cc77a1f77756..087ac16ceb90 100644 --- a/tests/tool_parsers/test_deepseekv4_tool_parser.py +++ b/tests/tool_parsers/test_deepseekv4_tool_parser.py @@ -76,6 +76,16 @@ def make_request(tools=None) -> MagicMock: return req +def make_tool(name: str, properties: dict[str, dict]) -> MagicMock: + tool = MagicMock() + tool.function.name = name + tool.function.parameters = { + "type": "object", + "properties": properties, + } + return tool + + def build_tool_call(func_name: str, params: dict[str, str]) -> str: param_strs = "".join( f'{PARAM_START}{k}" string="true">{v}{PARAM_END}\n' for k, v in params.items() @@ -86,6 +96,7 @@ def build_tool_call(func_name: str, params: dict[str, str]) -> str: def stream(parser: DeepSeekV4ToolParser, full_text: str, chunk_size: int = 7): deltas = [] previous_text = "" + request = make_request() for start in range(0, len(full_text), chunk_size): delta_text = full_text[start : start + chunk_size] current_text = previous_text + delta_text @@ -96,7 +107,7 @@ def stream(parser: DeepSeekV4ToolParser, full_text: str, chunk_size: int = 7): previous_token_ids=[], current_token_ids=[], delta_token_ids=[1], - request=make_request(), + request=request, ) previous_text = current_text if delta is not None: @@ -203,3 +214,127 @@ def test_get_vllm_registry_structural_tag_returns_structural_tag( ) tag = parser.get_structural_tag(req) assert isinstance(tag, StructuralTag) + + +def test_streaming_split_start_token_does_not_leak_dsml_markers(): + parser = make_parser() + full_text = "I will check." + build_tool_call("search", {"query": "vllm"}) + + deltas = stream(parser, full_text, chunk_size=1) + + content = "".join(delta.content or "" for delta in deltas) + assert content == "I will check." + assert "DSML" not in content + assert json.loads(reconstruct_args(deltas)) == {"query": "vllm"} + + +def test_streaming_plain_text_trailing_angle_bracket_is_flushed(): + parser = make_parser() + request = make_request() + previous_text = "2 <" + + delta = parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=previous_text, + delta_text="", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[1], + request=request, + ) + + assert delta is not None + assert delta.content == "2 <" + assert not delta.tool_calls + + +def test_extract_tool_calls_non_streaming_preserves_typed_arguments(): + parser = make_parser() + request = make_request( + [ + make_tool( + "plan_trip", + { + "days": {"type": "integer"}, + "flexible": {"type": "boolean"}, + "cities": {"type": "array"}, + "notes": {"type": "string"}, + }, + ) + ] + ) + model_output = ( + f"{TC_START}" + f'{INV_START}plan_trip">' + f'{PARAM_START}days" string="false">3{PARAM_END}' + f'{PARAM_START}flexible" string="false">false{PARAM_END}' + f'{PARAM_START}cities" string="false">["Beijing", "Shanghai"]{PARAM_END}' + f'{PARAM_START}notes" string="true">window seat{PARAM_END}' + f"{INV_END}" + f"{TC_END}" + ) + + result = parser.extract_tool_calls(model_output, request) + + assert result.tools_called + assert json.loads(result.tool_calls[0].function.arguments) == { + "days": 3, + "flexible": False, + "cities": ["Beijing", "Shanghai"], + "notes": "window seat", + } + + +def test_extract_tool_calls_repairs_arguments_wrapper_object(): + parser = make_parser() + request = make_request([make_tool("get_weather", {"location": {"type": "string"}})]) + model_output = ( + f"{TC_START}" + f'{INV_START}get_weather">' + f'{PARAM_START}arguments" string="false">{{"location": "Beijing"}}{PARAM_END}' + f"{INV_END}" + f"{TC_END}" + ) + + result = parser.extract_tool_calls(model_output, request) + + assert result.tools_called + assert json.loads(result.tool_calls[0].function.arguments) == { + "location": "Beijing" + } + + +def test_extract_tool_calls_repairs_input_wrapper_string(): + parser = make_parser() + request = make_request([make_tool("get_weather", {"location": {"type": "string"}})]) + model_output = ( + f"{TC_START}" + f'{INV_START}get_weather">' + f'{PARAM_START}input" string="true">{{"location": "Beijing"}}{PARAM_END}' + f"{INV_END}" + f"{TC_END}" + ) + + result = parser.extract_tool_calls(model_output, request) + + assert result.tools_called + assert json.loads(result.tool_calls[0].function.arguments) == { + "location": "Beijing" + } + + +def test_extract_tool_calls_unescapes_arguments_field_name(): + parser = make_parser() + request = make_request([make_tool("echo_args", {"arguments": {"type": "string"}})]) + model_output = ( + f"{TC_START}" + f'{INV_START}echo_args">' + f'{PARAM_START}__vllm_param_arguments__" string="true">hello{PARAM_END}' + f"{INV_END}" + f"{TC_END}" + ) + + result = parser.extract_tool_calls(model_output, request) + + assert result.tools_called + assert json.loads(result.tool_calls[0].function.arguments) == {"arguments": "hello"} diff --git a/vllm/tokenizers/deepseek_v4_encoding.py b/vllm/tokenizers/deepseek_v4_encoding.py index 6895771e2f59..c2e06971815f 100644 --- a/vllm/tokenizers/deepseek_v4_encoding.py +++ b/vllm/tokenizers/deepseek_v4_encoding.py @@ -62,6 +62,7 @@ "<{dsml_token}{tc_block_name}>\n{tool_calls}\n" ) tool_calls_block_name: str = "tool_calls" +ESCAPED_ARGUMENTS_PARAM_NAME = "__vllm_param_arguments__" tool_output_template: str = ( "{content}" @@ -117,6 +118,40 @@ def tools_from_openai_format(tools): return [tool["function"] for tool in tools] +def _escape_param_name(name: str) -> str: + if name == "arguments": + return ESCAPED_ARGUMENTS_PARAM_NAME + return name + + +def _unescape_param_name(name: str) -> str: + if name == ESCAPED_ARGUMENTS_PARAM_NAME: + return "arguments" + return name + + +def _escape_tool_schema(tool: Dict[str, Any]) -> Dict[str, Any]: + escaped_tool = copy.deepcopy(tool) + parameters = escaped_tool.get("parameters") + if not isinstance(parameters, dict): + return escaped_tool + + properties = parameters.get("properties") + if isinstance(properties, dict): + parameters["properties"] = { + _escape_param_name(key): value for key, value in properties.items() + } + + required = parameters.get("required") + if isinstance(required, list): + parameters["required"] = [ + _escape_param_name(name) if isinstance(name, str) else name + for name in required + ] + + return escaped_tool + + def tool_calls_from_openai_format(tool_calls): """Convert OpenAI-format tool calls to internal format.""" return [ @@ -155,15 +190,14 @@ def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str: p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' P_dsml_strs = [] - if isinstance(tool_call["arguments"], str): - arguments = json.loads(tool_call["arguments"]) - else: - arguments = tool_call["arguments"] + arguments = _normalize_tool_call_arguments(tool_call["arguments"]) + if not isinstance(arguments, dict): + return "" for k, v in arguments.items(): p_dsml_str = p_dsml_template.format( dsml_token=dsml_token, - key=k, + key=_escape_param_name(k), is_str="true" if isinstance(v, str) else "false", value=v if isinstance(v, str) else to_json(v), ) @@ -172,6 +206,39 @@ def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str: return "\n".join(P_dsml_strs) +def _normalize_tool_call_arguments(arguments: Any) -> Dict[str, Any] | None: + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + return None + + if not isinstance(arguments, dict): + return None + + if set(arguments.keys()) == {"input"}: + inner = arguments["input"] + if isinstance(inner, str): + try: + inner = json.loads(inner) + except json.JSONDecodeError: + return arguments + if isinstance(inner, dict): + arguments = inner + + if set(arguments.keys()) == {"arguments"}: + inner = arguments["arguments"] + if isinstance(inner, str): + try: + inner = json.loads(inner) + except json.JSONDecodeError: + return arguments + if isinstance(inner, dict): + return inner + + return arguments + + def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]: """ Decode DSML parameters back to a tool call dict. @@ -188,7 +255,7 @@ def _decode_value(key: str, value: str, string: str): value = to_json(value) return f"{to_json(key)}: {value}" - tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}" + tool_args_json = "{" + ", ".join([_decode_value(_unescape_param_name(k), v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}" return dict(name=tool_name, arguments=tool_args_json) @@ -202,7 +269,7 @@ def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str: Returns: Formatted tools section string. """ - tools_json = [to_json(t) for t in tools] + tools_json = [to_json(_escape_tool_schema(t)) for t in tools] return TOOLS_TEMPLATE.format( tool_schemas="\n".join(tools_json), diff --git a/vllm/tool_parsers/deepseekv4_tool_parser.py b/vllm/tool_parsers/deepseekv4_tool_parser.py index e32451cd8bbd..fa0b1c105fda 100644 --- a/vllm/tool_parsers/deepseekv4_tool_parser.py +++ b/vllm/tool_parsers/deepseekv4_tool_parser.py @@ -1,15 +1,32 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + +import regex as re + from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser from vllm.tool_parsers.structural_tag_registry import ( get_enable_structured_outputs_in_reasoning, get_model_structural_tag, ) +logger = init_logger(__name__) + +ESCAPED_ARGUMENTS_PARAM_NAME = "__vllm_param_arguments__" + class DeepSeekV4ToolParser(DeepSeekV32ToolParser): """ @@ -29,3 +46,242 @@ def get_structural_tag(self, request: ChatCompletionRequest): tool_choice=request.tool_choice, reasoning=get_enable_structured_outputs_in_reasoning(), ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.parameter_complete_regex = re.compile( + r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)', + re.DOTALL, + ) + + @staticmethod + def _function_name(tool) -> str | None: + if isinstance(tool, dict): + function = tool.get("function") + if isinstance(function, dict): + return function.get("name") + return getattr(function, "name", None) + return getattr(getattr(tool, "function", None), "name", None) + + @staticmethod + def _function_parameters(tool): + if isinstance(tool, dict): + function = tool.get("function") + if isinstance(function, dict): + return function.get("parameters") + return getattr(function, "parameters", None) + return getattr(getattr(tool, "function", None), "parameters", None) + + def _extract_param_name(self, param_name: str) -> str: + if param_name == ESCAPED_ARGUMENTS_PARAM_NAME: + return "arguments" + return param_name + + def _get_param_config( + self, + request: ChatCompletionRequest | None, + function_name: str | None, + ) -> dict[str, dict]: + if not request or not request.tools or not function_name: + return {} + + for tool in request.tools: + if self._function_name(tool) != function_name: + continue + params = self._function_parameters(tool) + if isinstance(params, dict): + properties = params.get("properties") + if isinstance(properties, dict): + return properties + return {} + + return {} + + def _coerce_param_value( + self, + value: str, + *, + string_attr: str, + param_type, + ): + if string_attr == "true": + return value + if param_type: + return self._convert_param_value(value, param_type) + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + @staticmethod + def _repair_param_dict( + param_dict: dict, + param_config: dict[str, dict], + ) -> dict: + allowed = set(param_config.keys()) + for wrapper in ("arguments", "input"): + if set(param_dict.keys()) != {wrapper} or wrapper in allowed: + continue + inner = param_dict[wrapper] + if isinstance(inner, str): + try: + inner = json.loads(inner) + except json.JSONDecodeError: + return param_dict + if isinstance(inner, dict) and set(inner.keys()).issubset(allowed): + return inner + return param_dict + + def _parse_invoke_params( + self, + invoke_str: str, + request: ChatCompletionRequest | None = None, + function_name: str | None = None, + ) -> dict: + param_config = self._get_param_config(request, function_name) + param_dict = {} + + for param_name, string_attr, param_val in self.parameter_complete_regex.findall( + invoke_str + ): + original_param_name = param_name + param_name = self._extract_param_name(param_name) + param_type = None + if ( + original_param_name == ESCAPED_ARGUMENTS_PARAM_NAME + and "arguments" in param_config + ): + param_type = param_config["arguments"].get("type") + elif param_name in param_config and isinstance( + param_config[param_name], dict + ): + param_type = param_config[param_name].get("type") + + param_dict[param_name] = self._coerce_param_value( + param_val, + string_attr=string_attr, + param_type=param_type, + ) + + return self._repair_param_dict(param_dict, param_config) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """Extract DeepSeek V4 DSML tool calls from complete model output.""" + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + tool_calls = [] + for tool_call_match in self.tool_call_complete_regex.findall(model_output): + for invoke_name, invoke_content in self.invoke_complete_regex.findall( + tool_call_match + ): + param_dict = self._parse_invoke_params( + invoke_content, + request=request, + function_name=invoke_name, + ) + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=invoke_name, + arguments=json.dumps(param_dict, ensure_ascii=False), + ), + ) + ) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + first_tool_idx = model_output.find(self.tool_call_start_token) + content = model_output[:first_tool_idx] if first_tool_idx > 0 else None + + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + + except Exception: + logger.exception("Error extracting DeepSeek V4 tool calls") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def _extract_delta_tool_calls( + self, + current_text: str, + request: ChatCompletionRequest | None, + ) -> list[DeltaToolCall]: + complete_invokes = self.invoke_complete_regex.findall(current_text) + delta_tool_calls: list[DeltaToolCall] = [] + + while len(complete_invokes) > self.current_tool_index: + invoke_name, invoke_body = complete_invokes[self.current_tool_index] + param_dict = self._parse_invoke_params( + invoke_body, + request=request, + function_name=invoke_name, + ) + args_json = json.dumps(param_dict, ensure_ascii=False) + idx = self.current_tool_index + self.current_tool_index += 1 + + self.prev_tool_call_arr.append( + {"name": invoke_name, "arguments": param_dict} + ) + self.streamed_args_for_tool.append(args_json) + + delta_tool_calls.append( + DeltaToolCall( + index=idx, + id=self._generate_tool_call_id(), + function=DeltaFunctionCall( + name=invoke_name, + arguments=args_json, + ), + type="function", + ) + ) + + return delta_tool_calls + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids, + current_token_ids, + delta_token_ids, + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if not previous_text: + self._reset_streaming_state() + + content = self._extract_content(current_text) + delta_tool_calls = self._extract_delta_tool_calls(current_text, request) + + if ( + not delta_text + and self.tool_call_start_token not in current_text + and self._sent_content_idx < len(current_text) + ): + held_content = current_text[self._sent_content_idx :] + self._sent_content_idx = len(current_text) + content = (content or "") + held_content + + if delta_tool_calls or content: + return DeltaMessage(content=content, tool_calls=delta_tool_calls) + + if not delta_text and delta_token_ids and self.prev_tool_call_arr: + return DeltaMessage(content="") + + return None