Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct chat_template_caps {
bool requires_object_arguments = false;
// CohereForAI/c4ai-command-r-plus simple variant
bool requires_non_null_content = false;
// mistralai/Ministral-3-14B-Reasoning-2512
bool requires_non_empty_content = false;
// MiniMaxAI/MiniMax-Text-01 special
bool requires_typed_content = false;
};
Expand Down Expand Up @@ -171,13 +173,24 @@ class chat_template {
};
auto out_empty = render_with_content("");
auto out_null = render_with_content(json());
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);

auto out_nonempty = render_with_content(" ");
caps_.requires_non_empty_content = contains(out_nonempty, user_needle) && !contains(out_empty, user_needle) && !contains(out_null, user_needle);
caps_.requires_non_null_content = caps_.requires_non_empty_content || (contains(out_empty, user_needle) && !contains(out_null, user_needle));

json j_null;
auto assistant_content = [&](const json & content) {
if (content.is_null() && caps_.requires_non_null_content) {
return json("");
}
if ((content.is_null() || (content.is_string() && content.empty())) && caps_.requires_non_empty_content) {
return json(" ");
}
return content;
};
auto make_tool_calls_msg = [&](const json & tool_calls) {
return json {
{"role", "assistant"},
{"content", caps_.requires_non_null_content? "" : j_null},
{"content", assistant_content(j_null)},
{"tool_calls", tool_calls},
};
};
Expand Down Expand Up @@ -249,7 +262,7 @@ class chat_template {
};
const json tool_call_msg {
{"role", "assistant"},
{"content", caps_.requires_non_null_content ? "" : j_null},
{"content", assistant_content(j_null)},
{"tool_calls", json::array({
{
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
Expand Down
22 changes: 18 additions & 4 deletions scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class TemplateCaps:
supports_tool_call_id: bool = False
requires_object_arguments: bool = False
requires_non_null_content: bool = False
requires_non_empty_content: bool = False
requires_typed_content: bool = False

def to_json(self):
Expand All @@ -96,6 +97,7 @@ def to_json(self):
"supports_tool_call_id": self.supports_tool_call_id,
"requires_object_arguments": self.requires_object_arguments,
# "requires_non_null_content": self.requires_non_null_content,
# "requires_non_empty_content": self.requires_non_empty_content,
"requires_typed_content": self.requires_typed_content,
}, indent=2)

Expand Down Expand Up @@ -171,14 +173,26 @@ def __init__(self, template, env=None, filters=None, global_functions=None):
}])
caps.supports_tools = "some_tool" in out

caps.requires_non_null_content = \
(user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \
caps.requires_non_empty_content = \
(user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ' '}])) \
and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \
and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}]))
caps.requires_non_null_content = caps.requires_non_empty_content or (
(user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}]))
and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}]))
)

def assistant_content(content=None):
if content is None and caps.requires_non_null_content:
return ""
if not content and caps.requires_non_empty_content:
return " "
return content

def make_tool_calls_msg(tool_calls, content=None):
return {
"role": "assistant",
"content": "" if content is None and caps.requires_non_null_content else content,
"content": assistant_content(content),
"tool_calls": tool_calls,
}
def make_tool_call(tool_name, arguments):
Expand Down Expand Up @@ -243,7 +257,7 @@ def make_tool_call(tool_name, arguments):
args = {"arg1": "some_value"}
tool_call_msg = {
"role": "assistant",
"content": "" if caps.requires_non_null_content else None,
"content": assistant_content(),
"tool_calls": [
{
"id": "call_1___",
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ set(MODEL_IDS
mistralai/Mistral-Large-Instruct-2411
mistralai/Mistral-Nemo-Instruct-2407
mistralai/Mistral-Small-24B-Instruct-2501
mistralai/Ministral-3-14B-Reasoning-2512
mkurman/Qwen2.5-14B-DeepSeek-R1-1M
mlabonne/AlphaMonarch-7B
mlx-community/Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32
Expand Down
32 changes: 32 additions & 0 deletions tests/test-capabilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ static minja::chat_template_caps get_caps(const std::string &path)
print("supports_parallel_tool_calls", caps.supports_parallel_tool_calls);
print("requires_object_arguments", caps.requires_object_arguments);
print("requires_non_null_content", caps.requires_non_null_content);
print("requires_non_empty_content", caps.requires_non_empty_content);
// print("requires_non_null_content", caps.requires_non_null_content);
print("requires_typed_content", caps.requires_typed_content);
std::cout << "}\n" << std::endl;
Expand All @@ -75,6 +76,7 @@ TEST(CapabilitiesTest, Gemma7b) {
EXPECT_FALSE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -88,6 +90,7 @@ TEST(CapabilitiesTest, QwQ32B) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
EXPECT_TRUE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -101,6 +104,7 @@ TEST(CapabilitiesTest, Qwen3Coder) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -115,6 +119,7 @@ TEST(CapabilitiesTest, DeepSeekR1Distill) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}
#endif // _WIN32
Expand All @@ -129,6 +134,7 @@ TEST(CapabilitiesTest, FunctionaryMediumV3_2) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -142,6 +148,7 @@ TEST(CapabilitiesTest, MetaLlama3_1_8BInstruct) {
EXPECT_FALSE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -155,6 +162,7 @@ TEST(CapabilitiesTest, MetaLlama3_2_3BInstruct) {
EXPECT_FALSE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -168,6 +176,7 @@ TEST(CapabilitiesTest, MetaLlama3_3_70BInstruct) {
EXPECT_FALSE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -181,6 +190,7 @@ TEST(CapabilitiesTest, MiniMaxAIText01) {
EXPECT_FALSE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_TRUE(caps.requires_typed_content);
}

Expand All @@ -194,6 +204,7 @@ TEST(CapabilitiesTest, Mistral7BInstruct) {
EXPECT_FALSE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -207,6 +218,21 @@ TEST(CapabilitiesTest, MistralNemoInstruct) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

TEST(CapabilitiesTest, MistralMinistral3Reasoning) {
auto caps = get_caps("tests/mistralai-Ministral-3-14B-Reasoning-2512.jinja");
EXPECT_TRUE(caps.supports_system_role);
EXPECT_TRUE(caps.supports_tools);
EXPECT_TRUE(caps.supports_tool_calls);
EXPECT_FALSE(caps.supports_tool_call_id);
EXPECT_TRUE(caps.supports_tool_responses);
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_TRUE(caps.requires_non_null_content);
EXPECT_TRUE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -220,6 +246,7 @@ TEST(CapabilitiesTest, NousResearchHermes3Llama3_1_70BToolUse) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -233,6 +260,7 @@ TEST(CapabilitiesTest, NousResearchHermes2ProLlama3_8BToolUse) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -246,6 +274,7 @@ TEST(CapabilitiesTest, CommandRPlusDefault) {
EXPECT_FALSE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_TRUE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -259,6 +288,7 @@ TEST(CapabilitiesTest, CommandRPlusRag) {
EXPECT_FALSE(caps.supports_parallel_tool_calls);
EXPECT_FALSE(caps.requires_object_arguments);
EXPECT_TRUE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -272,6 +302,7 @@ TEST(CapabilitiesTest, CommandRPlusToolUse) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand All @@ -285,6 +316,7 @@ TEST(CapabilitiesTest, GLM46) {
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
EXPECT_FALSE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_non_empty_content);
EXPECT_FALSE(caps.requires_typed_content);
}

Expand Down
Loading