Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
7884b0e
sampling : add support for backend sampling
danbev Nov 17, 2025
9fe9a00
llama-cli : add backend sampler configuration
danbev Nov 17, 2025
f1f3e68
server : add backend sampling options/configuration
danbev Nov 17, 2025
a3eb847
webui : add backend sampling options
danbev Nov 17, 2025
67d3b8e
ggml : add initial cumsum implementation for CUDA
danbev Nov 17, 2025
71574f9
sampling : enable all backend sampler tests
danbev Nov 18, 2025
4b52e59
graph : do not include llama-model.h
ggerganov Nov 18, 2025
82957a9
sampling : always expose sampled_ids
danbev Nov 18, 2025
311c1a3
sampling : ensure at most one output token per seq
danbev Nov 18, 2025
26be108
CUDA: Optimize argsort for gpu-based token sampling
ORippler Nov 18, 2025
0da7e7d
sampling : remove version from sampler chain
danbev Nov 19, 2025
51fee29
sampling : always populate logits for sampled probs
danbev Nov 19, 2025
7e98ebc
sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
d74eb61
squash! sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
38f408c
common : fix regression caused by extra memory allocations during sam…
ggerganov Nov 19, 2025
18ed4d8
squash! sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
0c660e7
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 20, 2025
ed4345b
squash! common : fix regression caused by extra memory allocations du…
danbev Nov 20, 2025
0d28b16
sampling : introduce sampling_info struct
danbev Nov 20, 2025
c162562
sampling : return early if backend sampling is disabled
danbev Nov 21, 2025
61ffe41
sampling : use pinned memory for backend sampling buffers
danbev Nov 21, 2025
9b24393
common, tools : refactor model loading to support backend samplers
danbev Nov 21, 2025
79b8cf2
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 21, 2025
65500d0
sampling : add stride variable for clarity
danbev Nov 23, 2025
ae23d2d
sampling: clarify candidate ids usage in comments
danbev Nov 23, 2025
9e273f7
sampling : fix copying both sampled tokens and logits/probs from backend
danbev Nov 23, 2025
50d21aa
tests : cleanup test-backend-sampler.cpp
danbev Nov 24, 2025
7816f0b
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 24, 2025
d88ba18
common : remove build-info.cpp from commit [no ci]
danbev Nov 24, 2025
4a90583
sampling : cleanup and clarify output_reserve
danbev Nov 24, 2025
8eb9b47
sampling : remove redundant checks for stride and size [no ci]
danbev Nov 24, 2025
25f3380
sampling : add debug log when backend sampler selects token
danbev Nov 24, 2025
d0bea21
examples : update batched to use backend sampling
danbev Nov 24, 2025
e2d4f08
llama-cli : fix dangling reference to sampler config
ggerganov Nov 24, 2025
b26c706
common : initialize backend samplers
ggerganov Nov 24, 2025
883a870
samplers : add missing cont
ggerganov Nov 24, 2025
a02adf4
sampling : add assertions for contiguous tensors in async copy functions
danbev Nov 24, 2025
2b4c792
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 25, 2025
0f17ccd
examples : add info about hybrid sampling in batched [no ci]
danbev Nov 25, 2025
53dca56
Merge remote-tracking branch 'upstream/master' into gpu-sampling
danbev Nov 25, 2025
9e5e09d
sampling : remove backend-dist option (wip)
danbev Nov 25, 2025
ec047e1
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 25, 2025
f23b306
CUDA: Add top-k implementation
ORippler Nov 21, 2025
b45d504
sampling : add min-p backend sampler
danbev Nov 26, 2025
4fea191
Use `FetchContent` over CPM as it's bundled with CMake
ORippler Nov 26, 2025
0f7805f
common : add get_active_samplers function to check enabled samplers
danbev Nov 26, 2025
90a3aff
cuda : fix editorconfig-checker warning
danbev Nov 26, 2025
7c2bfb3
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 26, 2025
d9d7361
sampling : use argmax for min-p sampling
danbev Nov 27, 2025
51107a0
sampling : fix temperature check to allow zero temperature
danbev Nov 27, 2025
5ea3be2
cuda : fix top-k compilation when CUB is unavailable
danbev Nov 27, 2025
172208a
sampling : add comments about backend sampler [no ci]
danbev Nov 27, 2025
e9d0709
sampling : remove backend sampling chain from common_sampler
danbev Nov 27, 2025
f9889cf
Fix top-k comp & behavior for non-CUB path
ORippler Nov 27, 2025
74be332
sampling : support intermixed backend/cpu samplers
danbev Nov 27, 2025
9ad6522
squash! sampling : support intermixed backend/cpu samplers
danbev Nov 28, 2025
459b7ae
squash! sampling : support intermixed backend/cpu samplers
danbev Nov 28, 2025
117e207
refactor : simplify and improve memory management
ggerganov Nov 28, 2025
333da80
Add initial version for top-p sampling
ORippler Nov 28, 2025
8cac9de
sampling : use logits directly for min-p filtering
danbev Nov 28, 2025
2464d1b
sampling : simplify
ggerganov Nov 28, 2025
fbc8f49
llama : simplify
ggerganov Nov 29, 2025
9028ebf
llama : cleanup + naming
ggerganov Nov 29, 2025
d8d98bb
Merge branch 'master' into HEAD
ggerganov Nov 29, 2025
ff7b0bf
llama : call backend_init once
ggerganov Nov 29, 2025
467746e
Merge branch 'master' into HEAD
ggerganov Nov 29, 2025
1760bd6
llama : reserve graphs with samplers
ggerganov Nov 29, 2025
c187003
llama : naming
ggerganov Nov 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ static bool sampler_backend_supported(enum common_sampler_type type) {
switch (type) {
case COMMON_SAMPLER_TYPE_TOP_K:
case COMMON_SAMPLER_TYPE_TEMPERATURE:
case COMMON_SAMPLER_TYPE_MIN_P:
return true;
default:
return false;
Expand Down Expand Up @@ -325,6 +326,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
backend_idx++;
break;
case COMMON_SAMPLER_TYPE_MIN_P:
if (params.min_p > 0.0f) {
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_min_p(params.min_p));
}
backend_idx++;
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
Expand Down Expand Up @@ -468,6 +475,12 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo
}
backend_idx++;
break;
case COMMON_SAMPLER_TYPE_MIN_P:
if (params.min_p > 0.0f) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p));
}
backend_idx++;
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,9 @@ extern "C" {
/// @details Distribution sampling on backend - final sampling step that selects a token
LLAMA_API struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed);

