Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
7884b0e
sampling : add support for backend sampling
danbev Nov 17, 2025
9fe9a00
llama-cli : add backend sampler configuration
danbev Nov 17, 2025
f1f3e68
server : add backend sampling options/configuration
danbev Nov 17, 2025
a3eb847
webui : add backend sampling options
danbev Nov 17, 2025
67d3b8e
ggml : add initial cumsum implementation for CUDA
danbev Nov 17, 2025
71574f9
sampling : enable all backend sampler tests
danbev Nov 18, 2025
4b52e59
graph : do not include llama-model.h
ggerganov Nov 18, 2025
82957a9
sampling : always expose sampled_ids
danbev Nov 18, 2025
311c1a3
sampling : ensure at most one output token per seq
danbev Nov 18, 2025
26be108
CUDA: Optimize argsort for gpu-based token sampling
ORippler Nov 18, 2025
0da7e7d
sampling : remove version from sampler chain
danbev Nov 19, 2025
51fee29
sampling : always populate logits for sampled probs
danbev Nov 19, 2025
7e98ebc
sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
d74eb61
squash! sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
38f408c
common : fix regression caused by extra memory allocations during sam…
ggerganov Nov 19, 2025
18ed4d8
squash! sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
0c660e7
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 20, 2025
ed4345b
squash! common : fix regression caused by extra memory allocations du…
danbev Nov 20, 2025
0d28b16
sampling : introduce sampling_info struct
danbev Nov 20, 2025
c162562
sampling : return early if backend sampling is disabled
danbev Nov 21, 2025
61ffe41
sampling : use pinned memory for backend sampling buffers
danbev Nov 21, 2025
9b24393
common, tools : refactor model loading to support backend samplers
danbev Nov 21, 2025
79b8cf2
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 21, 2025
65500d0
sampling : add stride variable for clarity
danbev Nov 23, 2025
ae23d2d
sampling: clarify candidate ids usage in comments
danbev Nov 23, 2025
9e273f7
sampling : fix copying both sampled tokens and logits/probs from backend
danbev Nov 23, 2025
50d21aa
tests : cleanup test-backend-sampler.cpp
danbev Nov 24, 2025
7816f0b
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 24, 2025
d88ba18
common : remove build-info.cpp from commit [no ci]
danbev Nov 24, 2025
4a90583
sampling : cleanup and clarify output_reserve
danbev Nov 24, 2025
8eb9b47
sampling : remove redundant checks for stride and size [no ci]
danbev Nov 24, 2025
25f3380
sampling : add debug log when backend sampler selects token
danbev Nov 24, 2025
d0bea21
examples : update batched to use backend sampling
danbev Nov 24, 2025
e2d4f08
llama-cli : fix dangling reference to sampler config
ggerganov Nov 24, 2025
b26c706
common : initialize backend samplers
ggerganov Nov 24, 2025
883a870
samplers : add missing cont
ggerganov Nov 24, 2025
a02adf4
sampling : add assertions for contiguous tensors in async copy functions
danbev Nov 24, 2025
2b4c792
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 25, 2025
0f17ccd
examples : add info about hybrid sampling in batched [no ci]
danbev Nov 25, 2025
53dca56
Merge remote-tracking branch 'upstream/master' into gpu-sampling
danbev Nov 25, 2025
9e5e09d
sampling : remove backend-dist option (wip)
danbev Nov 25, 2025
ec047e1
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 25, 2025
f23b306
CUDA: Add top-k implementation
ORippler Nov 21, 2025
b45d504
sampling : add min-p backend sampler
danbev Nov 26, 2025
4fea191
Use `FetchContent` over CPM as it's bundled with CMake
ORippler Nov 26, 2025
0f7805f
common : add get_active_samplers function to check enabled samplers
danbev Nov 26, 2025
90a3aff
cuda : fix editorconfig-checker warning
danbev Nov 26, 2025
7c2bfb3
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 26, 2025
d9d7361
sampling : use argmax for min-p sampling
danbev Nov 27, 2025
51107a0
sampling : fix temperature check to allow zero temperature
danbev Nov 27, 2025
5ea3be2
cuda : fix top-k compilation when CUB is unavailable
danbev Nov 27, 2025
172208a
sampling : add comments about backend sampler [no ci]
danbev Nov 27, 2025
e9d0709
sampling : remove backend sampling chain from common_sampler
danbev Nov 27, 2025
f9889cf
Fix top-k comp & behavior for non-CUB path
ORippler Nov 27, 2025
74be332
sampling : support intermixed backend/cpu samplers
danbev Nov 27, 2025
9ad6522
squash! sampling : support intermixed backend/cpu samplers
danbev Nov 28, 2025
459b7ae
squash! sampling : support intermixed backend/cpu samplers
danbev Nov 28, 2025
117e207
refactor : simplify and improve memory management
ggerganov Nov 28, 2025
333da80
Add initial version for top-p sampling
ORippler Nov 28, 2025
8cac9de
sampling : use logits directly for min-p filtering
danbev Nov 28, 2025
2464d1b
sampling : simplify
ggerganov Nov 28, 2025
fbc8f49
llama : simplify
ggerganov Nov 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
}
).set_sparam());
add_opt(common_arg(
{"--backend-sampling"},
"enable backend sampling (default: disabled)",
[](common_params & params) {
params.sampling.backend_sampling = true;
}
).set_sparam());
add_opt(common_arg(
{"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified",
Expand Down
198 changes: 133 additions & 65 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -950,31 +950,40 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
// Model utils
//

static inline void common_init_sampler_from_model(
// TODO: move to common/sampling
static void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {

const uint64_t config = sparams.user_sampling_config;

auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
if (config & user_config) {
return;
}

char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
if (end && end != buf) {
dst = v;
}
}
};

auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
if (config & user_config) {
return;
}

char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
if (end && end != buf) {
dst = v;
}
}
};

Expand Down Expand Up @@ -1002,31 +1011,130 @@ static inline void common_init_sampler_from_model(
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}

struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
struct common_init_result::impl {
impl() = default;
~impl() = default;

llama_model_ptr model;
llama_context_ptr context;

std::vector<llama_adapter_lora_ptr> lora;

std::vector<common_sampler_ptr> samplers;
std::vector<llama_sampler_seq_config> samplers_seq_config;
};

common_init_result::common_init_result(common_params & params) :
pimpl(new impl{}) {
const auto mparams = common_model_params_to_llama(params);

llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return iparams;
return;
}

common_init_sampler_from_model(model, params.sampling);
pimpl->model.reset(model);

const llama_vocab * vocab = llama_model_get_vocab(model);

// updates params.sampling
// TODO: fix naming
common_init_sampler_from_model(model, params.sampling);

auto cparams = common_context_params_to_llama(params);

if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}

// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}

if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}

//if (params.sampling.penalty_last_n == -1) {
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
//}

//if (params.sampling.dry_penalty_last_n == -1) {
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//}

// init the backend samplers as part of the context creation
pimpl->samplers.resize(cparams.n_seq_max);
pimpl->samplers_seq_config.resize(cparams.n_seq_max);

for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
llama_sampler * backend_chain = common_sampler_chain_backend(pimpl->samplers[i].get());
pimpl->samplers_seq_config[i] = { i, backend_chain };
}

cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();

llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return;
}

pimpl->context.reset(lctx);
}

llama_model * common_init_result::model() {
return pimpl->model.get();
}

llama_context * common_init_result::context() {
return pimpl->context.get();
}

common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get();
}

std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}

void common_init_result::free_context() {
pimpl->context.reset();
}

common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));

llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
llama_model_free(model);
return iparams;
return res;
}

llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return res;
}

const llama_vocab * vocab = llama_model_get_vocab(model);

if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
Expand All @@ -1038,10 +1146,7 @@ struct common_init_result common_init_from_params(common_params & params) {

const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_model_free(model);

return iparams;
return res;
}

int err = llama_apply_adapter_cvec(
Expand All @@ -1052,10 +1157,7 @@ struct common_init_result common_init_from_params(common_params & params) {
params.control_vector_layer_start,
params.control_vector_layer_end);
if (err) {
llama_free(lctx);
llama_model_free(model);

return iparams;
return res;
}
}

Expand All @@ -1079,10 +1181,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}

if (!ok) {
llama_free(lctx);
llama_model_free(model);

return iparams;
return res;
}
}

Expand All @@ -1092,9 +1191,7 @@ struct common_init_result common_init_from_params(common_params & params) {
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}

char buf[1024];
Expand All @@ -1103,43 +1200,13 @@ struct common_init_result common_init_from_params(common_params & params) {
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
}

if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters);
}

if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}

// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}

if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}

if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
}

if (params.sampling.dry_penalty_last_n == -1) {
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
}

if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);

Expand Down Expand Up @@ -1178,12 +1245,11 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_set_warmup(lctx, false);
}

iparams.model.reset(model);
iparams.context.reset(lctx);

return iparams;
return res;
}

common_init_result::~common_init_result() = default;

std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
Expand All @@ -1192,7 +1258,9 @@ std::string get_model_endpoint() {
std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) {
model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') model_endpoint += '/';
if (model_endpoint.back() != '/') {
model_endpoint += '/';
}
}
return model_endpoint;
}
Expand Down
35 changes: 30 additions & 5 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ struct common_params_sampling {

std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY


std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY,
Expand All @@ -213,6 +212,18 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens

bool backend_sampling = false; // enable backend sampling

bool has_logit_bias() const {
return !logit_bias.empty();
}

bool is_disabled(enum common_sampler_type type) const;

// remove disabled samplers
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
void filter_disabled();

// print the parameters into a string
std::string print() const;
};
Expand Down Expand Up @@ -648,15 +659,29 @@ std::vector<common_file_info> fs_list_files(const std::string & path);
// Model utils
//

struct common_sampler;

// note: defines object's lifetime
struct common_init_result {
llama_model_ptr model;
llama_context_ptr context;
common_init_result(common_params & params);
~common_init_result();

llama_model * model();
llama_context * context();
common_sampler * sampler(llama_seq_id seq_id);

std::vector<llama_adapter_lora_ptr> & lora();

std::vector<llama_adapter_lora_ptr> lora;
void free_context();

private:
struct impl;
std::unique_ptr<impl> pimpl;
};

struct common_init_result common_init_from_params(common_params & params);
using common_init_result_ptr = std::unique_ptr<common_init_result>;

common_init_result_ptr common_init_from_params(common_params & params);

struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
Expand Down
Loading