@@ -104,9 +104,10 @@ struct ring_buffer {
104104struct common_sampler {
105105 common_params_sampling params;
106106
107- struct llama_sampler * grmr;
108107 struct llama_sampler * chain;
109108
109+ bool grammar;
110+
110111 ring_buffer<llama_token> prev;
111112
112113 std::vector<llama_token_data> cur;
@@ -116,7 +117,6 @@ struct common_sampler {
116117 void reset () {
117118 prev.clear ();
118119
119- llama_sampler_reset (grmr);
120120 llama_sampler_reset (chain);
121121 }
122122
@@ -184,10 +184,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
184184
185185 lparams.no_perf = params.no_perf ;
186186
187- struct llama_sampler * grmr;
187+ llama_sampler * chain = llama_sampler_chain_init (lparams);
188+
189+ bool grammar = false ;
190+ std::vector<llama_sampler *> samplers;
191+
188192 if (params.grammar .compare (0 , 11 , " %llguidance" ) == 0 ) {
189193#ifdef LLAMA_USE_LLGUIDANCE
190- grmr = llama_sampler_init_llg (vocab, " lark" , params.grammar .c_str ());
194+ samplers.push_back (llama_sampler_init_llg (vocab, " lark" , params.grammar .c_str ()));
195+ grammar = true ;
191196#else
192197 GGML_ABORT (" llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled" );
193198#endif // LLAMA_USE_LLGUIDANCE
@@ -234,26 +239,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
234239 trigger_patterns_c.push_back (regex.c_str ());
235240 }
236241
237- grmr = params.grammar_lazy
238- ? llama_sampler_init_grammar_lazy_patterns (vocab, params.grammar .c_str (), " root" ,
239- trigger_patterns_c.data (), trigger_patterns_c.size (),
240- trigger_tokens.data (), trigger_tokens.size ())
241- : llama_sampler_init_grammar (vocab, params.grammar .c_str (), " root" );
242- if (!grmr) {
243- return nullptr ;
242+ if (!params.grammar .empty ()) {
243+ if (params.grammar_lazy ) {
244+ samplers.push_back (
245+ llama_sampler_init_grammar_lazy_patterns (vocab, params.grammar .c_str (), " root" ,
246+ trigger_patterns_c.data (), trigger_patterns_c.size (),
247+ trigger_tokens.data (), trigger_tokens.size ()));
248+ } else {
249+ samplers.push_back (llama_sampler_init_grammar (vocab, params.grammar .c_str (), " root" ));
250+ }
251+
252+ grammar = true ;
244253 }
245254 }
246255
247- auto * result = new common_sampler {
248- /* .params = */ params,
249- /* .grmr = */ grmr,
250- /* .chain = */ llama_sampler_chain_init (lparams),
251- /* .prev = */ ring_buffer<llama_token>(std::max (32 , params.n_prev )),
252- /* .cur = */ {},
253- /* .cur_p = */ {},
254- };
255-
256- std::vector<llama_sampler *> samplers;
257256 if (params.has_logit_bias ()) {
258257 samplers.push_back (llama_sampler_init_logit_bias (llama_vocab_n_tokens (vocab), params.logit_bias .size (), params.logit_bias .data ()));
259258 }
@@ -316,15 +315,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
316315 }
317316
318317 for (auto * smpl : samplers) {
319- llama_sampler_chain_add (result-> chain , smpl);
318+ llama_sampler_chain_add (chain, smpl);
320319 }
321320
321+ auto * result = new common_sampler {
322+ /* .params = */ params,
323+ /* .chain = */ chain,
324+ /* .grammar = */ grammar,
325+ /* .prev = */ ring_buffer<llama_token>(std::max (32 , params.n_prev )),
326+ /* .cur = */ {},
327+ /* .cur_p = */ {},
328+ };
329+
322330 return result;
323331}
324332
325333void common_sampler_free (struct common_sampler * gsmpl) {
326334 if (gsmpl) {
327- llama_sampler_free (gsmpl->grmr );
328335 llama_sampler_free (gsmpl->chain );
329336
330337 delete gsmpl;
@@ -334,11 +341,24 @@ void common_sampler_free(struct common_sampler * gsmpl) {
334341void common_sampler_accept (struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
335342 const auto tm = gsmpl->tm ();
336343
337- if (accept_grammar) {
338- llama_sampler_accept (gsmpl->grmr , token);
339- }
344+ if (gsmpl->grammar ) {
345+ const int n_smpl = llama_sampler_chain_n (gsmpl->chain );
340346
341- llama_sampler_accept (gsmpl->chain , token);
347+ for (int i = 0 ; i < n_smpl; i++) {
348+ auto * smpl = llama_sampler_chain_get (gsmpl->chain , i);
349+
350+ // the grammar sampler is always the first one
351+ if (i == 0 ) {
352+ if (accept_grammar) {
353+ llama_sampler_accept (smpl, token);
354+ }
355+ } else {
356+ llama_sampler_accept (smpl, token);
357+ }
358+ }
359+ } else {
360+ llama_sampler_accept (gsmpl->chain , token);
361+ }
342362
343363 gsmpl->prev .push_back (token);
344364}
@@ -349,12 +369,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
349369
350370struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
351371 return new common_sampler {
352- /* .params = */ gsmpl->params ,
353- /* .grmr = */ llama_sampler_clone (gsmpl->grmr ),
354- /* .chain = */ llama_sampler_clone ( gsmpl->chain ) ,
355- /* .prev = */ gsmpl->prev ,
356- /* .cur = */ gsmpl->cur ,
357- /* .cur_p = */ gsmpl->cur_p ,
372+ /* .params = */ gsmpl->params ,
373+ /* .chain = */ llama_sampler_clone (gsmpl->chain ),
374+ /* .grammar = */ gsmpl->grammar ,
375+ /* .prev = */ gsmpl->prev ,
376+ /* .cur = */ gsmpl->cur ,
377+ /* .cur_p = */ gsmpl->cur_p ,
358378 };
359379}
360380
@@ -407,77 +427,49 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
407427 return gsmpl->chain ;
408428}
409429
410- llama_token common_sampler_sample (struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
430+ llama_token common_sampler_sample (struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
431+ llama_synchronize (ctx);
432+
433+ // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
434+ const auto tm = gsmpl->tm ();
435+
436+ llama_token id = LLAMA_TOKEN_NULL;
437+
411438 // Check if a backend sampler has already sampled a token in which case we
412439 // return that token id directly.
413440 {
414- const llama_token id = llama_get_sampled_token_ith (ctx, idx);
441+ id = llama_get_sampled_token_ith (ctx, idx);
415442
416443 if (id != LLAMA_TOKEN_NULL) {
417444 LOG_DBG (" %s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n " , __func__, id);
445+
418446 return id;
419447 }
420448 }
421449
422- llama_synchronize (ctx);
423-
424- // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
425- const auto tm = gsmpl->tm ();
426-
427450 gsmpl->set_logits (ctx, idx);
428451
429- auto & grmr = gsmpl->grmr ;
430452 auto & chain = gsmpl->chain ;
431453 auto & cur_p = gsmpl->cur_p ; // initialized by set_logits
432454
433- if (grammar_first) {
434- llama_sampler_apply (grmr, &cur_p);
435- }
436-
437455 llama_sampler_apply (chain, &cur_p);
438456
439457 GGML_ASSERT (cur_p.selected != -1 && " no selected token during sampling - check your sampling configuration" );
440458
441- const llama_token id = cur_p.data [cur_p.selected ].id ;
442-
443- if (grammar_first) {
444- return id;
445- }
446-
447- // check if it the sampled token fits the grammar
448- {
449- llama_token_data single_token_data = { id, 1 .0f , 0 .0f };
450- llama_token_data_array single_token_data_array = { &single_token_data, 1 , -1 , false };
451-
452- llama_sampler_apply (grmr, &single_token_data_array);
453-
454- const bool is_valid = single_token_data_array.data [0 ].logit != -INFINITY;
455- if (is_valid) {
456- return id;
457- }
458- }
459-
460- // resampling:
461- // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
462- gsmpl->set_logits (ctx, idx);
463-
464- llama_sampler_apply (grmr, &cur_p);
465- llama_sampler_apply (chain, &cur_p);
466-
467- GGML_ASSERT (cur_p.selected != -1 && " no selected token during re-sampling - check your sampling configuration" );
459+ id = cur_p.data [cur_p.selected ].id ;
468460
469- return cur_p. data [cur_p. selected ]. id ;
461+ return id;
470462}
471463
472- std::vector<llama_token> common_sampler_sample_and_accept_n (struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int > & idxs, const llama_tokens & draft, bool grammar_first ) {
464+ std::vector<llama_token> common_sampler_sample_and_accept_n (struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int > & idxs, const llama_tokens & draft) {
473465 GGML_ASSERT (idxs.size () == draft.size () + 1 && " idxs.size() must be draft.size() + 1" );
474466
475467 std::vector<llama_token> result;
476468 result.reserve (idxs.size ());
477469
478470 size_t i = 0 ;
479471 for (; i < draft.size (); i++) {
480- const llama_token id = common_sampler_sample (gsmpl, ctx, idxs[i], grammar_first );
472+ const llama_token id = common_sampler_sample (gsmpl, ctx, idxs[i]);
481473
482474 common_sampler_accept (gsmpl, id, true );
483475
@@ -489,7 +481,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
489481 }
490482
491483 if (i == draft.size ()) {
492- const llama_token id = common_sampler_sample (gsmpl, ctx, idxs[i], grammar_first );
484+ const llama_token id = common_sampler_sample (gsmpl, ctx, idxs[i]);
493485
494486 common_sampler_accept (gsmpl, id, true );
495487
@@ -499,13 +491,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
499491 return result;
500492}
501493
502- std::vector<llama_token> common_sampler_sample_and_accept_n (struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first ) {
494+ std::vector<llama_token> common_sampler_sample_and_accept_n (struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
503495 std::vector<int > idxs (draft.size () + 1 );
504496 for (size_t i = 0 ; i < idxs.size (); ++i) {
505497 idxs[i] = i;
506498 }
507499
508- return common_sampler_sample_and_accept_n (gsmpl, ctx, idxs, draft, grammar_first );
500+ return common_sampler_sample_and_accept_n (gsmpl, ctx, idxs, draft);
509501}
510502
511503uint32_t common_sampler_get_seed (const struct common_sampler * gsmpl) {
0 commit comments