@@ -106,7 +106,6 @@ struct common_sampler {
106106
107107 struct llama_sampler * grmr;
108108 struct llama_sampler * chain;
109- struct llama_sampler * chain_backend;
110109
111110 ring_buffer<llama_token> prev;
112111
@@ -119,7 +118,6 @@ struct common_sampler {
119118
120119 llama_sampler_reset (grmr);
121120 llama_sampler_reset (chain);
122- llama_sampler_reset (chain_backend);
123121 }
124122
125123 void set_logits (struct llama_context * ctx, int idx) {
@@ -247,13 +245,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
247245 }
248246
249247 auto * result = new common_sampler {
250- /* .params = */ params,
251- /* .grmr = */ grmr,
252- /* .chain = */ llama_sampler_chain_init (lparams),
253- /* .chain_backend = */ llama_sampler_chain_init (lparams),
254- /* .prev = */ ring_buffer<llama_token>(std::max (32 , params.n_prev )),
255- /* .cur = */ {},
256- /* .cur_p = */ {},
248+ /* .params = */ params,
249+ /* .grmr = */ grmr,
250+ /* .chain = */ llama_sampler_chain_init (lparams),
251+ /* .prev = */ ring_buffer<llama_token>(std::max (32 , params.n_prev )),
252+ /* .cur = */ {},
253+ /* .cur_p = */ {},
257254 };
258255
259256 std::vector<llama_sampler *> samplers;
@@ -318,15 +315,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
318315 GGML_ASSERT (false && " unknown mirostat version" );
319316 }
320317
321- bool is_backend = params.backend_sampling ;
322-
323- // split in two chains: backend -> CPU
324318 for (auto * smpl : samplers) {
325- if (!smpl->iface ->backend_apply ) {
326- is_backend = false ;
327- }
328-
329- llama_sampler_chain_add (is_backend ? result->chain_backend : result->chain , smpl);
319+ llama_sampler_chain_add (result->chain , smpl);
330320 }
331321
332322 return result;
@@ -336,7 +326,6 @@ void common_sampler_free(struct common_sampler * gsmpl) {
336326 if (gsmpl) {
337327 llama_sampler_free (gsmpl->grmr );
338328 llama_sampler_free (gsmpl->chain );
339- llama_sampler_free (gsmpl->chain_backend );
340329
341330 delete gsmpl;
342331 }
@@ -360,13 +349,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
360349
361350struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
362351 return new common_sampler {
363- /* .params = */ gsmpl->params ,
364- /* .grmr = */ llama_sampler_clone (gsmpl->grmr ),
365- /* .chain = */ llama_sampler_clone (gsmpl->chain ),
366- /* .chain_backend = */ llama_sampler_clone (gsmpl->chain_backend ),
367- /* .prev = */ gsmpl->prev ,
368- /* .cur = */ gsmpl->cur ,
369- /* .cur_p = */ gsmpl->cur_p ,
352+ /* .params = */ gsmpl->params ,
353+ /* .grmr = */ llama_sampler_clone (gsmpl->grmr ),
354+ /* .chain = */ llama_sampler_clone (gsmpl->chain ),
355+ /* .prev = */ gsmpl->prev ,
356+ /* .cur = */ gsmpl->cur ,
357+ /* .cur_p = */ gsmpl->cur_p ,
370358 };
371359}
372360
@@ -415,20 +403,22 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
415403 }
416404}
417405
418- struct llama_sampler * common_sampler_chain_backend (const struct common_sampler * gsmpl) {
419- return gsmpl->chain_backend ;
406+ struct llama_sampler * common_sampler_get (const struct common_sampler * gsmpl) {
407+ return gsmpl->chain ;
420408}
421409
422410llama_token common_sampler_sample (struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
423411 // Check if a backend sampler has already sampled a token in which case we
424412 // return that token id directly.
425413 {
426414 const llama_token id = llama_get_sampled_token_ith (ctx, idx);
415+
427416 if (id != LLAMA_TOKEN_NULL) {
428417 LOG_DBG (" %s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n " , __func__, id);
429418 return id;
430419 }
431420 }
421+
432422 llama_synchronize (ctx);
433423
434424 // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
@@ -556,16 +546,12 @@ llama_token common_sampler_last(const struct common_sampler * gsmpl) {
556546}
557547
558548std::string common_sampler_print (const struct common_sampler * gsmpl) {
559- std::string result = llama_sampler_chain_n (gsmpl->chain_backend ) > 0 ? " *logits " : " logits " ;
560-
561- for (int i = 0 ; i < llama_sampler_chain_n (gsmpl->chain_backend ); i++) {
562- const auto * smpl = llama_sampler_chain_get (gsmpl->chain_backend , i);
563- result += std::string (" -> *" ) + llama_sampler_name (smpl) + " " ;
564- }
549+ std::string result = " logits " ;
565550
566551 for (int i = 0 ; i < llama_sampler_chain_n (gsmpl->chain ); i++) {
567552 const auto * smpl = llama_sampler_chain_get (gsmpl->chain , i);
568- result += std::string (" -> " ) + llama_sampler_name (smpl) + " " ;
553+ result += std::string (" -> " );
554+ result += std::string (llama_sampler_name (smpl)) + " " ;
569555 }
570556
571557 return result;
0 commit comments