diff --git a/common/chat.cpp b/common/chat.cpp index aba26e97a10..50d818dfbd4 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2780,6 +2780,15 @@ static void system_message_not_supported(json & messages) { } } +static void requires_non_null_content(json & messages) { + GGML_ASSERT(messages.is_array()); + for (auto & message : messages) { + if (message.contains("tool_calls") && !message.contains("content")) { + message["content"] = ""; + } + } +} + static void func_args_not_string(json & messages) { GGML_ASSERT(messages.is_array()); for (auto & message : messages) { @@ -2885,6 +2894,13 @@ static common_chat_params common_chat_templates_apply_jinja( workaround::system_message_not_supported(params.messages); } + if (tmpl.original_caps().supports_tool_calls) { + // some templates will require the content field in tool call messages + // to still be non-null, this puts an empty string everywhere where the + // content field is null + workaround::requires_non_null_content(params.messages); + } + params.extra_context = json::object(); for (auto el : inputs.chat_template_kwargs) { params.extra_context[el.first] = json::parse(el.second); diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index f27490f1fb7..6b6c095be6e 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -160,7 +160,7 @@ caps caps_get(jinja::program & prog) { {"content", "Assistant message"}, {"tool_calls", json::array({ { - {"id", "call1"}, + {"id", "call00001"}, {"type", "function"}, {"function", { {"name", "tool1"}, @@ -170,10 +170,10 @@ caps caps_get(jinja::program & prog) { }} }, { - {"id", "call2"}, + {"id", "call00002"}, {"type", "function"}, {"function", { - {"name", "tool2"}, + {"name", "tool1"}, {"arguments", { {"arg", "value"} }} @@ -194,7 +194,7 @@ caps caps_get(jinja::program & prog) { {"name", "tool"}, {"type", "function"}, {"function", { - {"name", "tool"}, + {"name", "tool1"}, {"description", "Tool description"}, {"parameters", { {"type", "object"}, diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index 54d3a0923bd..314a31d7d01 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -949,6 +949,19 @@ static void test_tests(testing & t) { {{"x", {{"a", 1}}}}, "yes" ); + + test_template(t, "something in undefined", + "{% if x in y %}yes{% else %}no{% endif %}", + {{"x", 1}}, + "no" + ); + + test_template(t, "null is undefined", + "{% if null is not defined %}yes{% else %}no{% endif %}", + json::object(), + "yes" + ); + } static void test_string_methods(testing & t) {