/// @details Min-P filtering on backend - filter tokens with a probability less than p times the maximum probability.
LLAMA_API struct llama_sampler * llama_sampler_backend_init_min_p(float p);

// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

Expand Down
115 changes: 115 additions & 0 deletions src/llama-backend-sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,118 @@ struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab,

return sampler;
}

struct llama_sampler_backend_min_p_ctx {
float p;

// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};

static void llama_sampler_backend_min_p_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
}

static void llama_sampler_backend_min_p_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);

auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;

struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits);
ggml_set_name(softmax, "softmax");

// Get the sorted indices of the softmax probabilities in descending order.
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC);
ggml_set_name(sorted_idx, "sorted_idx");

// Reshape into a row vector.
struct ggml_tensor * softmax_rows = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
ggml_set_name(softmax_rows, "softmax_rows");

// Get the sorted probabilities using the sorted indices so that we can get
// the max probability value, which will be the first entry in sorted_probs.
struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_rows, sorted_idx);
ggml_set_name(sorted_probs, "sorted_probs");

// Get the max probability value from sorted_probs.
struct ggml_tensor * p_max = ggml_view_1d(ctx, sorted_probs, 1, 0);
ggml_set_name(p_max, "p_max");

// Calculate the threshold value.
struct ggml_tensor * threshold = ggml_scale(ctx, p_max, sctx->p);
ggml_set_name(threshold, "min_p_threshold");

