Skip to content

Commit da038d8

Browse files
committed
completed top nsigma sampler implementation
1 parent ddc3c22 commit da038d8

File tree

5 files changed

+117
-84
lines changed

5 files changed

+117
-84
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
899899
params.sampling.min_p = std::stof(value);
900900
}
901901
).set_sparam());
902+
add_opt(common_arg(
903+
{"--top-nsigma"}, "N",
904+
string_format("top-n-sigma sampling (default: %d, -1 = disabled)", params.sampling.top_n_sigma),
905+
[](common_params & params, const std::string & value) {
906+
params.sampling.top_n_sigma = std::stof(value);
907+
}
908+
).set_sparam());
902909
add_opt(common_arg(
903910
{"--xtc-probability"}, "N",
904911
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),

common/common.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ enum common_sampler_type {
9595
COMMON_SAMPLER_TYPE_XTC = 8,
9696
COMMON_SAMPLER_TYPE_INFILL = 9,
9797
COMMON_SAMPLER_TYPE_PENALTIES = 10,
98-
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11
9998
};
10099

101100
// dimensionality reduction methods, used by cvector-generator
@@ -129,7 +128,7 @@ struct common_params_sampling {
129128
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
130129
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
131130
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
132-
int32_t top_n_sigma = 2;
131+
int32_t top_n_sigma = -1; // -1 = disabled
133132
float mirostat_tau = 5.00f; // target entropy
134133
float mirostat_eta = 0.10f; // learning rate
135134
bool ignore_eos = false;
@@ -148,7 +147,6 @@ struct common_params_sampling {
148147
COMMON_SAMPLER_TYPE_MIN_P,
149148
COMMON_SAMPLER_TYPE_XTC,
150149
COMMON_SAMPLER_TYPE_TEMPERATURE,
151-
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
152150
};
153151

154152
std::string grammar; // optional BNF-like grammar to constrain sampling

common/sampling.cpp

Lines changed: 46 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,11 @@ std::string common_params_sampling::print() const {
131131
snprintf(result, sizeof(result),
132132
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
133133
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
134-
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
135-
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
134+
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n"
135+
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f,",
136136
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
137137
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
138-
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
138+
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
139139
mirostat, mirostat_eta, mirostat_tau);
140140

141141
return std::string(result);
@@ -162,49 +162,50 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
162162
params.logit_bias.data()));
163163

164164
if (params.mirostat == 0) {
165-
for (const auto & cnstr : params.samplers) {
166-
switch (cnstr) {
167-
case COMMON_SAMPLER_TYPE_DRY:
168-
{
169-
std::vector<const char *> c_breakers;
170-
c_breakers.reserve(params.dry_sequence_breakers.size());
171-
for (const auto & str : params.dry_sequence_breakers) {
172-
c_breakers.push_back(str.c_str());
165+
if(params.top_n_sigma >= 0) {
166+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
167+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma));
168+
} else {
169+
for (const auto & cnstr : params.samplers) {
170+
switch (cnstr) {
171+
case COMMON_SAMPLER_TYPE_DRY:
172+
{
173+
std::vector<const char *> c_breakers;
174+
c_breakers.reserve(params.dry_sequence_breakers.size());
175+
for (const auto & str : params.dry_sequence_breakers) {
176+
c_breakers.push_back(str.c_str());
177+
}
178+
179+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
173180
}
174-
175-
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
176-
}
177-
break;
178-
case COMMON_SAMPLER_TYPE_TOP_K:
179-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
180-
break;
181-
case COMMON_SAMPLER_TYPE_TOP_P:
182-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
183-
break;
184-
case COMMON_SAMPLER_TYPE_MIN_P:
185-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
186-
break;
187-
case COMMON_SAMPLER_TYPE_XTC:
188-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
189-
break;
190-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
191-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
192-
break;
193-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
194-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
195-
break;
196-
case COMMON_SAMPLER_TYPE_INFILL:
197-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
198-
break;
199-
case COMMON_SAMPLER_TYPE_PENALTIES:
200-
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
201-
break;
202-
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
203-
// llama_sampler_chain_add(result->chain, )
204-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma))
205-
break;
206-
default:
207-
GGML_ASSERT(false && "unknown sampler type");
181+
break;
182+
case COMMON_SAMPLER_TYPE_TOP_K:
183+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
184+
break;
185+
case COMMON_SAMPLER_TYPE_TOP_P:
186+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
187+
break;
188+
case COMMON_SAMPLER_TYPE_MIN_P:
189+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
190+
break;
191+
case COMMON_SAMPLER_TYPE_XTC:
192+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
193+
break;
194+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
195+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
196+
break;
197+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
198+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
199+
break;
200+
case COMMON_SAMPLER_TYPE_INFILL:
201+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
202+
break;
203+
case COMMON_SAMPLER_TYPE_PENALTIES:
204+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
205+
break;
206+
default:
207+
GGML_ASSERT(false && "unknown sampler type");
208+
}
208209
}
209210
}
210211
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
@@ -411,7 +412,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
411412
case COMMON_SAMPLER_TYPE_XTC: return 'x';
412413
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
413414
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
414-
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
415415
default : return '?';
416416
}
417417
}
@@ -427,7 +427,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
427427
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
428428
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
429429
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
430-
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
431430
default : return "";
432431
}
433432
}
@@ -443,7 +442,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
443442
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
444443
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
445444
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
446-
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
447445
};
448446

