diff --git a/common/chat.cpp b/common/chat.cpp index 938872e82ee1d..93eadfdb19f29 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2269,6 +2269,13 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, extra_context); data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; + auto supports_thinking = tmpl.source().find("") != std::string::npos; + + // you should not be able to call enable_thinking if is not supported + if (!supports_thinking && extra_context["enable_thinking"]) { + extra_context["enable_thinking"] = false; + } + if (string_ends_with(data.prompt, "\n")) { if (!extra_context["enable_thinking"]) { data.prompt += ""; @@ -2331,9 +2338,27 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat tool_call_alts.push_back( "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); + + // thinking grammar logic depending on if thinking_forced_open was to true (so already opened (and maybe closed)) and if thinking is even allowed + if (extra_context["enable_thinking"]) { + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + data.thinking_forced_open ? "" : "" + }); + std::string prelude = ""; + if (!data.thinking_forced_open) { + prelude = builder.add_rule("think-start", "\"\""); + } + prelude += " "; + prelude += builder.add_rule("think-content", "( [^<] | \"<\" [^/] | \"] )*"); + prelude += " "; + prelude += builder.add_rule("think-end", "\"\" space"); + prelude += " "; + builder.add_rule("root", prelude + "(" + tool_call + ")" + (inputs.parallel_tool_calls ? "*" : "?")); + } else { + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + } + // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, diff --git a/models/templates/Qwen-Qwen3-0.6B.jinja b/models/templates/Qwen-Qwen3-0.6B.jinja index 699ff8df401fe..01be9b307daa2 100644 --- a/models/templates/Qwen-Qwen3-0.6B.jinja +++ b/models/templates/Qwen-Qwen3-0.6B.jinja @@ -17,23 +17,27 @@ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} {%- for message in messages[::-1] %} {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} {%- set ns.multi_step_tool = false %} {%- set ns.last_query_index = index %} {%- endif %} {%- endfor %} {%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} {%- elif message.role == "assistant" %} - {%- set content = message.content %} {%- set reasoning_content = '' %} - {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- if message.reasoning_content is string %} {%- set reasoning_content = message.reasoning_content %} {%- else %} - {%- if '' in message.content %} - {%- set content = message.content.split('')[-1].lstrip('\n') %} - {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} {%- endif %} {%- endif %} {%- if loop.index0 > ns.last_query_index %} @@ -70,7 +74,7 @@ {{- '<|im_start|>user' }} {%- endif %} {{- '\n\n' }} - {{- message.content }} + {{- content }} {{- '\n' }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} {{- '<|im_end|>\n' }} diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 4a8ba849b3f8c..619bc1f1062b9 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1111,6 +1111,68 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); } + { + auto tmpls = read_templates("models/templates/Qwen-Qwen3-0.6B.jinja"); + std::vector end_tokens{ "<|im_end|>" }; + + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + + // Test that enable_thinking=false adds empty think tags + { + common_chat_templates_inputs inputs_no_thinking; + inputs_no_thinking.messages = {message_user}; + inputs_no_thinking.tools = tools; + inputs_no_thinking.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED; + inputs_no_thinking.enable_thinking = false; + + auto params = common_chat_templates_apply(tmpls.get(), inputs_no_thinking); + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, params.format); + // Verify the prompt contains empty think tags when thinking is disabled + assert_equals(true, params.prompt.find("\n\n") != std::string::npos); + } + + // Test that grammar allows thinking with REQUIRED tool choice + { + common_chat_templates_inputs inputs_with_thinking; + inputs_with_thinking.messages = {message_user}; + inputs_with_thinking.tools = tools; + inputs_with_thinking.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED; + inputs_with_thinking.enable_thinking = true; + + auto params = common_chat_templates_apply(tmpls.get(), inputs_with_thinking); + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, params.format); + + // The key fix: grammar should contain the thinking pattern even with REQUIRED + assert_equals(false, params.grammar.empty()); + assert_equals(true, params.grammar.find("") != std::string::npos); + + // Grammar should allow thinking before tool calls + assert_equals(true, params.grammar.find("think-") != std::string::npos || + params.grammar.find("") != std::string::npos); + } + + // Test parsing: tool call with thinking works correctly + assert_msg_equals(message_assist_call_thoughts, + common_chat_parse( + "I'm\nthinking\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test that reasoning + tool calls work in template generation + test_templates(tmpls.get(), end_tokens, message_assist_call_thoughts, tools, + "", // Don't check exact delta, just verify it parses correctly + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + COMMON_REASONING_FORMAT_DEEPSEEK); + + // Verify enable_thinking support + assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); + } { auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index b8f0f10863fb8..b191840cd2adf 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -623,3 +623,62 @@ def do_test_hello_world(server: ServerProcess, **kwargs): code = actual_arguments["code"] assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}' + + + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("tool,hf_repo,template_override,reasoning_format", [ + (PYTHON_TOOL, "unsloth/Qwen3-0.6B-GGUF:Q4_K_M", None, 'deepseek'), + (TEST_TOOL, "unsloth/Qwen3-0.6B-GGUF:Q4_K_M", None, 'deepseek'), +]) +def test_required_tool_with_reasoning(tool: dict, hf_repo: str, template_override: str | Tuple[str, str | None] | None, reasoning_format: Literal['deepseek', 'none'], stream: CompletionMode): + global server + n_predict = 512 + + # Set the reasoning format + server.reasoning_format = reasoning_format + + server.jinja = True + server.n_ctx = 8192 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + + + server.start(timeout_seconds=TIMEOUT_START_SLOW) + + # Make the request with "tool_choice": "required" + body = server.make_any_request("POST", "/v1/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, # This prompt will force the tool use + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "stream": stream == CompletionMode.STREAMED, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, timeout=TIMEOUT_HTTP_REQUEST) + + choice = body["choices"][0] + + + reasoning_content:str = choice["message"].get("reasoning_content") + assert reasoning_content is not None, 'Expected reasoning content, but got None' + assert len(reasoning_content.strip()) > 3, 'Reasoning content is too small to be credible' + + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + + actual_arguments = json.loads(tool_call["function"]["arguments"]) + if tool is PYTHON_TOOL: + assert "code" in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: 'code'" + elif tool is TEST_TOOL: + assert "success" in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: 'success'" \ No newline at end of file