Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2064,7 +2064,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
// Trigger on tool calls that appear in the commentary channel
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
"<\\|channel\\|>(commentary|analysis) to"
"<\\|channel\\|>(?:commentary|analysis) to"
});

// Trigger tool calls that appear in the role section, either at the
Expand Down Expand Up @@ -2397,17 +2397,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
// If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
std::string(data.thinking_forced_open ? "(</think>\\s*)" : "") + (
"\\s*("
"(?:<tool_call>"
"|<function"
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
")"
")[\\s\\S]*"
")"
),
});
data.preserved_tokens = {
Expand Down
26 changes: 13 additions & 13 deletions common/regex-partial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
return res;
}
std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
auto group = srmatch[1].str();
if (group.length() != 0) {
auto it = srmatch[1].second.base();
Expand Down Expand Up @@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.

- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
- /a|b/ -> (a|b).*
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
- /a|b/ -> ^(a|b)
- /a*?/ -> error, could match ""
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
- /.*?ab/ -> ((?:b)?a).* (merge .*)
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
- /.*?ab/ -> ^((?:b)?a) (omit .*)
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)

The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
*/
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin();
Expand Down Expand Up @@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
}
}

// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts;
Expand All @@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
throw std::runtime_error("Unmatched '(' in pattern");
}

return "(" + res + ")[\\s\\S]*";
return "^(" + res + ")";
}
18 changes: 10 additions & 8 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
patterns_anywhere.push_back(regex_escape(word));
trigger_patterns.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
patterns_anywhere.push_back(trigger.value);
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
const auto & pattern = trigger.value;
std::string anchored = "^$";
if (!pattern.empty()) {
anchored = (pattern.front() != '^' ? "^" : "")
+ pattern
+ (pattern.back() != '$' ? "$" : "");
}
trigger_patterns.push_back(anchored);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
Expand All @@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}

if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}

std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {
Expand Down
53 changes: 40 additions & 13 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,44 @@ static void print_rule(
fprintf(file, "\n");
}

//
// Regex utilities
//

size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
auto find_start_pos = [](const std::smatch & match) {
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
return start;
};

if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
// match against the entire input
std::smatch match;
if (std::regex_match(input, match, regex)) {
return find_start_pos(match);
}
}

// search anywhere
std::smatch match;
if (std::regex_search(input, match, regex)) {
return find_start_pos(match);
}

return std::string::npos;
}


//
// implementation
//
Expand Down Expand Up @@ -1312,21 +1350,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
grammar.trigger_buffer += piece;

std::smatch match;
for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
auto start = trigger_pattern.find(grammar.trigger_buffer);
if (start != std::string::npos) {
grammar.awaiting_trigger = false;
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}

// replay tokens that overlap with [start, end)
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
Expand Down
2 changes: 2 additions & 0 deletions src/llama-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ struct llama_grammar_parser {
struct llama_grammar_trigger_pattern {
std::string pattern;
std::regex regex;

size_t find(const std::string & input) const;
};

struct llama_grammar {
Expand Down
28 changes: 14 additions & 14 deletions tests/test-regex-partial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() {
printf("[%s]\n", __func__);

assert_equals<std::string>(
"((?:(?:c)?b)?a)[\\s\\S]*",
"^((?:(?:c)?b)?a)",
regex_to_reversed_partial_regex("abc"));

assert_equals<std::string>(
"(a+)[\\s\\S]*",
"^(a+)",
regex_to_reversed_partial_regex("a+"));

assert_equals<std::string>(
"(a*)[\\s\\S]*",
"^(a*)",
regex_to_reversed_partial_regex("a*"));

assert_equals<std::string>(
"(a?)[\\s\\S]*",
"^(a?)",
regex_to_reversed_partial_regex("a?"));

assert_equals<std::string>(
"([a-z])[\\s\\S]*",
"^([a-z])",
regex_to_reversed_partial_regex("[a-z]"));

assert_equals<std::string>(
"((?:\\w+)?[a-z])[\\s\\S]*",
"^((?:\\w+)?[a-z])",
regex_to_reversed_partial_regex("[a-z]\\w+"));

assert_equals<std::string>(
"((?:a|b))[\\s\\S]*",
"^((?:a|b))",
regex_to_reversed_partial_regex("(?:a|b)"));
assert_equals<std::string>(
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
"^((?:(?:(?:d)?c)?b)?a)",
regex_to_reversed_partial_regex("abcd"));
assert_equals<std::string>(
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
"^((?:b)?a*)", // TODO: ((?:b)?a*+).* ??
regex_to_reversed_partial_regex("a*b"));
assert_equals<std::string>(
"((?:(?:b)?a)?.*)[\\s\\S]*",
"^((?:(?:b)?a)?.*)",
regex_to_reversed_partial_regex(".*?ab"));
assert_equals<std::string>(
"((?:(?:b)?.*)?a)[\\s\\S]*",
"^((?:(?:b)?.*)?a)",
regex_to_reversed_partial_regex("a.*?b"));
assert_equals<std::string>(
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
"^((?:(?:d)?(?:(?:c)?b))?a)",
regex_to_reversed_partial_regex("a(bc)d"));
assert_equals<std::string>(
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
"^((?:(?:(?:c)?b|(?:e)?d))?a)",
regex_to_reversed_partial_regex("a(bc|de)"));
assert_equals<std::string>(
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
"^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)",
regex_to_reversed_partial_regex("ab{2,4}c"));
}

Expand Down
Loading