Skip to content

Commit 56d7750

Browse files
committed
wip fix tests
1 parent e652566 commit 56d7750

File tree

11 files changed

+163
-151
lines changed

11 files changed

+163
-151
lines changed

common/sampling.cpp

Lines changed: 70 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ struct ring_buffer {
104104
struct common_sampler {
105105
common_params_sampling params;
106106

107-
struct llama_sampler * grmr;
108107
struct llama_sampler * chain;
109108

109+
bool grammar;
110+
110111
ring_buffer<llama_token> prev;
111112

112113
std::vector<llama_token_data> cur;
@@ -116,7 +117,6 @@ struct common_sampler {
116117
void reset() {
117118
prev.clear();
118119

119-
llama_sampler_reset(grmr);
120120
llama_sampler_reset(chain);
121121
}
122122

@@ -184,10 +184,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
184184

185185
lparams.no_perf = params.no_perf;
186186

187-
struct llama_sampler * grmr;
187+
llama_sampler * chain = llama_sampler_chain_init(lparams);
188+
189+
bool grammar = false;
190+
std::vector<llama_sampler *> samplers;
191+
188192
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
189193
#ifdef LLAMA_USE_LLGUIDANCE
190-
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
194+
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
195+
grammar = true;
191196
#else
192197
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
193198
#endif // LLAMA_USE_LLGUIDANCE
@@ -234,26 +239,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
234239
trigger_patterns_c.push_back(regex.c_str());
235240
}
236241

237-
grmr = params.grammar_lazy
238-
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
239-
trigger_patterns_c.data(), trigger_patterns_c.size(),
240-
trigger_tokens.data(), trigger_tokens.size())
241-
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
242-
if (!grmr) {
243-
return nullptr;
242+
if (!params.grammar.empty()) {
243+
if (params.grammar_lazy) {
244+
samplers.push_back(
245+
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
246+
trigger_patterns_c.data(), trigger_patterns_c.size(),
247+
trigger_tokens.data(), trigger_tokens.size()));
248+
} else {
249+
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
250+
}
251+
252+
grammar = true;
244253
}
245254
}
246255

247-
auto * result = new common_sampler {
248-
/* .params = */ params,
249-
/* .grmr = */ grmr,
250-
/* .chain = */ llama_sampler_chain_init(lparams),
251-
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
252-
/* .cur = */ {},
253-
/* .cur_p = */ {},
254-
};
255-
256-
std::vector<llama_sampler *> samplers;
257256
if (params.has_logit_bias()) {
258257
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
259258
}
@@ -316,15 +315,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
316315
}
317316

318317
for (auto * smpl : samplers) {
319-
llama_sampler_chain_add(result->chain, smpl);
318+
llama_sampler_chain_add(chain, smpl);
320319
}
321320

321+
auto * result = new common_sampler {
322+
/* .params = */ params,
323+
/* .chain = */ chain,
324+
/* .grammar = */ grammar,
325+
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
326+
/* .cur = */ {},
327+
/* .cur_p = */ {},
328+
};
329+
322330
return result;
323331
}
324332

325333
void common_sampler_free(struct common_sampler * gsmpl) {
326334
if (gsmpl) {
327-
llama_sampler_free(gsmpl->grmr);
328335
llama_sampler_free(gsmpl->chain);
329336

330337
delete gsmpl;
@@ -334,11 +341,24 @@ void common_sampler_free(struct common_sampler * gsmpl) {
334341
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
335342
const auto tm = gsmpl->tm();
336343

337-
if (accept_grammar) {
338-
llama_sampler_accept(gsmpl->grmr, token);
339-
}
344+
if (gsmpl->grammar) {
345+
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
340346

341-
llama_sampler_accept(gsmpl->chain, token);
347+
for (int i = 0; i < n_smpl; i++) {
348+
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
349+
350+
// the grammar sampler is always the first one
351+
if (i == 0) {
352+
if (accept_grammar) {
353+
llama_sampler_accept(smpl, token);
354+
}
355+
} else {
356+
llama_sampler_accept(smpl, token);
357+
}
358+
}
359+
} else {
360+
llama_sampler_accept(gsmpl->chain, token);
361+
}
342362

343363
gsmpl->prev.push_back(token);
344364
}
@@ -349,12 +369,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
349369

350370
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
351371
return new common_sampler {
352-
/* .params = */ gsmpl->params,
353-
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
354-
/* .chain = */ llama_sampler_clone(gsmpl->chain),
355-
/* .prev = */ gsmpl->prev,
356-
/* .cur = */ gsmpl->cur,
357-
/* .cur_p = */ gsmpl->cur_p,
372+
/* .params = */ gsmpl->params,
373+
/* .chain = */ llama_sampler_clone(gsmpl->chain),
374+
/* .grammar = */ gsmpl->grammar,
375+
/* .prev = */ gsmpl->prev,
376+
/* .cur = */ gsmpl->cur,
377+
/* .cur_p = */ gsmpl->cur_p,
358378
};
359379
}
360380

@@ -407,77 +427,49 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
407427
return gsmpl->chain;
408428
}
409429

410-
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
430+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
431+
llama_synchronize(ctx);
432+
433+
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
434+
const auto tm = gsmpl->tm();
435+
436+
llama_token id = LLAMA_TOKEN_NULL;
437+
411438
// Check if a backend sampler has already sampled a token in which case we
412439
// return that token id directly.
413440
{
414-
const llama_token id = llama_get_sampled_token_ith(ctx, idx);
441+
id = llama_get_sampled_token_ith(ctx, idx);
415442

416443
if (id != LLAMA_TOKEN_NULL) {
417444
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
445+
418446
return id;
419447
}
420448
}
421449

