Skip to content

Commit f7d92d9

Browse files
authored
Review: "Implement DRY penalty" (#645)
* Silence bogus Clippy warning Clippy's suggestion cannot be implemented because of borrowing issues * Get rid of unnecessary type annotations Interesting that Clippy doesn't catch this * Store default sequence breakers in a slice It's nicer when the length is not hardcoded * Make default sequence breakers private No need to leak this as it's not used elsewhere * Limit match length Avoids quadratic runtime and potential DoS with adversarial inputs Ref oobabooga/text-generation-webui#6047 * "Fix" sequence breaker tokenization Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result. See LostRuins/koboldcpp#982 for a correct solution to this problem.
1 parent 8650d9c commit f7d92d9

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

mistralrs-core/src/sampler.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIter
1616
use serde::{Deserialize, Serialize};
1717
use tokenizers::Tokenizer;
1818

19-
pub const SEQUENCE_BREAKERS: [&str; 4] = ["\n", ":", "\\", "*"];
19+
const SEQUENCE_BREAKERS: &[&str] = &["\n", ":", "\\", "*"];
2020

2121
#[derive(Clone, Debug)]
2222
/// Stop sequences or ids.
@@ -79,12 +79,8 @@ impl DrySamplingParams {
7979
Ok(Self {
8080
base: base.unwrap_or(1.75),
8181
allowed_length: allowed_length.unwrap_or(2),
82-
sequence_breakers: sequence_breakers.unwrap_or(
83-
SEQUENCE_BREAKERS
84-
.map(|x| x.to_string())
85-
.into_iter()
86-
.collect::<Vec<_>>(),
87-
),
82+
sequence_breakers: sequence_breakers
83+
.unwrap_or(SEQUENCE_BREAKERS.iter().map(|x| x.to_string()).collect()),
8884
multiplier,
8985
})
9086
}
@@ -96,10 +92,7 @@ impl Default for DrySamplingParams {
9692
multiplier: 1.0,
9793
base: 1.75,
9894
allowed_length: 2,
99-
sequence_breakers: SEQUENCE_BREAKERS
100-
.map(|x| x.to_string())
101-
.into_iter()
102-
.collect::<Vec<_>>(),
95+
sequence_breakers: SEQUENCE_BREAKERS.iter().map(|x| x.to_string()).collect(),
10396
}
10497
}
10598
}
@@ -123,14 +116,19 @@ impl DrySamplingParamsInner {
123116
.into_iter()
124117
.map(|breaker| {
125118
tokenizer
126-
.encode(breaker.clone(), true)
119+
// Prefix with 'a' to get the correct encoding of the token at the end of a text.
120+
//
121+
// FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
122+
// for the correct solution which covers multi-token sequence breakers
123+
// and ambiguous encodings.
124+
.encode(format!("a{breaker}"), true)
127125
.map_err(anyhow::Error::msg)
128126
.map(|enc| {
129127
let ids = enc.get_ids();
130128
if !ids.is_empty() {
131129
None
132130
} else {
133-
Some(ids[0])
131+
Some(ids[ids.len() - 1])
134132
}
135133
})
136134
})
@@ -505,7 +503,8 @@ impl Sampler {
505503

506504
let mut match_length = 1;
507505

508-
loop {
506+
// Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
507+
while match_length < 50 {
509508
if match_length > i {
510509
// Start of input
511510
break;
@@ -527,6 +526,7 @@ impl Sampler {
527526
match_length += 1;
528527
}
529528

529+
#[allow(clippy::map_entry)]
530530
if match_lengths.contains_key(&next_token) {
531531
match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
532532
} else {

0 commit comments

Comments
 (0)