Skip to content

Commit 6958d41

Browse files
committed
sampling : check backend support during init
1 parent 1bde707 commit 6958d41

File tree

8 files changed

+369
-178
lines changed

8 files changed

+369
-178
lines changed

common/common.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,8 +1098,7 @@ common_init_result::common_init_result(common_params & params) :
10981098

10991099
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
11001100
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
1101-
llama_sampler * backend_chain = common_sampler_chain_backend(pimpl->samplers[i].get());
1102-
pimpl->samplers_seq_config[i] = { i, backend_chain };
1101+
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
11031102
}
11041103

11051104
cparams.samplers = pimpl->samplers_seq_config.data();

common/sampling.cpp

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ struct common_sampler {
106106

107107
struct llama_sampler * grmr;
108108
struct llama_sampler * chain;
109-
struct llama_sampler * chain_backend;
110109

111110
ring_buffer<llama_token> prev;
112111

@@ -119,7 +118,6 @@ struct common_sampler {
119118

120119
llama_sampler_reset(grmr);
121120
llama_sampler_reset(chain);
122-
llama_sampler_reset(chain_backend);
123121
}
124122

125123
void set_logits(struct llama_context * ctx, int idx) {
@@ -247,13 +245,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
247245
}
248246

249247
auto * result = new common_sampler {
250-
/* .params = */ params,
251-
/* .grmr = */ grmr,
252-
/* .chain = */ llama_sampler_chain_init(lparams),
253-
/* .chain_backend = */ llama_sampler_chain_init(lparams),
254-
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
255-
/* .cur = */ {},
256-
/* .cur_p = */ {},
248+
/* .params = */ params,
249+
/* .grmr = */ grmr,
250+
/* .chain = */ llama_sampler_chain_init(lparams),
251+
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
252+
/* .cur = */ {},
253+
/* .cur_p = */ {},
257254
};
258255

259256
std::vector<llama_sampler *> samplers;
@@ -318,15 +315,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
318315
GGML_ASSERT(false && "unknown mirostat version");
319316
}
320317

321-
bool is_backend = params.backend_sampling;
322-
323-
// split in two chains: backend -> CPU
324318
for (auto * smpl : samplers) {
325-
if (!smpl->iface->backend_apply) {
326-
is_backend = false;
327-
}
328-
329-
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
319+
llama_sampler_chain_add(result->chain, smpl);
330320
}
331321

332322
return result;
@@ -336,7 +326,6 @@ void common_sampler_free(struct common_sampler * gsmpl) {
336326
if (gsmpl) {
337327
llama_sampler_free(gsmpl->grmr);
338328
llama_sampler_free(gsmpl->chain);
339-
llama_sampler_free(gsmpl->chain_backend);
340329

341330
delete gsmpl;
342331
}
@@ -360,13 +349,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
360349

361350
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
362351
return new common_sampler {
363-
/* .params = */ gsmpl->params,
364-
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
365-
/* .chain = */ llama_sampler_clone(gsmpl->chain),
366-
/* .chain_backend = */ llama_sampler_clone(gsmpl->chain_backend),
367-
/* .prev = */ gsmpl->prev,
368-
/* .cur = */ gsmpl->cur,
369-
/* .cur_p = */ gsmpl->cur_p,
352+
/* .params = */ gsmpl->params,
353+
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
354+
/* .chain = */ llama_sampler_clone(gsmpl->chain),
355+
/* .prev = */ gsmpl->prev,
356+
/* .cur = */ gsmpl->cur,
357+
/* .cur_p = */ gsmpl->cur_p,
370358
};
371359
}
372360

@@ -415,20 +403,22 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
415403
}
416404
}
417405

418-
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl) {
419-
return gsmpl->chain_backend;
406+
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
407+
return gsmpl->chain;
420408
}
421409

422410
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
423411
// Check if a backend sampler has already sampled a token in which case we
424412
// return that token id directly.
425413
{
426414
const llama_token id = llama_get_sampled_token_ith(ctx, idx);
415+
427416
if (id != LLAMA_TOKEN_NULL) {
428417
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
429418
return id;
430419
}
431420
}
421+
432422
llama_synchronize(ctx);
433423

434424
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
@@ -556,16 +546,12 @@ llama_token common_sampler_last(const struct common_sampler * gsmpl) {
556546
}
557547

558548
std::string common_sampler_print(const struct common_sampler * gsmpl) {
559-
std::string result = llama_sampler_chain_n(gsmpl->chain_backend) > 0 ? "*logits " : "logits ";
560-
561-
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain_backend); i++) {
562-
const auto * smpl = llama_sampler_chain_get(gsmpl->chain_backend, i);
563-
result += std::string("-> *") + llama_sampler_name(smpl) + " ";
564-
}
549+
std::string result = "logits ";
565550

566551
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
567552
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
568-
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
553+
result += std::string("-> ");
554+
result += std::string(llama_sampler_name(smpl)) + " ";
569555
}
570556

571557
return result;

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
4848
// arguments can be nullptr to skip printing
4949
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
5050

51-
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl);
51+
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
5252

5353
// extended sampling implementation:
5454
//

include/llama.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ extern "C" {
369369
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
370370
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
371371

372-
// backend sampler chain configuration (does not keep a reference, so make sure the caller keeps the samplers alive)
372+
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
373+
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
373374
struct llama_sampler_seq_config * samplers;
374375
size_t n_samplers;
375376
};
@@ -1193,21 +1194,27 @@ extern "C" {
11931194
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
11941195
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
11951196

1196-
// backend sampling interface
1197-
void (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
1197+
// backend sampling interface:
11981198

1199+
// return true if the backend supports all ops needed by the sampler
1200+
// note: call once per sampler
1201+
bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
1202+
1203+
// call after .backend_accept()
11991204
void (*backend_accept)(
12001205
struct llama_sampler * smpl,
12011206
struct ggml_context * ctx,
12021207
struct ggml_cgraph * gf,
12031208
struct ggml_tensor * selected_token);
12041209

1210+
// call after .backend_init()
12051211
void (*backend_apply)(
12061212
struct llama_sampler * smpl,
12071213
struct ggml_context * ctx,
12081214
struct ggml_cgraph * gf,
12091215
struct llama_sampler_data * data);
12101216

1217+
// call before .backend_apply()
12111218
void (*backend_set_input)(struct llama_sampler * smpl);
12121219
};
12131220

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ llama_context::llama_context(
6868
for (size_t i = 0; i < params.n_samplers; ++i) {
6969
const auto & config = params.samplers[i];
7070

71+
// TODO: assert this is a llama_sampler_chain instance
72+
7173
if (set_sampler(config.seq_id, config.sampler)) {
7274
const int n_samplers = llama_sampler_chain_n(config.sampler);
7375

0 commit comments

Comments
 (0)