@@ -131,11 +131,11 @@ std::string common_params_sampling::print() const {
131131 snprintf (result, sizeof (result),
132132 " \t repeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n "
133133 " \t dry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n "
134- " \t top_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n "
135- " \t mirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f" ,
134+ " \t top_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n "
135+ " \t mirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, " ,
136136 penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
137137 dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
138- top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
138+ top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
139139 mirostat, mirostat_eta, mirostat_tau);
140140
141141 return std::string (result);
@@ -162,49 +162,50 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
162162 params.logit_bias .data ()));
163163
164164 if (params.mirostat == 0 ) {
165- for (const auto & cnstr : params.samplers ) {
166- switch (cnstr) {
167- case COMMON_SAMPLER_TYPE_DRY:
168- {
169- std::vector<const char *> c_breakers;
170- c_breakers.reserve (params.dry_sequence_breakers .size ());
171- for (const auto & str : params.dry_sequence_breakers ) {
172- c_breakers.push_back (str.c_str ());
165+ if (params.top_n_sigma >= 0 ) {
166+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
167+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
168+ } else {
169+ for (const auto & cnstr : params.samplers ) {
170+ switch (cnstr) {
171+ case COMMON_SAMPLER_TYPE_DRY:
172+ {
173+ std::vector<const char *> c_breakers;
174+ c_breakers.reserve (params.dry_sequence_breakers .size ());
175+ for (const auto & str : params.dry_sequence_breakers ) {
176+ c_breakers.push_back (str.c_str ());
177+ }
178+
179+ llama_sampler_chain_add (result->chain , llama_sampler_init_dry (model, params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
173180 }
174-
175- llama_sampler_chain_add (result->chain , llama_sampler_init_dry (model, params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
176- }
177- break ;
178- case COMMON_SAMPLER_TYPE_TOP_K:
179- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
180- break ;
181- case COMMON_SAMPLER_TYPE_TOP_P:
182- llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
183- break ;
184- case COMMON_SAMPLER_TYPE_MIN_P:
185- llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
186- break ;
187- case COMMON_SAMPLER_TYPE_XTC:
188- llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
189- break ;
190- case COMMON_SAMPLER_TYPE_TYPICAL_P:
191- llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
192- break ;
193- case COMMON_SAMPLER_TYPE_TEMPERATURE:
194- llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
195- break ;
196- case COMMON_SAMPLER_TYPE_INFILL:
197- llama_sampler_chain_add (result->chain , llama_sampler_init_infill (model));
198- break ;
199- case COMMON_SAMPLER_TYPE_PENALTIES:
200- llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
201- break ;
202- case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
203- // llama_sampler_chain_add(result->chain, )
204- llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ))
205- break ;
206- default :
207- GGML_ASSERT (false && " unknown sampler type" );
181+ break ;
182+ case COMMON_SAMPLER_TYPE_TOP_K:
183+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
184+ break ;
185+ case COMMON_SAMPLER_TYPE_TOP_P:
186+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
187+ break ;
188+ case COMMON_SAMPLER_TYPE_MIN_P:
189+ llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
190+ break ;
191+ case COMMON_SAMPLER_TYPE_XTC:
192+ llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
193+ break ;
194+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
195+ llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
196+ break ;
197+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
198+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
199+ break ;
200+ case COMMON_SAMPLER_TYPE_INFILL:
201+ llama_sampler_chain_add (result->chain , llama_sampler_init_infill (model));
202+ break ;
203+ case COMMON_SAMPLER_TYPE_PENALTIES:
204+ llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
205+ break ;
206+ default :
207+ GGML_ASSERT (false && " unknown sampler type" );
208+ }
208209 }
209210 }
210211 llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
@@ -411,7 +412,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
411412 case COMMON_SAMPLER_TYPE_XTC: return ' x' ;
412413 case COMMON_SAMPLER_TYPE_INFILL: return ' i' ;
413414 case COMMON_SAMPLER_TYPE_PENALTIES: return ' e' ;
414- case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return ' s' ;
415415 default : return ' ?' ;
416416 }
417417}
@@ -427,7 +427,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
427427 case COMMON_SAMPLER_TYPE_XTC: return " xtc" ;
428428 case COMMON_SAMPLER_TYPE_INFILL: return " infill" ;
429429 case COMMON_SAMPLER_TYPE_PENALTIES: return " penalties" ;
430- case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return " top_n_sigma" ;
431430 default : return " " ;
432431 }
433432}
@@ -443,7 +442,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
443442 { " xtc" , COMMON_SAMPLER_TYPE_XTC },
444443 { " infill" , COMMON_SAMPLER_TYPE_INFILL },
445444 { " penalties" , COMMON_SAMPLER_TYPE_PENALTIES },
446- { " top_n_sigma" , COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
447445 };
448446
449447 // since samplers names are written multiple ways
@@ -458,9 +456,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
458456 { " typ" , COMMON_SAMPLER_TYPE_TYPICAL_P },
459457 { " min-p" , COMMON_SAMPLER_TYPE_MIN_P },
460458 { " temp" , COMMON_SAMPLER_TYPE_TEMPERATURE },
461- { " top-n-sigma" , COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
462- { " top-nsigma" , COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
463- { " top_nsigma" , COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
464459 };
465460
466461 std::vector<common_sampler_type> samplers;
@@ -494,7 +489,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
494489 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
495490 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
496491 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
497- { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA}
498492 };
499493
500494 std::vector<common_sampler_type> samplers;
0 commit comments