422-
llama_synchronize(ctx);
423-
424-
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
425-
const auto tm = gsmpl->tm();
426-
427450
gsmpl->set_logits(ctx, idx);
428451

429-
auto & grmr = gsmpl->grmr;
430452
auto & chain = gsmpl->chain;
431453
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
432454

433-
if (grammar_first) {
434-
llama_sampler_apply(grmr, &cur_p);
435-
}
436-
437455
llama_sampler_apply(chain, &cur_p);
438456

439457
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
440458

441-
const llama_token id = cur_p.data[cur_p.selected].id;
442-
443-
if (grammar_first) {
444-
return id;
445-
}
446-
447-
// check if it the sampled token fits the grammar
448-
{
449-
llama_token_data single_token_data = { id, 1.0f, 0.0f };
450-
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
451-
452-
llama_sampler_apply(grmr, &single_token_data_array);
453-
454-
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
455-
if (is_valid) {
456-
return id;
457-
}
458-
}
459-
460-
// resampling:
461-
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
462-
gsmpl->set_logits(ctx, idx);
463-
464-
llama_sampler_apply(grmr, &cur_p);
465-
llama_sampler_apply(chain, &cur_p);
466-
467-
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
459+
id = cur_p.data[cur_p.selected].id;
468460

469-
return cur_p.data[cur_p.selected].id;
461+
return id;
470462
}
471463

472-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
464+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
473465
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
474466

475467
std::vector<llama_token> result;
476468
result.reserve(idxs.size());
477469

478470
size_t i = 0;
479471
for (; i < draft.size(); i++) {
480-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
472+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
481473

482474
common_sampler_accept(gsmpl, id, true);
483475

@@ -489,7 +481,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
489481
}
490482

491483
if (i == draft.size()) {
492-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
484+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
493485

494486
common_sampler_accept(gsmpl, id, true);
495487

@@ -499,13 +491,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
499491
return result;
500492
}
501493

502-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
494+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
503495
std::vector<int> idxs(draft.size() + 1);
504496
for (size_t i = 0; i < idxs.size(); ++i) {
505497
idxs[i] = i;
506498
}
507499

508-
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
500+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
509501
}
510502

511503
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

common/sampling.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
5757
// - check if the token fits the grammar (if any)
5858
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
5959
//
60-
// if grammar_first is true, the grammar is applied before the samplers (slower)
61-
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
62-
//
63-
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
60+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
6461

6562
// generalized version of common_sampler_sample
6663
//
@@ -78,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
7875
//
7976
// returns at least 1 token, up to idxs.size()
8077
//
81-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
78+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
8279

8380
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
84-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
81+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
8582

8683
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
8784

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
315315
for (int i = 0; i < params.n_draft; ++i) {
316316
common_batch_clear(batch);
317317

318-
common_sampler_sample(smpl, ctx_dft, 0, true);
318+
common_sampler_sample(smpl, ctx_dft, 0);
319319

320320
const auto * cur_p = common_sampler_get_candidates(smpl, true);
321321

examples/speculative/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
242242
bool accept = false;
243243
if (params.sampling.temp > 0) {
244244
// stochastic verification
245-
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
245+
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
246246

247247
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
248248

@@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
491491
continue;
492492
}
493493

494-
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
494+
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
495495

496496
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
497497

src/llama-context.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
820820
output_reorder();
821821

822822
if (sampling.logits == nullptr) {
823-
return 0;
823+
return model.vocab.n_tokens();
824824
}
825825

826826
try {
@@ -2977,14 +2977,15 @@ float * llama_get_logits(llama_context * ctx) {
29772977
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
29782978
ctx->synchronize();
29792979

2980-
if (ctx->get_sampled_token_ith(i) != LLAMA_TOKEN_NULL) {
2981-
return nullptr;
2982-
}
2983-
if (ctx->get_sampled_probs_ith(i) != nullptr) {
2984-
return nullptr;
2980+
float * res = nullptr;
2981+
2982+
res = ctx->get_sampled_logits_ith(i);
2983+
2984+
if (!res) {
2985+
res = ctx->get_logits_ith(i);
29852986
}
29862987

2987-
return ctx->get_logits_ith(i);
2988+
return res;
29882989
}
29892990

29902991
float * llama_get_embeddings(llama_context * ctx) {

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2109,7 +2109,7 @@ void llm_graph_context::build_sampling() const {
21092109
ggml_build_forward_expand(gf, data.probs);
21102110
}
21112111

2112-
if (data.logits != logits_seq) {
2112+
if (data.logits != nullptr) {
21132113
ggml_set_output(data.logits);
21142114
res->t_sampled_logits[seq_id] = data.logits;
21152115
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);

src/llama-sampling.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
366366
}
367367

368368
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
369+
if (!smpl) {
370+
return;
371+
}
372+
369373
if (smpl->iface->accept) {
370374
smpl->iface->accept(smpl, token);
371375
}
372376
}
373377

374378
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
379+
if (!smpl) {
380+
return;
381+
}
382+
375383
GGML_ASSERT(smpl->iface->apply);
376384
smpl->iface->apply(smpl, cur_p);
377385
}
378386

379387
void llama_sampler_reset(struct llama_sampler * smpl) {
388+
if (!smpl) {
389+
return;
390+
}
391+
380392
if (smpl->iface->reset) {
381393
smpl->iface->reset(smpl);
382394
}
383395
}
384396

385397
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
398+
if (!smpl) {
399+
return nullptr;
400+
}
401+
386402
if (smpl->iface->clone) {
387403
return smpl->iface->clone(smpl);
388404
}

0 commit comments

Comments
 (0)