@@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) {
1616 case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return " Functionary v3.2" ;
1717 case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return " Functionary v3.1 Llama 3.1" ;
1818 case COMMON_CHAT_FORMAT_HERMES_2_PRO: return " Hermes 2 Pro" ;
19+ case COMMON_CHAT_FORMAT_COMMAND_R7B: return " Command R7B" ;
1920 default :
2021 throw std::runtime_error (" Unknown chat format" );
2122 }
@@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
317318 return parse_prefixed_json_tool_call_array (input, " [TOOL_CALLS]" );
318319}
319320
321+ static common_chat_params common_chat_params_init_command_r7b (const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
322+ common_chat_params data;
323+ data.grammar_lazy = inputs.tool_choice != " required" ;
324+ data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
325+ auto schemas = json::array ();
326+ foreach_function (inputs.tools , [&](const json & tool) {
327+ const auto & function = tool[" function" ];
328+ schemas.push_back ({
329+ {" type" , " object" },
330+ {" properties" , {
331+ {" tool_call_id" , {
332+ {" type" , " string" },
333+ // Command-R's template expects an integer string.
334+ {" pattern" , " ^[0-9]{1,10}$" },
335+ }},
336+ {" tool_name" , {
337+ {" type" , " string" },
338+ {" const" , function[" name" ]},
339+ }},
340+ {" parameters" , function[" parameters" ]},
341+ }},
342+ {" required" , json::array ({" tool_call_id" , " tool_name" , " parameters" })},
343+ });
344+ });
345+ auto schema = json {
346+ {" type" , " array" },
347+ {" items" , schemas.size () == 1 ? schemas[0 ] : json {{" anyOf" , schemas}}},
348+ {" minItems" , 1 },
349+ };
350+ if (!inputs.parallel_tool_calls ) {
351+ schema[" maxItems" ] = 1 ;
352+ }
353+ builder.add_rule (" root" , " \" <|START_ACTION|>\" " + builder.add_schema (" tool_calls" , schema) + " \" <|END_ACTION|>\" " );
354+ }, grammar_options);
355+ data.grammar_triggers .push_back ({" <|START_ACTION|>" , /* .at_start = */ false });
356+ data.preserved_tokens = {
357+ " <|START_RESPONSE|>" ,
358+ " <|END_RESPONSE|>" ,
359+ " <|START_THINKING|>" ,
360+ " <|END_THINKING|>" ,
361+ " <|END_ACTION|>" ,
362+ };
363+ data.prompt = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
364+ data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
365+ return data;
366+ }
367+ static common_chat_msg common_chat_parse_command_r7b (const std::string & input) {
368+ static std::regex response_regex (" <\\ |START_RESPONSE\\ |>(.*?)<\\ |END_RESPONSE\\ |>" );
369+ static std::regex thought_action_regex (" <\\ |START_THINKING\\ |>([\\ s\\ S\\ n\\ r]*?)<\\ |END_THINKING\\ |><\\ |START_ACTION\\ |>([\\ s\\ S\\ n\\ r]*?)<\\ |END_ACTION\\ |>" );
370+ std::smatch match;
371+
372+ common_chat_msg result;
373+ result.role = " assistant" ;
374+ if (std::regex_match (input, match, response_regex)) {
375+ result.content = match[1 ].str ();
376+ } else if (std::regex_match (input, match, thought_action_regex)) {
377+ result.tool_plan = match[1 ].str ();
378+ auto actions_str = match[2 ].str ();
379+ auto actions = json::parse (actions_str);
380+ for (const auto & action : actions) {
381+ result.tool_calls .push_back ({
382+ /* .name = */ action[" tool_name" ],
383+ /* .arguments = */ action[" parameters" ].dump (),
384+ /* .id = */ action[" tool_call_id" ],
385+ });
386+ }
387+ } else {
388+ LOG_ERR (" Failed to parse command_r output" );
389+ result.content = input;
390+ }
391+ return result;
392+ }
393+
320394static void expect_tool_parameters (const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
321395 if (!parameters.is_object () || !parameters.contains (" type" ) || parameters[" type" ] != " object" || !parameters.contains (" properties" ) || !parameters.contains (" required" )) {
322396 throw std::runtime_error (" Parameters of tool " + name + " must be an object w/ required properties" );
@@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
462536 " \" <|tool▁call▁begin|>function<|tool▁sep|>" + name + " \\ n```json\\ n\" " + args_rule + " \" ```<|tool▁call▁end|>\" " ));
463537 });
464538 data.grammar_triggers .push_back ({" <|tool▁calls▁begin|>" , /* .at_start = */ false });
539+ data.preserved_tokens = {
540+ " <|tool▁sep|>" ,
541+ " <|tool▁call▁end|>" ,
542+ };
465543 builder.add_rule (" root" , " \" <|tool▁calls▁begin|>\" (" + string_join (tool_rules, " | " ) + " )" + (inputs.parallel_tool_calls ? " *" : " " ) + " space" );
466544 }, grammar_options);
467545 data.prompt = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
@@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
704782 auto tool_call = " \" <tool_call>\" space " + builder.add_rule (" tool_call" , string_join (tool_rules, " | " )) + " \" </tool_call>\" space" ;
705783 builder.add_rule (" root" , inputs.parallel_tool_calls ? " (" + tool_call + " )+" : tool_call);
706784 data.grammar_triggers .push_back ({" <tool_call>" , /* .at_start = */ false });
707- // Not really a trigger but need to print this special token to get a successful parse.
708- data.grammar_triggers .push_back ({" </tool_call>" , /* .at_start = */ false });
785+ data.preserved_tokens = { " </tool_call>" };
709786 }, grammar_options);
710787
711788 data.prompt = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
@@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
822899 if (src.find (" [TOOL_CALLS]" ) != std::string::npos) {
823900 return common_chat_params_init_mistral_nemo (tmpl, inputs);
824901 }
902+ if (src.find (" <|END_THINKING|><|START_ACTION|>" ) != std::string::npos) {
903+ return common_chat_params_init_command_r7b (tmpl, inputs);
904+ }
825905 return common_chat_params_init_generic (tmpl, inputs);
826906}
827907
@@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
855935 return common_chat_parse_hermes_2_pro (input);
856936 case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
857937 return common_chat_parse_firefunction_v2 (input);
938+ case COMMON_CHAT_FORMAT_COMMAND_R7B:
939+ return common_chat_parse_command_r7b (input);
858940 default :
859941 throw std::runtime_error (" Unsupported format: " + common_chat_format_name (format));
860942 }
0 commit comments