449447
// since samplers names are written multiple ways
@@ -458,9 +456,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
458456
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
459457
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
460458
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
461-
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
462-
{ "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
463-
{ "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
464459
};
465460

466461
std::vector<common_sampler_type> samplers;
@@ -494,7 +489,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
494489
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
495490
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
496491
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
497-
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA}
498492
};
499493

500494
std::vector<common_sampler_type> samplers;

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,9 @@ extern "C" {
11331133
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
11341134
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
11351135

1136+
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
1137+
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n);
1138+
11361139
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
11371140
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
11381141
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.

src/llama-sampling.cpp

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
301301
cur_p->size = k;
302302
}
303303

304+
304305
static uint32_t get_rng_seed(uint32_t seed) {
305306
if (seed == LLAMA_DEFAULT_SEED) {
306307
// use system clock if std::random_device is not a true RNG
@@ -1657,35 +1658,65 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
16571658

16581659
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
16591660
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1660-
llama_sampler_top_n_sigma_impl(cur_p, ctx->n);
1661-
}
1662-
1663-
// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
1664-
// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
1665-
// return llama_sampler_init_top_k(ctx->k);
1666-
// }
1667-
1668-
// static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
1669-
// delete (llama_sampler_top_k *) smpl->ctx;
1670-
// }
1671-
1672-
// static struct llama_sampler_i llama_sampler_top_k_i = {
1673-
// /* .name = */ llama_sampler_top_k_name,
1674-
// /* .accept = */ nullptr,
1675-
// /* .apply = */ llama_sampler_top_k_apply,
1676-
// /* .reset = */ nullptr,
1677-
// /* .clone = */ llama_sampler_top_k_clone,
1678-
// /* .free = */ llama_sampler_top_k_free,
1679-
// };
1680-
1681-
// struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
1682-
// return new llama_sampler {
1683-
// /* .iface = */ &llama_sampler_top_k_i,
1684-
// /* .ctx = */ new llama_sampler_top_k {
1685-
// /* .k = */ k,
1686-
// },
1687-
// };
1688-
// }
1661+
// 1. Find max logit: M
1662+
// 2. Find standard deviation of logits: sig
1663+
// 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0
1664+
// 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf
1665+
// 5. p = softmax(l)
1666+
1667+
// find max logit and calculate mean
1668+
int32_t max = cur_p->data[0].logit;
1669+
int32_t logits_sum = 0;
1670+
for (size_t i = 0; i < cur_p->size; ++i) {
1671+
if(cur_p->data[i].logit > max){
1672+
max = cur_p->data[i].logit;
1673+
}
1674+
logits_sum += cur_p->data[i].logit;
1675+
}
1676+
int32_t mean = logits_sum/cur_p->size;
1677+
1678+
// calculate standard deviation
1679+
int32_t acc = 0;
1680+
for(size_t i = 0; i < cur_p->size; ++i){
1681+
acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean);
1682+
}
1683+
int32_t std = sqrt(acc/cur_p->size);
1684+
1685+
//apply mask
1686+
for(size_t i = 0; i < cur_p->size; ++i){
1687+
if(cur_p->data[i].logit < max - (ctx->n * std)) {
1688+
cur_p->data[i].logit = -INFINITY;
1689+
}
1690+
}
1691+
llama_sampler_softmax_impl(cur_p);
1692+
}
1693+
1694+
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl){
1695+
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
1696+
return llama_sampler_init_top_n_sigma(ctx->n);
1697+
}
1698+
1699+
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1700+
delete (llama_sampler_top_n_sigma *) smpl->ctx;
1701+
}
1702+
1703+
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1704+
/* .name = */ llama_sampler_top_n_sigma_name,
1705+
/* .accept = */ nullptr,
1706+
/* .apply = */ llama_sampler_top_n_sigma_apply,
1707+
/* .reset = */ nullptr,
1708+
/* .clone = */ llama_sampler_top_n_sigma_clone,
1709+
/* .free = */ llama_sampler_top_n_sigma_free,
1710+
};
1711+
1712+
struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n) {
1713+
return new llama_sampler {
1714+
/* .iface = */ &llama_sampler_top_n_sigma_i,
1715+
/* .ctx = */ new llama_sampler_top_n_sigma {
1716+
/* .n = */ n,
1717+
},
1718+
};
1719+
}
16891720

16901721
// DRY
16911722

0 commit comments

Comments
 (0)