Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions common/chat-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,23 @@ std::string common_chat_msg_parser::consume_rest() {
}

// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;

if (add_prelude_to_content) {
add_content(prelude);
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;

return find_regex_result{prelude, m.groups};
}

Expand Down
3 changes: 2 additions & 1 deletion common/chat-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class common_chat_msg_parser {
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
const common_chat_syntax & syntax() const { return syntax_; }

void move_to(size_t pos) {
if (pos > input_.size()) {
Expand Down Expand Up @@ -77,7 +78,7 @@ class common_chat_msg_parser {
std::vector<common_string_range> groups;
};

std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);

bool try_consume_literal(const std::string & literal);

Expand Down
53 changes: 36 additions & 17 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,6 @@ static void parse_json_tool_calls(
}
from = std::string::npos;

builder.add_content(res->prelude);
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
Expand Down Expand Up @@ -689,7 +688,6 @@ static void parse_json_tool_calls(
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
builder.add_content(res->prelude);
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
Expand All @@ -702,7 +700,6 @@ static void parse_json_tool_calls(
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
if (auto res = builder.try_find_regex(prefix)) {
builder.add_content(res->prelude);
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
Expand Down Expand Up @@ -838,6 +835,10 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
return data;
}
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
Expand Down Expand Up @@ -910,6 +911,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
return data;
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
Expand Down Expand Up @@ -1004,7 +1010,6 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {

if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
builder.add_content(res->prelude);
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
Expand All @@ -1019,10 +1024,8 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
// If we didn't extract thoughts, prelude includes them.
builder.add_content(res->prelude);
if (auto res = builder.try_find_regex(end_response_regex)) {
builder.add_content(res->prelude);
// If we didn't extract thoughts, prelude includes them.
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're not using res any more and this block is a no-op, perhaps change it to

        if (!builder.try_find_regex(end_response_regex)) {
            ...
        }

builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
Expand Down Expand Up @@ -1131,6 +1134,11 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
Expand All @@ -1141,8 +1149,6 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
builder.add_content(res->prelude);

auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);

Expand Down Expand Up @@ -1258,6 +1264,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
Expand Down Expand Up @@ -1319,6 +1329,10 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
return data;
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
Expand Down Expand Up @@ -1460,11 +1474,14 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
return data;
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));

if (auto res = builder.try_find_regex(python_tag_regex)) {
builder.add_content(res->prelude);
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
Expand Down Expand Up @@ -1598,6 +1615,10 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex open_regex(
"(?:"
Expand All @@ -1619,8 +1640,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
);

if (auto res = builder.try_find_regex(open_regex)) {
builder.add_content(res->prelude);

const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";

Expand Down Expand Up @@ -1856,10 +1875,10 @@ static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}

static void common_chat_parse(common_chat_msg_parser & builder, common_chat_format format) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(format), builder.input().c_str());
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());

switch (format) {
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
Expand Down Expand Up @@ -1894,15 +1913,15 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form
common_chat_parse_command_r7b(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(format));
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}

common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder, syntax.format);
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG_DBG("Partial parse: %s\n", ex.what());
if (!is_partial) {
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ struct common_chat_syntax {
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
};

// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
Expand Down
Loading
Loading