Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,46 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
return msgs;
}

json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
static json render_message_to_json(const std::vector<common_chat_msg> & msgs, const jinja::caps & c) {
if (!c.supports_string_content && !c.supports_typed_content) {
LOG_WRN("%s: Neither string content nor typed content is supported by the template. This is unexpected and may lead to issues.\n", __func__);
}

bool only_string_accepted = c.supports_string_content && !c.supports_typed_content;
bool only_typed_accepted = !c.supports_string_content && c.supports_typed_content;

json messages = json::array();
for (const auto & msg : msgs) {
json jmsg = msg.to_json_oaicompat(concat_typed_text);
messages.push_back(jmsg);
if (only_string_accepted) {
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ true);
messages.push_back(jmsg);
} else if (only_typed_accepted) {
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false);
if (jmsg.at("content").is_string()) {
jmsg["content"] = json::array({
json{
{"type", "text"},
{"text", jmsg.at("content").get<std::string>()},
}
});
}
messages.push_back(jmsg);
} else {
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false);
messages.push_back(jmsg);
}
}
return messages;
}

// DEPRECATED: only used in tests
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
jinja::caps c;
c.supports_string_content = true;
c.supports_typed_content = !concat_typed_text;
return render_message_to_json(msgs, c);
}

std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
std::vector<common_chat_tool> result;

Expand Down Expand Up @@ -3020,7 +3051,7 @@ static common_chat_params common_chat_templates_apply_jinja(
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.add_generation_prompt = inputs.add_generation_prompt;
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
Expand Down
2 changes: 2 additions & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *

// Parses a JSON array of messages in OpenAI's chat completion API format.
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);

// DEPRECATED: only used in tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be a good idea just to replace common_chat_msgs_to_json_oaicompat with the body of the new render_message_to_json (with the extended signature with full caps) and make the old common_chat_msgs_to_json_oaicompat call it with just the single cap set?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I implemented in 6edb330 , test-chat is OK

nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);

std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
Expand Down
13 changes: 9 additions & 4 deletions common/jinja/caps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ static void caps_print_stats(value & v, const std::string & path) {

std::map<std::string, bool> caps::to_map() const {
return {
{"requires_typed_content", requires_typed_content},
{"supports_string_content", supports_string_content},
{"supports_typed_content", supports_typed_content},
{"supports_tools", supports_tools},
{"supports_tool_calls", supports_tool_calls},
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
Expand All @@ -89,7 +90,7 @@ caps caps_get(jinja::program & prog) {
return v->stats.ops.find(op_name) != v->stats.ops.end();
};

// case: typed content requirement
// case: typed content support
caps_try_execute(
prog,
[&]() {
Expand All @@ -105,12 +106,16 @@ caps caps_get(jinja::program & prog) {
// tools
return json{nullptr};
},
[&](bool, value & messages, value &) {
[&](bool success, value & messages, value &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
// accessed as an array
result.requires_typed_content = true;
result.supports_typed_content = true;
}
if (!success) {
// failed to execute with content as string
result.supports_string_content = false;
}
}
);
Expand Down
4 changes: 3 additions & 1 deletion common/jinja/caps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ struct caps {
bool supports_parallel_tool_calls = true;
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content

bool requires_typed_content = false; // default: use string content
// one of the 2 content capabilities must be true
bool supports_string_content = true;
bool supports_typed_content = false;

// for reporting on server
std::map<std::string, bool> to_map() const;
Expand Down
6 changes: 6 additions & 0 deletions common/jinja/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,12 @@ value for_statement::execute_impl(context & ctx) {

value iterable_val = iter_expr->execute(scope);

// mark the variable being iterated as used for stats
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
iterable_val->stats.ops.insert("array_access");
}

if (iterable_val->is_undefined()) {
JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop");
iterable_val = mk_val<value_array>();
Expand Down
Loading