88from exllamav2 .ext import exllamav2_ext as ext_c , none_tensor
99from copy import copy
1010import threading
11+ from functools import lru_cache
12+ import re
1113# import line_profiler
1214
1315_tl_tensors = threading .local ()
@@ -37,6 +39,12 @@ def _get_output_probs(shape, dtype):
3739 return _tl_tensors .output_probs
3840
3941
42+ @dataclass
43+ class NgramNode :
44+ value : int = 0
45+ children : dict [int , NgramNode ] = field (default_factory = dict )
46+
47+
4048class ExLlamaV2Sampler :
4149
4250 @dataclass
@@ -74,6 +82,15 @@ class Settings:
7482
7583 post_sampling_hooks : list [ExLlamaV2PostSamplingHook ] = field (default_factory = list )
7684
85+ dry_allowed_length : int = 0 # 0 to disable
86+ dry_base : float = 2.0
87+ dry_multiplier : float = 2.0
88+ dry_sequence_breakers : set [int ] | None = None
89+ dry_max_ngram : int = 20
90+
91+ ngram_trie : dict [int , NgramNode ] = None
92+ ngram_index : int = 0
93+
7794 @staticmethod
7895 def greedy (** kwargs ) -> ExLlamaV2Sampler .Settings :
7996 defaults = {
@@ -101,6 +118,11 @@ def greedy_clone(self):
101118 c .token_frequency_penalty = self .token_frequency_penalty
102119 c .token_presence_penalty = self .token_presence_penalty
103120 c .token_bias = None
121+ c .dry_allowed_length = self .dry_allowed_length
122+ c .dry_base = self .dry_allowed_length
123+ c .dry_multiplier = self .dry_multiplier
124+ c .dry_sequence_breakers = self .dry_sequence_breakers
125+ c .dry_max_ngram = self .dry_max_ngram
104126 c .filters = []
105127 return c
106128
@@ -139,6 +161,82 @@ def allow_tokens(
139161 raise ValueError ("Incorrect type in allow_tokens list" )
140162
141163
164+ @staticmethod
165+ @lru_cache (10 )
166+ def get_dry_default_sequence_breaker_tokens (
167+ tokenizer : ExLlamaV2Tokenizer
168+ ) -> set [int ]:
169+ result = set ()
170+ dry_default_sequence_breaker_chars = r".,!?<>\[\]\(\)\{\}\n\t\""
171+ pattern = re .compile (r"[" + dry_default_sequence_breaker_chars + "]" )
172+ pieces = tokenizer .get_id_to_piece_list (include_special_tokens = True )
173+ for t in range (len (pieces )):
174+ if bool (pattern .search (pieces [t ])):
175+ result .add (t )
176+ for t in tokenizer .extended_id_to_piece .keys ():
177+ result .add (t )
178+ return result
179+
180+
181+ @staticmethod
182+ def apply_dry (
183+ settings : ExLlamaV2Sampler .Settings ,
184+ tokenizer : ExLlamaV2Tokenizer ,
185+ sequence_ids : torch .Tensor ,
186+ logits : torch .Tensor
187+ ):
188+ if settings .ngram_trie is None :
189+ settings .ngram_trie = NgramNode (0 , {})
190+ settings .ngram_index = 0
191+
192+ if settings .dry_sequence_breakers is None :
193+ settings .dry_sequence_breakers = \
194+ ExLlamaV2Sampler .get_dry_default_sequence_breaker_tokens (tokenizer )
195+
196+ # Convert sequence IDs to list once since .item() is slow
197+ sequence_list = sequence_ids [0 ].tolist ()
198+
199+ # Update trie with new ngrams
200+ seq_len = max (len (sequence_list ) - 1 , 0 )
201+ for i in range (max (settings .ngram_index - settings .dry_max_ngram , 0 ), seq_len ):
202+ node = settings .ngram_trie
203+ for j in range (i , min (i + settings .dry_max_ngram , seq_len )):
204+ t = sequence_list [j ]
205+ if t in settings .dry_sequence_breakers :
206+ break
207+ if t not in node .children :
208+ node .children [t ] = NgramNode (0 , {})
209+ if j >= settings .ngram_index :
210+ node .children [t ].value += 1
211+ node = node .children [t ]
212+ settings .ngram_index = seq_len
213+
214+ # Find longest ngram
215+ seq_len = len (sequence_list )
216+ beg = max (seq_len - settings .dry_max_ngram , 0 )
217+ end = max (seq_len - settings .dry_allowed_length + 1 , 0 )
218+ penalty_tokens = None
219+ for i in range (beg , end ):
220+ node = settings .ngram_trie
221+ for j in range (i , seq_len ):
222+ t = sequence_list [j ]
223+ if t not in node .children :
224+ break
225+ node = node .children [t ]
226+ else :
227+ penalty_tokens = node .children
228+ ngram_prefix_length = j - i + 1
229+ break
230+
231+ # Apply penalties if a node with children was reached at the end of the context, in which case
232+ # those children count all ngrams of length > ngram_prefix_length
233+ if penalty_tokens :
234+ indices = torch .tensor ([[list (penalty_tokens .keys ())]], dtype = torch .long )
235+ exc_length = ngram_prefix_length - settings .dry_allowed_length
236+ penalty = - settings .dry_multiplier * settings .dry_base ** exc_length
237+ penalties = torch .tensor ([[[penalty * node .value for node in penalty_tokens .values ()]]], dtype = torch .float )
238+ logits .scatter_add_ (- 1 , indices , penalties )
239+
142240 @staticmethod
143241 # @profile
144242 def sample (
@@ -264,6 +362,11 @@ def prep_logit_filter(lf):
264362 # logits = logits + settings.token_bias
265363 ext_c .fast_fadd_cpu (logits , settings .token_bias )
266364
365+ # DRY
366+
367+ if settings .dry_allowed_length :
368+ ExLlamaV2Sampler .apply_dry (settings , tokenizer , sequence_ids , logits )
369+
267370 # Evaluate filters
268371
269372 if len (filters ) > 0 :
@@ -285,8 +388,8 @@ def prep_logit_filter(lf):
285388 # Special case if a single token passes
286389 if len (pass_tokens ) == 1 and return_top_tokens == 0 and prefix_token is None :
287390 single_passed_token = next (iter (pass_tokens ))
288- output_tokens = torch .tensor ([[single_passed_token ]], dtype = torch .long )
289- output_probs = torch .tensor ([[1 ]], dtype = torch .float )
391+ output_tokens = torch .tensor ([[single_passed_token ]], dtype = torch .long )
392+ output_probs = torch .tensor ([[1 ]], dtype = torch .float )
290393 output_ktokens = none_tensor
291394 output_kprobs = none_tensor
292395 end_filter = (single_passed_token in end_tokens )
0 commit comments