Skip to content

Commit d35f62e

Browse files
EricLBuehlerp-e-w
andauthored
Implement DRY penalty (#637)
* Implement dry penalty * Add dry sampling params to requests * Handle it * Clippy * Review: "Implement DRY penalty" (#645) * Silence bogus Clippy warning Clippy's suggestion cannot be implemented because of borrowing issues * Get rid of unnecessary type annotations Interesting that Clippy doesn't catch this * Store default sequence breakers in a slice It's nicer when the length is not hardcoded * Make default sequence breakers private No need to leak this as it's not used elsewhere * Limit match length Avoids quadratic runtime and potential DoS with adversarial inputs Ref oobabooga/text-generation-webui#6047 * "Fix" sequence breaker tokenization Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result. See LostRuins/koboldcpp#982 for a correct solution to this problem. * Nicer * Even better * Complete merge * Fix saturating sub * Handle when no context * Make context the entire sequence and refactor * Remove slicing for all * Fix the bug with penalty Credit to @p-e-w for finding this! Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com> * Add custom logits processor API (#702) * Add custom logits processor api * Typos * Nicer interface and update example * Fix doctest * Update docs * Update exports * Add Gemma 2 PagedAttention support (#704) * Add gemma2 paged attn support * Non cuda support? * Remove error * It works * Faster RmsNorm in gemma/gemma2 (#703) * Fix bug in metal isq (#706) * Support GGUF BF16 tensors (#691) * Support GGUF bf16 tensors * Fix loading of bf16 ggml tensor * Fix dequant of bf16 * Use merged rev * Softcapping, real batching + sliding window support for Flash Attention (#707) * Flash attention varlen kind of works * Seems to work * Now it's nice * Sliding window support and clippy * Remove warning * Support smollm * Update rev to match merged * Remove some usages of 'pub' in models (#708) * Support the Phi 3.5 V model (#710) * Update image_seq_len * Update the examples * Format * Implement the Phi 3.5 MoE model (#709) * Copy the model * Add most of it * Add the blocksparse moe parts * Clippy * Fix mscales * A batch of fixes * Correctly cast it * Handle isq on gate * Even more progress * Runs now * Clippy * Fix to use layernorm * Remove unused * Add docs * Add more docs * Apply review comments * Update readme --------- Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com>
1 parent 91a423e commit d35f62e

File tree

12 files changed

+335
-31
lines changed

12 files changed

+335
-31
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,12 @@ Mistal.rs supports several model categories:
101101
- [Paper](https://arxiv.org/abs/2405.19076)
102102
- [Docs](docs/ANYMOE.md)
103103
- PagedAttention: [docs](docs/PAGED_ATTENTION.md)
104-
- Various sampling techniques:
104+
- Various sampling and penalty techniques:
105105
- Top K
106106
- Top P
107107
- Min P
108+
- [Dry Penalty](https://github.com/oobabooga/text-generation-webui/pull/5677)
109+
- Frequency and Presence Penalty
108110
- Please suggest more by raising an issue!
109111
- Tool calling: [docs](docs/TOOL_CALLING.md)
110112
- Prompt chunking (only without PagedAttention for now): handle larger prompts where the activation size would cause an OOM by sending chunks

mistralrs-bench/src/main.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ use clap::Parser;
33
use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
44
use mistralrs_core::{
55
initialize_logging, paged_attn_supported, Constraint, DefaultSchedulerMethod,
6-
DeviceLayerMapMetadata, DeviceMapMetadata, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
7-
MistralRsBuilder, ModelDType, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
8-
RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
6+
DeviceLayerMapMetadata, DeviceMapMetadata, DrySamplingParams, Loader, LoaderBuilder,
7+
MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelDType, ModelSelected, NormalRequest,
8+
PagedAttentionConfig, Request, RequestMessage, Response, SamplingParams, SchedulerConfig,
9+
TokenSource, Usage,
910
};
1011
use std::sync::Arc;
1112
use std::{fmt::Display, num::NonZeroUsize};
@@ -64,6 +65,7 @@ fn run_bench(
6465
stop_toks: None,
6566
logits_bias: None,
6667
n_choices: 1,
68+
dry_params: Some(DrySamplingParams::default()),
6769
};
6870
let sender = mistralrs.get_sender().unwrap();
6971
let (tx, mut rx) = channel(10_000);
@@ -227,6 +229,7 @@ fn warmup_run(mistralrs: Arc<MistralRs>) {
227229
stop_toks: None,
228230
logits_bias: None,
229231
n_choices: 1,
232+
dry_params: Some(DrySamplingParams::default()),
230233
};
231234
let sender = mistralrs.get_sender().unwrap();
232235
let (tx, mut rx) = channel(10_000);

mistralrs-core/src/engine/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,11 +647,13 @@ impl Engine {
647647
tokenizer,
648648
request.sampling_params.frequency_penalty,
649649
request.sampling_params.presence_penalty,
650+
request.sampling_params.dry_params,
650651
topk,
651652
topp,
652653
minp,
653654
request.logits_processors.unwrap_or_default(),
654655
);
656+
let sampler = handle_seq_error!(sampler, request.response);
655657

656658
if request.sampling_params.n_choices == 0 {
657659
request

mistralrs-core/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ pub use pipeline::{
7878
pub use request::{Constraint, MessageContent, NormalRequest, Request, RequestMessage};
7979
pub use response::Response;
8080
pub use response::*;
81-
pub use sampler::{CustomLogitsProcessor, SamplingParams, StopTokens, TopLogprob};
81+
pub use sampler::{
82+
CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
83+
};
8284
pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
8385
use serde::Serialize;
8486
use tokio::runtime::Runtime;

mistralrs-core/src/pipeline/amoe.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,19 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
354354

355355
// Create several dummy objects for the sequences. No custom logits processors.
356356
let (dummy_sender, _) = tokio::sync::mpsc::channel(10000);
357-
let dummy_sampler =
358-
Sampler::new(None, 0, tokenizer.clone(), None, None, -1, 0.0, 0.0, vec![]);
357+
let dummy_sampler = Sampler::new(
358+
None,
359+
0,
360+
tokenizer.clone(),
361+
None,
362+
None,
363+
None,
364+
-1,
365+
0.0,
366+
0.0,
367+
vec![],
368+
)
369+
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
359370

360371
let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
361372
1, false, false, 0,

mistralrs-core/src/sampler.rs

Lines changed: 187 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
22

33
use std::{
4-
collections::HashMap,
4+
collections::{HashMap, HashSet},
55
iter::zip,
66
sync::{Arc, Mutex},
77
};
@@ -12,9 +12,14 @@ use pyo3::pyclass;
1212

1313
use rand::distributions::{Distribution, WeightedIndex};
1414
use rand_isaac::Isaac64Rng;
15+
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
1516
use serde::{Deserialize, Serialize};
17+
use std::sync::LazyLock;
1618
use tokenizers::Tokenizer;
1719

20+
static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
21+
LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22+
1823
#[derive(Clone, Debug)]
1924
/// Stop sequences or ids.
2025
pub enum StopTokens {
@@ -36,6 +41,7 @@ pub struct SamplingParams {
3641
pub max_len: Option<usize>,
3742
pub logits_bias: Option<HashMap<u32, f32>>,
3843
pub n_choices: usize,
44+
pub dry_params: Option<DrySamplingParams>,
3945
}
4046

4147
impl Default for SamplingParams {
@@ -52,10 +58,91 @@ impl Default for SamplingParams {
5258
max_len: None,
5359
logits_bias: None,
5460
n_choices: 1,
61+
dry_params: None,
5562
}
5663
}
5764
}
5865

66+
#[derive(Clone, Debug)]
67+
pub struct DrySamplingParams {
68+
pub sequence_breakers: Vec<String>,
69+
pub multiplier: f32,
70+
pub base: f32,
71+
pub allowed_length: usize,
72+
}
73+
74+
impl DrySamplingParams {
75+
pub fn new_with_defaults(
76+
multiplier: f32,
77+
sequence_breakers: Option<Vec<String>>,
78+
base: Option<f32>,
79+
allowed_length: Option<usize>,
80+
) -> anyhow::Result<Self> {
81+
Ok(Self {
82+
base: base.unwrap_or(1.75),
83+
allowed_length: allowed_length.unwrap_or(2),
84+
sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
85+
multiplier,
86+
})
87+
}
88+
}
89+
90+
impl Default for DrySamplingParams {
91+
fn default() -> Self {
92+
Self {
93+
multiplier: 0.0,
94+
base: 1.75,
95+
allowed_length: 2,
96+
sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
97+
}
98+
}
99+
}
100+
101+
#[derive(Clone, Debug)]
102+
struct DrySamplingParamsInner {
103+
pub sequence_breakers: HashSet<u32>,
104+
pub multiplier: f32,
105+
pub base: f32,
106+
pub allowed_length: usize,
107+
}
108+
109+
impl DrySamplingParamsInner {
110+
pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
111+
Ok(Self {
112+
base: other.base,
113+
allowed_length: other.allowed_length,
114+
sequence_breakers: HashSet::from_iter(
115+
other
116+
.sequence_breakers
117+
.into_iter()
118+
.map(|breaker| {
119+
tokenizer
120+
// Prefix with 'a' to get the correct encoding of the token at the end of a text.
121+
//
122+
// FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
123+
// for the correct solution which covers multi-token sequence breakers
124+
// and ambiguous encodings.
125+
.encode(["a", &breaker].concat(), true)
126+
.map_err(anyhow::Error::msg)
127+
.map(|enc| {
128+
let ids = enc.get_ids();
129+
if !ids.is_empty() {
130+
None
131+
} else {
132+
Some(ids[ids.len() - 1])
133+
}
134+
})
135+
})
136+
.collect::<anyhow::Result<Vec<_>>>()?
137+
.into_iter()
138+
.flatten()
139+
.collect::<Vec<_>>(),
140+
),
141+
multiplier: other.multiplier,
142+
})
143+
}
144+
}
145+
59146
/// Customizable logtis processor
60147
pub trait CustomLogitsProcessor: Send + Sync {
61148
/// Logits and sequence context (prompt and generated tokens), returning modified tokens.
@@ -76,6 +163,7 @@ pub struct Sampler {
76163
tokenizer: Arc<Tokenizer>,
77164
frequency_penalty: Option<f32>,
78165
presence_penalty: Option<f32>,
166+
dry_params: Option<DrySamplingParamsInner>,
79167
top_k: i64,
80168
top_p: f64,
81169
min_p: f64,
@@ -112,27 +200,34 @@ impl Sampler {
112200
tokenizer: Arc<Tokenizer>,
113201
frequency_penalty: Option<f32>,
114202
presence_penalty: Option<f32>,
203+
dry_params: Option<DrySamplingParams>,
115204
top_k: i64,
116205
top_p: f64,
117206
min_p: f64,
118207
logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
119-
) -> Self {
208+
) -> anyhow::Result<Self> {
120209
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
121210
None
122211
} else {
123212
temperature
124213
};
125-
Self {
214+
let dry_params = dry_params.map(|params| DrySamplingParamsInner::from(params, &tokenizer));
215+
let dry_params = match dry_params {
216+
Some(fallible) => Some(fallible?),
217+
None => None,
218+
};
219+
Ok(Self {
126220
temperature,
127221
top_n_logprobs,
128222
tokenizer,
129223
frequency_penalty,
130224
presence_penalty,
225+
dry_params,
131226
top_k,
132227
top_p,
133228
min_p,
134229
logits_processors,
135-
}
230+
})
136231
}
137232

138233
fn get_top_logprobs(
@@ -372,6 +467,21 @@ impl Sampler {
372467
}
373468

374469
fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
470+
if context.is_empty() {
471+
candle_core::bail!("Penalty context is empty, this should not happen.");
472+
}
473+
474+
// Dry penalty
475+
self.apply_dry_penalty(&mut logits, context)?;
476+
477+
// Frequency and Presence penalty
478+
self.apply_freq_presc_penalty(&mut logits, context)?;
479+
480+
let vocab_size = logits.len();
481+
Tensor::from_vec(logits, vocab_size, &Device::Cpu)
482+
}
483+
484+
fn apply_freq_presc_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
375485
if self.frequency_penalty.is_some() || self.presence_penalty.is_some() {
376486
let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
377487
let presence_penalty = self.presence_penalty.unwrap_or(0.);
@@ -390,8 +500,71 @@ impl Sampler {
390500
- if count > 0.0 { 1. } else { 0. } * presence_penalty;
391501
}
392502
}
393-
let vocab_size = logits.len();
394-
Tensor::from_vec(logits, vocab_size, &Device::Cpu)
503+
Ok(())
504+
}
505+
506+
fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
507+
if let Some(ref params) = self.dry_params {
508+
let match_indices = context
509+
.par_iter()
510+
.enumerate()
511+
.take(context.len() - 1)
512+
.filter(|(_i, x)| *context.last().unwrap() == **x)
513+
.map(|(i, _)| i)
514+
.collect::<Vec<_>>();
515+
516+
let mut match_lengths = HashMap::new();
517+
518+
for i in match_indices {
519+
let next_token = context[i + 1];
520+
521+
if params.sequence_breakers.contains(&next_token) {
522+
continue;
523+
}
524+
525+
let mut match_length = 1;
526+
527+
// Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
528+
while match_length < 50 {
529+
if match_length > i {
530+
// Start of input
531+
break;
532+
}
533+
534+
let j = i - match_length;
535+
536+
let prev_tok = context[context.len() - (match_length + 1)];
537+
if context[j] != prev_tok {
538+
// Start of match reached
539+
break;
540+
}
541+
542+
if params.sequence_breakers.contains(&prev_tok) {
543+
// Seq breaking tok reached
544+
break;
545+
}
546+
547+
match_length += 1;
548+
}
549+
550+
#[allow(clippy::map_entry)]
551+
if match_lengths.contains_key(&next_token) {
552+
match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
553+
} else {
554+
match_lengths.insert(next_token, match_length);
555+
}
556+
}
557+
558+
// Actually apply penalties
559+
for (tok, match_len) in match_lengths {
560+
if match_len >= params.allowed_length {
561+
let penalty = params.multiplier
562+
* params.base.powf((match_len - params.allowed_length) as f32);
563+
logits[tok as usize] -= penalty;
564+
}
565+
}
566+
}
567+
Ok(())
395568
}
396569

397570
/// Sample the provided tokens.
@@ -406,7 +579,8 @@ impl Sampler {
406579
rng: Arc<Mutex<Isaac64Rng>>,
407580
sample_speculative: bool,
408581
) -> Result<Logprobs> {
409-
let mut logits = self.apply_penalties(logits.to_vec1()?, context)?;
582+
let logits = logits.to_vec1()?;
583+
let mut logits = self.apply_penalties(logits, context)?;
410584
for processor in &self.logits_processors {
411585
logits = processor.apply(&logits, context)?;
412586
}
@@ -487,11 +661,13 @@ mod tests {
487661
get_tokenizer().into(),
488662
None,
489663
None,
664+
None,
490665
32,
491666
0.1,
492667
0.05,
493668
vec![],
494-
);
669+
)
670+
.unwrap();
495671
let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
496672
let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
497673
let res = sampler
@@ -517,11 +693,13 @@ mod tests {
517693
get_tokenizer().into(),
518694
None,
519695
None,
696+
None,
520697
32,
521698
0.1,
522699
0.05,
523700
vec![],
524-
);
701+
)
702+
.unwrap();
525703
let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
526704
let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
527705
let res = sampler

0 commit comments

Comments
 (0)