From 654172c690b6eb2aadeb11941afc7207574498ca Mon Sep 17 00:00:00 2001 From: iOptimizeThings Date: Wed, 20 May 2026 21:54:54 -0700 Subject: [PATCH] [Bugfix] Validate JSON in kimi_k2 tool call arguments Signed-off-by: iOptimizeThings --- tests/tool_parsers/test_kimi_k2_tool_parser.py | 10 +++++----- vllm/tool_parsers/kimi_k2_tool_parser.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/tool_parsers/test_kimi_k2_tool_parser.py b/tests/tool_parsers/test_kimi_k2_tool_parser.py index b56032b91c17..3a99a908ad93 100644 --- a/tests/tool_parsers/test_kimi_k2_tool_parser.py +++ b/tests/tool_parsers/test_kimi_k2_tool_parser.py @@ -148,8 +148,8 @@ def test_extract_tool_calls( # id format: "something:digit" assert tc.id.split(":")[-1].isdigit() - def test_invalid_json_still_extracted(self, parser): - """Tool calls with invalid JSON are still returned (arguments as-is).""" + def test_invalid_json_skipped(self, parser): + """Tool calls with invalid JSON are skipped; valid ones kept.""" model_output = ( "Help. " + SECTION_BEGIN @@ -158,9 +158,9 @@ def test_invalid_json_still_extracted(self, parser): + SECTION_END ) content, tool_calls = run_tool_extraction(parser, model_output, streaming=False) - assert len(tool_calls) == 2 - assert tool_calls[0].function.name == "bad" - assert tool_calls[1].function.name == "good" + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "good" + assert json.loads(tool_calls[0].function.arguments) == {"city": "Shanghai"} def test_invalid_funcall_id_skipped(self, parser): """Tool calls with malformed id (no colon+digit) are skipped.""" diff --git a/vllm/tool_parsers/kimi_k2_tool_parser.py b/vllm/tool_parsers/kimi_k2_tool_parser.py index 7ddd8fa7a80d..e1e20b6df596 100644 --- a/vllm/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/tool_parsers/kimi_k2_tool_parser.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json from collections.abc import Sequence import regex as re @@ -96,6 +97,15 @@ def extract_tool_calls( function_id, function_args = match # function_id: functions.get_weather:0 or get_weather:0 function_name = function_id.split(":")[0].split(".")[-1] + try: + json.loads(function_args) + except (json.JSONDecodeError, TypeError): + logger.warning( + "Skipping tool call '%s' with invalid JSON: %s", + function_name, + function_args, + ) + continue tool_calls.append( ToolCall( id=function_id,