// Broadcast the threshold to match the shape of softmax.
struct ggml_tensor * threshold_b = ggml_repeat(ctx, threshold, softmax);
ggml_set_name(threshold_b, "min_p_threshold_b");

// Subtract the threshold from softmax probabilities.
struct ggml_tensor * sub = ggml_sub(ctx, softmax, threshold_b);

// Create a mask where probabilities below the threshold are 0 (discard),
// and others are 1 (keep).
struct ggml_tensor * mask = ggml_step(ctx, sub);
ggml_set_name(mask, "min_p_mask");

// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// min_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(min_p_bias, "min_p_bias");

// Add the min_p bias to the logits.
ggml_data->logits = ggml_add(ctx, ggml_data->logits, min_p_bias);
ggml_set_name(ggml_data->logits, "min_p_logits");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use the get_rows to return only the values where mask == 1?


ggml_build_forward_expand(gf, ggml_data->logits);
}

static const char * llama_sampler_backend_min_p_name(const struct llama_sampler *) {
return "backend-min-p";
}

static void llama_sampler_backend_min_p_free(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
delete sctx;
}

static struct llama_sampler * llama_sampler_backend_min_p_clone(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
return llama_sampler_backend_init_min_p(sctx->p);
}

struct llama_sampler * llama_sampler_backend_init_min_p(float p) {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_min_p_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_min_p_clone,
/*.free =*/ llama_sampler_backend_min_p_free,
/*.apply_ggml =*/ llama_sampler_backend_min_p_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ nullptr,
/*.init_ggml =*/ llama_sampler_backend_min_p_init_ggml,
};

auto * sctx = new llama_sampler_backend_min_p_ctx {
/*.p =*/ p,
/*.device =*/ nullptr,
};

auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ sctx,
};

return sampler;
}
57 changes: 57 additions & 0 deletions tests/test-backend-sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,62 @@ static void test_backend_temp_sampling(const char * model_path) {

}

static void test_backend_min_p_sampling(const char * model_path) {
test_model_context test_ctx;

const int seq_id = 0;
const float p = 0.1;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_min_p(p));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};

if (!test_ctx.setup(model_path, backend_sampler_configs)) {
return;
}

if (!test_ctx.decode({{seq_id, "Hello"}})) {
return;
}

int32_t batch_idx = test_ctx.idx_for_seq(seq_id);

float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx);
uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx);

// Print the logits that are above the min-p threshold
std::vector<float> filtered_logits;
for (size_t i = 0; i < n_logits; ++i) {
if (logits[i] > -1e9f) {
filtered_logits.push_back(logits[i]);
//printf("min_p logit[%zu] = %.6f\n", i, logits[i]);
}
}
GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);

// Sample using CPU sampler for verification to inspect they are reasonable
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
llama_sampler_chain_add(chain, llama_sampler_init_dist(88));

llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false);
printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);

// Decode and sampler 10 more tokens
for (int i = 0; i < 10; i++) {
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx);
printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
test_ctx.decode_token(token, 0);
}

printf("min-p sampling test PASSED\n");

llama_sampler_free(chain);
}

static void test_backend_multi_sequence_sampling(const char * model_path) {
test_model_context test_ctx;

Expand Down Expand Up @@ -772,6 +828,7 @@ static const backend_test_case BACKEND_TESTS[] = {
{ "set_sampler", test_backend_set_sampler, true },
{ "max_outputs", test_backend_max_outputs, true },
{ "mixed", test_backend_mixed_sampling, true },
{ "min_p", test_backend_min_p_sampling, true },
};

struct backend_cli_args {
Expand Down
Loading