@@ -152,7 +152,7 @@ def sample(
152152 blocked_tokens : list [int ] | None = None ,
153153 filters : list [ExLlamaV2Filter ] | None = None ,
154154 filter_prefer_eos : bool = False ,
155- sync : bool = False
155+ sync : bool = False ,
156156 ):
157157
158158 """
@@ -273,6 +273,9 @@ def prep_logit_filter(lf):
273273 for f in filters :
274274
275275 pt , et = f .next ()
276+ if len (filters ) > 1 and not isinstance (pt , set ):
277+ pt , et = set (pt ), set (et )
278+
276279 if pt is not None : pass_tokens = pt if pass_tokens is None else pass_tokens & pt
277280 if et is not None : end_tokens = et if end_tokens is None else end_tokens | et
278281
@@ -290,9 +293,15 @@ def prep_logit_filter(lf):
290293 return output_tokens , output_ktokens , output_kprobs , output_probs , end_filter
291294
292295 if filter_prefer_eos and tokenizer .eos_token_id in pass_tokens :
293- pass_tokens = { tokenizer .eos_token_id }
294- logit_filter = prep_logit_filter (logit_filter )
295- ext_c .logit_filter_exclusive (logit_filter , [sorted (list (pass_tokens ))])
296+ pass_tokens_list = [tokenizer .eos_token_id ]
297+ logit_filter = prep_logit_filter (logit_filter )
298+ ext_c .logit_filter_exclusive (logit_filter , pass_tokens_list )
299+ else :
300+ logit_filter = prep_logit_filter (logit_filter )
301+ if isinstance (pass_tokens , set ):
302+ ext_c .logit_filter_exclusive (logit_filter , [sorted (list (pass_tokens ))])
303+ else :
304+ ext_c .logit_filter_exclusive (logit_filter , [pass_tokens ])
296305
297306 # Healing
298307
0 commit comments