@@ -17,17 +17,26 @@ using json = nlohmann::ordered_json;
1717
1818namespace minja {
1919
20+ struct chat_template_caps {
21+ bool supports_tools = false ;
22+ bool supports_tool_calls = false ;
23+ bool supports_tool_responses = false ;
24+ bool supports_system_role = false ;
25+ bool supports_parallel_tool_calls = false ;
26+ bool supports_tool_call_id = false ;
27+ // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
28+ // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
29+ bool requires_object_arguments = false ;
30+ // CohereForAI/c4ai-command-r-plus simple variant
31+ bool requires_non_null_content = false ;
32+ // MiniMaxAI/MiniMax-Text-01 special
33+ bool requires_typed_content = false ;
34+ };
35+
2036class chat_template {
21- public:
2237
2338 private:
24- bool supports_tools_ = true ;
25- // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
26- // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
27- bool requires_object_arguments_ = false ;
28- bool requires_typed_content_ = false ;
29- bool supports_system_role_ = true ;
30- bool supports_parallel_tool_calls_ = false ;
39+ chat_template_caps caps_;
3140 std::string source_;
3241 std::string bos_token_;
3342 std::string eos_token_;
@@ -41,15 +50,16 @@ class chat_template {
4150 {
4251 try {
4352 auto prompt = apply (messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false );
44- // fprintf(stderr, "Prompt : %s\n", prompt.c_str());
53+ // fprintf(stderr, "try_raw_render : %s\n", prompt.c_str());
4554 return prompt;
4655 } catch (const std::exception & e) {
47- // fprintf(stderr, "Error : %s\n", e.what());
56+ // fprintf(stderr, "try_raw_render error : %s\n", e.what());
4857 return " " ;
4958 }
5059 }
5160
5261 public:
62+
5363 chat_template (const std::string & source, const std::string & bos_token, const std::string & eos_token)
5464 : source_(source), bos_token_(bos_token), eos_token_(eos_token)
5565 {
@@ -58,69 +68,120 @@ class chat_template {
5868 /* .lstrip_blocks = */ true ,
5969 /* .keep_trailing_newline = */ false ,
6070 });
61- supports_tools_ = source.find (" tools" ) != std::string::npos;
6271
63- auto renders_string_arguments =
64- try_raw_render ({
65- {
66- {" role" , " user" },
67- {" content" , " Hey" }
68- },
69- {
70- {" role" , " assistant" },
71- {" tool_calls" , json::array ({
72- {
73- {" id" , " call_1___" },
74- {" type" , " function" },
75- {" function" , {
76- {" arguments" , " {\" code\" : \" print('Hello, World!')\" }" },
77- {" name" , " ipython" },
72+ auto contains = [](const std::string & haystack, const std::string & needle) {
73+ return haystack.find (needle) != std::string::npos;
74+ };
75+
76+ const std::string user_needle = " <User Needle>" ;
77+ const std::string sys_needle = " <System Needle>" ;
78+ const json dummy_str_user_msg = {{" role" , " user" }, {" content" , user_needle}};
79+ const json dummy_typed_user_msg = {{" role" , " user" }, {" content" , json::array ({{{" type" , " text" }, {" text" , user_needle}}})}};
80+
81+ caps_.requires_typed_content =
82+ !contains (try_raw_render (json::array ({dummy_str_user_msg}), {}, false ), user_needle)
83+ && contains (try_raw_render (json::array ({dummy_typed_user_msg}), {}, false ), user_needle);
84+
85+ const auto dummy_user_msg = caps_.requires_typed_content
86+ ? dummy_typed_user_msg
87+ : dummy_str_user_msg;
88+ const json needle_system_msg = {
89+ {" role" , " system" },
90+ {" content" , caps_.requires_typed_content ? json::array ({{{" type" , " text" }, {" text" , sys_needle}}}) : json (sys_needle)},
91+ };
92+
93+ caps_.supports_system_role = contains (try_raw_render ({needle_system_msg, dummy_user_msg,}, {}, false ), sys_needle);
94+
95+ auto out = try_raw_render (json::array ({
96+ dummy_user_msg
97+ }), json::array ({
98+ {
99+ {" name" , " some_tool" },
100+ {" type" , " function" },
101+ {" function" , {
102+ {" name" , " some_tool" },
103+ {" description" , " Some tool." },
104+ {" parameters" , {
105+ {" type" , " object" },
106+ {" properties" , {
107+ {" arg" , {
108+ {" type" , " string" },
109+ {" description" , " Some argument." },
78110 }},
79- },
80- })},
81- }
82- }, {}, false ).find (" {\" code\" : \" print" ) != std::string::npos;
83- if (!renders_string_arguments) {
84- auto renders_object_arguments =
85- try_raw_render ({
86- {
87- {" role" , " user" },
88- {" content" , " Hey" }
89- },
90- {
91- {" role" , " assistant" },
92- {" tool_calls" , json::array ({
93- {
94- {" id" , " call_1___" },
95- {" type" , " function" },
96- {" function" , {
97- {" arguments" , {
98- {" code" , " print('Hello, World!')" },
99- }},
100- {" name" , " ipython" },
101- }},
102- },
103- })},
104- }
105- }, {}, false ).find (" {\" code\" : \" print" ) != std::string::npos;
106- requires_object_arguments_ = renders_object_arguments;
107- }
108- supports_parallel_tool_calls_ = source.find (" tool_call_id" ) != std::string::npos;
111+ }},
112+ {" required" , json::array ({ " arg" })},
113+ }},
114+ }},
115+ },
116+ }), false );
117+ caps_.supports_tools = contains (out, " some_tool" );
109118
110- supports_system_role_ = try_raw_render ({
111- {{" role" , " system" }, {" content" , " <System Needle>" }},
112- {{" role" , " user" }, {" content" , " Hey" }}
113- }, {}, false ).find (" <System Needle>" ) != std::string::npos;
119+ auto make_tool_calls_msg = [&](const json & tool_calls) {
120+ return json {
121+ {" role" , " assistant" },
122+ {" content" , nullptr },
123+ {" tool_calls" , tool_calls},
124+ };
125+ };
126+ auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
127+ return json {
128+ {" id" , " call_1___" },
129+ {" type" , " function" },
130+ {" function" , {
131+ {" arguments" , arguments},
132+ {" name" , tool_name},
133+ }},
134+ };
135+ };
136+ const json dummy_args_obj {{" argument_needle" , " print('Hello, World!')" }};
137+
138+ // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
139+ out = try_raw_render (json::array ({
140+ dummy_user_msg,
141+ make_tool_calls_msg (json::array ({make_tool_call (" ipython" , dummy_args_obj.dump ())})),
142+ }), {}, false );
143+ auto tool_call_renders_str_arguments = contains (out, " \" argument_needle\" :" ) || contains (out, " 'argument_needle':" );
144+ out = try_raw_render (json::array ({
145+ dummy_user_msg,
146+ make_tool_calls_msg (json::array ({make_tool_call (" ipython" , dummy_args_obj)})),
147+ }), {}, false );
148+ auto tool_call_renders_obj_arguments = contains (out, " \" argument_needle\" :" ) || contains (out, " 'argument_needle':" );
149+
150+ caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
151+ caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
152+ auto out_empty = try_raw_render (json::array ({dummy_user_msg, {{" role" , " assistant" }, {" content" , " " }}}), {}, false );
153+ auto out_null = try_raw_render (json::array ({dummy_user_msg, {{" role" , " assistant" }, {" content" , nullptr }}}), {}, false );
154+ caps_.requires_non_null_content = contains (out_empty, user_needle) && !contains (out_null, user_needle);
155+
156+ if (caps_.supports_tool_calls ) {
157+ auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json (dummy_args_obj.dump ());
158+ auto tc1 = make_tool_call (" test_tool1" , dummy_args);
159+ auto tc2 = make_tool_call (" test_tool2" , dummy_args);
160+ auto out = try_raw_render (json::array ({
161+ dummy_user_msg,
162+ make_tool_calls_msg (json::array ({tc1, tc2})),
163+ }), {}, false );
164+ caps_.supports_parallel_tool_calls = contains (out, " test_tool1" ) && contains (out, " test_tool2" );
114165
115- requires_typed_content_ = try_raw_render ({{{" role" , " user" }, {" content" , " Hey" }}}, {}, false ).find (" Hey" ) == std::string::npos
116- && try_raw_render ({{{" role" , " user" }, {" content" , {{{" type" , " text" }, {" text" , " Hey" }}}}}}, {}, false ).find (" Hey" ) != std::string::npos;
166+ out = try_raw_render (json::array ({
167+ dummy_user_msg,
168+ make_tool_calls_msg (json::array ({tc1})),
169+ {
170+ {" role" , " tool" },
171+ {" name" , " test_tool1" },
172+ {" content" , " Some response!" },
173+ {" tool_call_id" , " call_911_" },
174+ }
175+ }), {}, false );
176+ caps_.supports_tool_responses = contains (out, " Some response!" );
177+ caps_.supports_tool_call_id = contains (out, " call_911_" );
178+ }
117179 }
118180
119181 const std::string & source () const { return source_; }
120182 const std::string & bos_token () const { return bos_token_; }
121183 const std::string & eos_token () const { return eos_token_; }
122- bool supports_tools () const { return supports_tools_; }
123- bool supports_parallel_tool_calls () const { return supports_parallel_tool_calls_; }
184+ const chat_template_caps & original_caps () const { return caps_; }
124185
125186 std::string apply (
126187 const nlohmann::ordered_json & messages,
@@ -131,13 +192,19 @@ class chat_template {
131192 {
132193 json actual_messages;
133194
134- // First, "fix" messages so they have a chance to be rendered correctly by the template
135-
136- if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
195+ auto needs_adjustments = adjust_inputs && (false
196+ || !caps_.supports_system_role
197+ || !caps_.supports_tools
198+ || !caps_.supports_tool_responses
199+ || !caps_.supports_tool_calls
200+ || caps_.requires_object_arguments
201+ || caps_.requires_typed_content
202+ );
203+ if (needs_adjustments) {
137204 actual_messages = json::array ();
138205
139206 auto add_message = [&](const json & msg) {
140- if (requires_typed_content_ && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
207+ if (caps_. requires_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
141208 actual_messages.push_back ({
142209 {" role" , msg.at (" role" )},
143210 {" content" , {{
@@ -160,24 +227,32 @@ class chat_template {
160227 pending_system.clear ();
161228 }
162229 };
163- for (const auto & message_ : messages) {
230+ auto needs_tools_in_system = !tools.is_null () && tools.size () > 0 && !caps_.supports_tools ;
231+
232+ for (const auto & message_ : needs_tools_in_system ? add_system (messages, " Available tools: " + tools.dump (2 )) : messages) {
164233 auto message = message_;
165234 if (!message.contains (" role" ) || !message.contains (" content" )) {
166235 throw std::runtime_error (" message must have 'role' and 'content' fields: " + message.dump ());
167236 }
168237 std::string role = message.at (" role" );
169238
170239 if (message.contains (" tool_calls" )) {
171- if (requires_object_arguments_ || !supports_tools_ ) {
240+ if (caps_. requires_object_arguments || !caps_. supports_tool_calls ) {
172241 for (auto & tool_call : message.at (" tool_calls" )) {
173242 if (tool_call[" type" ] == " function" ) {
174243 auto & function = tool_call.at (" function" );
175- std::string arguments = function.at (" arguments" );
176- function[" arguments" ] = json::parse (arguments);
244+ auto & arguments = function.at (" arguments" );
245+ if (arguments.is_string ()) {
246+ try {
247+ arguments = json::parse (arguments.get <std::string>());
248+ } catch (const std::exception & ecvt) {
249+ fprintf (stderr, " Failed to parse arguments: %s\n " , ecvt.what ());
250+ }
251+ }
177252 }
178253 }
179254 }
180- if (!supports_tools_ ) {
255+ if (!caps_. supports_tool_calls ) {
181256 auto content = message.at (" content" );
182257 auto tool_calls = json::array ();
183258 for (const auto & tool_call : message.at (" tool_calls" )) {
@@ -204,7 +279,7 @@ class chat_template {
204279 message.erase (" tool_calls" );
205280 }
206281 }
207- if (!supports_tools_ && role == " tool" ) {
282+ if (!caps_. supports_tool_responses && role == " tool" ) {
208283 message[" role" ] = " user" ;
209284 auto obj = json {
210285 {" tool_response" , {
@@ -219,7 +294,7 @@ class chat_template {
219294 message.erase (" name" );
220295 }
221296
222- if (!message[" content" ].is_null () && !supports_system_role_ ) {
297+ if (!message[" content" ].is_null () && !caps_. supports_system_role ) {
223298 std::string content = message.at (" content" );
224299 if (role == " system" ) {
225300 if (!pending_system.empty ()) pending_system += " \n " ;
@@ -238,7 +313,9 @@ class chat_template {
238313 }
239314 add_message (message);
240315 }
241- flush_sys ();
316+ if (!caps_.supports_system_role ) {
317+ flush_sys ();
318+ }
242319 } else {
243320 actual_messages = messages;
244321 }
@@ -261,7 +338,28 @@ class chat_template {
261338 }
262339 }
263340
264- return template_root_->render (context);
341+ auto ret = template_root_->render (context);
342+ // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
343+ // fprintf(stderr, "apply: %s\n\n", ret.c_str());
344+ return ret;
345+ }
346+
347+ static nlohmann::ordered_json add_system (const nlohmann::ordered_json & messages, const std::string & system_prompt) {
348+ json messages_with_system = messages;
349+
350+ if (messages_with_system.size () > 0 && messages_with_system[0 ].at (" role" ) == " system" ) {
351+ std::string existing_system = messages_with_system.at (0 ).at (" content" );
352+ messages_with_system[0 ] = json {
353+ {" role" , " system" },
354+ {" content" , existing_system + " \n " + system_prompt},
355+ };
356+ } else {
357+ messages_with_system.insert (messages_with_system.begin (), json {
358+ {" role" , " system" },
359+ {" content" , system_prompt},
360+ });
361+ }
362+ return messages_with_system;
265363 }
266364};
267365
0 commit comments