Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size();
const char * pos = src;

// use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
// (though it's technically the same as -1 now)
auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) {

auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
bool no_max = max_times == UINT64_MAX;
if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}
Expand All @@ -377,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence(
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (unsigned long i = 1; i < min_times; i++) {
for (uint64_t i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
}

uint32_t last_rec_rule_id = 0;
auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;
auto n_opt = no_max ? 1 : max_times - min_times;

llama_grammar_rule rec_rule(prev_rule);
for (unsigned long i = 0; i < n_opt; i++) {
for (uint64_t i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times == UINT64_MAX) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times == UINT64_MAX ? rec_rule_id : last_rec_rule_id});
if (i > 0 || no_max) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
Expand Down Expand Up @@ -482,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence(
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
unsigned long min_times = std::stoul(std::string(pos, int_end - pos));
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);

unsigned long max_times = UINT64_MAX;
uint64_t max_times = UINT64_MAX; // default: no max limit

if (*pos == '}') {
max_times = min_times;
Expand All @@ -506,7 +506,8 @@ const char * llama_grammar_parser::parse_sequence(
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) {
bool has_max = max_times != UINT64_MAX;
if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
}
handle_repetitions(min_times, max_times);
Expand Down
Loading