Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,45 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) {
return !msg.content.empty() || !msg.tool_calls.empty();
}

std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
if (delims.empty() || prompt.empty()) {
return {};
}

auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
std::vector<std::string> all_delims;
std::vector<common_peg_parser> tagged_messages;

all_delims.reserve(delims.size());
tagged_messages.reserve(delims.size());
for (const auto & d : delims) {
all_delims.push_back(d.delimiter);
}

auto any_delim = p.until_one_of(all_delims);
for (const auto & d : delims) {
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
}

return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
});

common_peg_parse_context ctx(prompt);
const auto result = parser.parse(ctx);
if (!result.success()) {
return {};
}

std::vector<common_chat_msg_span> spans;
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
if (!node.tag.empty()) {
spans.push_back({ node.tag, node.start, node.end - node.start });
}
});

return spans;
}

json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
if (!content.empty() && !content_parts.empty()) {
throw std::runtime_error("Cannot specify both content and content_parts");
Expand Down Expand Up @@ -973,7 +1012,15 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
}
}

data.prompt = prompt;
data.prompt = prompt;
data.message_spans = common_chat_split_by_role(prompt, {
{ "assistant", "<|start|>assistant" },
{ "user", "<|start|>user" },
{ "system", "<|start|>developer" },
{ "system", "<|start|>system" },
{ "tool", "<|start|>functions" },
});

data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;

Expand Down
14 changes: 14 additions & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ struct common_chat_msg_diff {
}
};

struct common_chat_msg_span {
std::string role;
std::size_t pos = 0;
std::size_t len = 0;
};

struct common_chat_msg_delimiter {
std::string role;
std::string delimiter;
};

struct common_chat_tool {
std::string name;
std::string description;
Expand Down Expand Up @@ -187,6 +198,7 @@ struct common_chat_params {
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
std::vector<common_chat_msg_span> message_spans;
};

// per-message parsing syntax
Expand Down Expand Up @@ -275,3 +287,5 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
const common_chat_template & tmpl,
const std::string & src,
autoparser::generation_params & params);

std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
35 changes: 35 additions & 0 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,40 @@ static void test_msgs_oaicompat_json_conversion() {
}
}

static void test_split_by_role() {
LOG_DBG("%s\n", __func__);

// Empty inputs
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());

// Multi-role conversation, no leading/trailing content
{
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
const auto splits = common_chat_split_by_role(prompt, {
{ "user", "<|user|>" },
{ "assistant", "<|assistant|>" },
});
assert_equals<size_t>(3, splits.size());

assert_equals<std::string>("user", splits[0].role);
assert_equals<size_t>(0, splits[0].pos);
assert_equals<size_t>(10, splits[0].len);
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));

assert_equals<std::string>("assistant", splits[1].role);
assert_equals<size_t>(10, splits[1].pos);
assert_equals<size_t>(18, splits[1].len);
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));

assert_equals<std::string>("user", splits[2].role);
assert_equals<size_t>(28, splits[2].pos);
assert_equals<size_t>(11, splits[2].len);
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
}
}

static void test_tools_oaicompat_json_conversion() {
LOG_DBG("%s\n", __func__);
std::vector<common_chat_tool> tools{
Expand Down Expand Up @@ -4168,6 +4202,7 @@ int main(int argc, char ** argv) {
{
test_msg_diffs_compute();
test_msgs_oaicompat_json_conversion();
test_split_by_role();
test_tools_oaicompat_json_conversion();
test_developer_role_to_system_workaround();
test_template_output_peg_parsers(detailed_debug);
Expand Down