Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,27 @@ class chat_template {
}},
};
};
auto make_tool_call_response = [](const std::string & tool_call_id, const std::string & tool_name, const std::string & content) {
return json {
{"role", "tool"},
{"name", tool_name},
{"content", content},
{"tool_call_id", tool_call_id},
};
};
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};

// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
make_tool_call_response("call_1___", "ipython", "Hello, World!"),
}), {}, false);
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
make_tool_call_response("call_1___", "ipython", "Hello, World!"),
}), {}, false);
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");

Expand All @@ -215,18 +225,14 @@ class chat_template {
auto out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({tc1, tc2})),
dummy_user_msg,
}), {}, false);
caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");

out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({tc1})),
{
{"role", "tool"},
{"name", "test_tool1"},
{"content", "Some response!"},
{"tool_call_id", "call_911_"},
}
make_tool_call_response("call_911_", "test_tool1", "Some response!"),
}), {}, false);
caps_.supports_tool_responses = contains(out, "Some response!");
caps_.supports_tool_call_id = contains(out, "call_911_");
Expand Down
17 changes: 11 additions & 6 deletions scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,26 @@ def make_tool_call(tool_name, arguments):
"name": tool_name,
}
}
def make_tool_call_response(tool_call_id, tool_name, content):
return {
"role": "tool",
"name": tool_name,
"content": content,
"tool_call_id": tool_call_id,
}

dummy_args_obj = {"argument_needle": "print('Hello, World!')"}

out = self.try_raw_render([
dummy_user_msg,
make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]),
make_tool_call_response("call_1___", "ipython", "Hello, world!"),
])
tool_call_renders_str_arguments = "<parameter=argument_needle>" in out or '"argument_needle":' in out or "'argument_needle':" in out
out = self.try_raw_render([
dummy_user_msg,
make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]),
make_tool_call_response("call_1___", "ipython", "Hello, world!"),
])
tool_call_renders_obj_arguments = "<parameter=argument_needle>" in out or '"argument_needle":' in out or "'argument_needle':" in out

Expand All @@ -209,18 +218,14 @@ def make_tool_call(tool_name, arguments):
out = self.try_raw_render([
dummy_user_msg,
make_tool_calls_msg([tc1, tc2]),
dummy_user_msg,
])
caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out

out = self.try_raw_render([
dummy_user_msg,
make_tool_calls_msg([tc1]),
{
"role": "tool",
"name": "test_tool1",
"content": "Some response!",
"tool_call_id": "call_911_",
}
make_tool_call_response("call_911_", "test_tool1", "Some response!"),
])
caps.supports_tool_responses = "Some response!" in out
caps.supports_tool_call_id = "call_911_" in out
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ set(MODEL_IDS
nvidia/Eagle2-1B
nvidia/Eagle2-9B
nvidia/Llama-3.1-Nemotron-70B-Instruct-HF
nvidia/NVIDIA-Nemotron-Nano-9B-v2
OnlyCheeini/greesychat-turbo
onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX
open-thoughts/OpenThinker-7B
Expand Down
12 changes: 12 additions & 0 deletions tests/test-capabilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ TEST(CapabilitiesTest, NousResearchHermes2ProLlama3_8BToolUse) {
EXPECT_FALSE(caps.requires_typed_content);
}

TEST(CapabilitiesTest, NvidiaNemotronNano_9BToolUse) {
auto caps = get_caps("tests/nvidia-NVIDIA-Nemotron-Nano-9B-v2.jinja");
EXPECT_TRUE(caps.supports_system_role);
EXPECT_TRUE(caps.supports_tools);
EXPECT_TRUE(caps.supports_tool_calls);
EXPECT_TRUE(caps.supports_tool_responses);
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
// EXPECT_TRUE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_typed_content);
}

TEST(CapabilitiesTest, CommandRPlusDefault) {
auto caps = get_caps("tests/CohereForAI-c4ai-command-r-plus-default.jinja");
EXPECT_TRUE(caps.supports_system_role);
Expand Down