Skip to content

Commit 054a45c

Browse files
authored
grammar: fix regression caused by ggml-org#17381 (ggml-org#17412)
* grammar: fix regression caused by ggml-org#17381 * more readable
1 parent 4c91f26 commit 054a45c

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/llama-grammar.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,10 @@ const char * llama_grammar_parser::parse_sequence(
347347
size_t last_sym_start = rule.size();
348348
const char * pos = src;
349349

350-
// use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used
350+
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
351351
// (though it's technically the same as -1 now)
352-
auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) {
353-
352+
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
353+
bool no_max = max_times == UINT64_MAX;
354354
if (last_sym_start == rule.size()) {
355355
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
356356
}
@@ -377,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence(
377377
rule.resize(last_sym_start);
378378
} else {
379379
// Repeat the previous elements (min_times - 1) times
380-
for (unsigned long i = 1; i < min_times; i++) {
380+
for (uint64_t i = 1; i < min_times; i++) {
381381
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
382382
}
383383
}
384384

385385
uint32_t last_rec_rule_id = 0;
386-
auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;
386+
auto n_opt = no_max ? 1 : max_times - min_times;
387387

388388
llama_grammar_rule rec_rule(prev_rule);
389-
for (unsigned long i = 0; i < n_opt; i++) {
389+
for (uint64_t i = 0; i < n_opt; i++) {
390390
rec_rule.resize(prev_rule.size());
391391
uint32_t rec_rule_id = generate_symbol_id( rule_name);
392-
if (i > 0 || max_times == UINT64_MAX) {
393-
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times == UINT64_MAX ? rec_rule_id : last_rec_rule_id});
392+
if (i > 0 || no_max) {
393+
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
394394
}
395395
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
396396
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
@@ -482,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence(
482482
throw std::runtime_error(std::string("expecting an int at ") + pos);
483483
}
484484
const char * int_end = parse_int(pos);
485-
unsigned long min_times = std::stoul(std::string(pos, int_end - pos));
485+
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
486486
pos = parse_space(int_end, is_nested);
487487

488-
unsigned long max_times = UINT64_MAX;
488+
uint64_t max_times = UINT64_MAX; // default: no max limit
489489

490490
if (*pos == '}') {
491491
max_times = min_times;
@@ -506,7 +506,8 @@ const char * llama_grammar_parser::parse_sequence(
506506
} else {
507507
throw std::runtime_error(std::string("expecting ',' at ") + pos);
508508
}
509-
if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) {
509+
bool has_max = max_times != UINT64_MAX;
510+
if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
510511
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
511512
}
512513
handle_repetitions(min_times, max_times);

0 commit comments

Comments
 (0)