Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,10 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size();
const char * pos = src;

// use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
// (though it's technically the same as -1 now)
auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) {
// ref: https://github.com/ggml-org/llama.cpp/pull/17381
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {

if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
Expand Down Expand Up @@ -377,7 +378,7 @@ const char * llama_grammar_parser::parse_sequence(
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (unsigned long i = 1; i < min_times; i++) {
for (uint64_t i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
}
Expand All @@ -386,7 +387,7 @@ const char * llama_grammar_parser::parse_sequence(
auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;

llama_grammar_rule rec_rule(prev_rule);
for (unsigned long i = 0; i < n_opt; i++) {
for (uint64_t i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times == UINT64_MAX) {
Expand Down Expand Up @@ -482,10 +483,10 @@ const char * llama_grammar_parser::parse_sequence(
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
unsigned long min_times = std::stoul(std::string(pos, int_end - pos));
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);

unsigned long max_times = UINT64_MAX;
uint64_t max_times = UINT64_MAX;

if (*pos == '}') {
max_times = min_times;
Expand All @@ -506,7 +507,7 @@ const char * llama_grammar_parser::parse_sequence(
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) {
if (min_times > MAX_REPETITION_THRESHOLD || max_times > MAX_REPETITION_THRESHOLD) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be wrong. The condition was "if min_times exceeds threshold or (max_times is defined and exceeds threshold)". Now, this is going to trigger if max_times is not defined (so with X{n,} patterns)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're not going to use std::optional like I wanted in the first approach, then I think this is the only way to do it. This is bascially equivalent to checking for == -1 or, as the original test had it, < 0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm ok I thought one of the condition was overlapping the other.

using uint_max as a special value is fine, the main advantage of std::optional is its readability which is easy to replicate: c0b9903

throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
}
handle_repetitions(min_times, max_times);
Expand Down
Loading