77from exllamav2 .generator .hooks import ExLlamaV2PostSamplingHook
88from exllamav2 .ext import exllamav2_ext as ext_c , none_tensor
99from copy import copy
10+ import threading
1011# import line_profiler
1112
13+ _tl_tensors = threading .local ()
14+
15+ def _get_logit_filter (shape , dtype ):
16+ global _tl_tensors
17+ if not hasattr (_tl_tensors , 'logit_filter' ) \
18+ or _tl_tensors .logit_filter .shape != shape \
19+ or _tl_tensors .logit_filter .dtype != dtype :
20+ _tl_tensors .logit_filter = torch .empty (shape , dtype = dtype )
21+ return _tl_tensors .logit_filter
22+
23+ def _get_output_tokens (shape , dtype ):
24+ global _tl_tensors
25+ if not hasattr (_tl_tensors , 'output_tokens' ) \
26+ or _tl_tensors .output_tokens .shape != shape \
27+ or _tl_tensors .output_tokens .dtype != dtype :
28+ _tl_tensors .output_tokens = torch .empty (shape , dtype = dtype )
29+ return _tl_tensors .output_tokens
30+
31+ def _get_output_probs (shape , dtype ):
32+ global _tl_tensors
33+ if not hasattr (_tl_tensors , 'output_probs' ) \
34+ or _tl_tensors .output_probs .shape != shape \
35+ or _tl_tensors .output_probs .dtype != dtype :
36+ _tl_tensors .output_probs = torch .empty (shape , dtype = dtype )
37+ return _tl_tensors .output_probs
38+
39+
1240class ExLlamaV2Sampler :
1341
1442 @dataclass
@@ -186,7 +214,7 @@ def sample(
186214 else :
187215 assert batch_size == 1 or len (filters ) == 0 , "Filters not implemented for batch size > 1"
188216
189- logits = logits .squeeze ( 1 )
217+ # logits = logits.view(batch_size, vocab_size )
190218
191219 # Sync
192220
@@ -203,8 +231,13 @@ def sample(
203231
204232 # Prepare filter
205233
206- logit_filter = torch .empty ((batch_size , vocab_size ), dtype = torch .bool )
207- ext_c .fast_fill_cpu_ones_bool (logit_filter )
234+ logit_filter = None
235+ def prep_logit_filter (lf ):
236+ if lf is not None :
237+ return lf
238+ lf = _get_logit_filter ((batch_size , vocab_size ), torch .bool )
239+ ext_c .fast_fill_cpu_ones_bool (lf )
240+ return lf
208241
209242 # Repetition penalty
210243
@@ -223,7 +256,7 @@ def sample(
223256 # Temporarily ban individual tokens
224257
225258 if blocked_tokens :
226- logits [:, blocked_tokens ] = - 1e30
259+ logits [:, :, blocked_tokens ] = - 1e30
227260
228261 # Token bias
229262
@@ -247,7 +280,7 @@ def sample(
247280 assert pass_tokens , "Filter excluded all tokens"
248281 if filter_prefer_eos and tokenizer .eos_token_id in pass_tokens :
249282 pass_tokens = { tokenizer .eos_token_id }
250- # TODO: pass pass_tokens as a numpy array or Python set
283+ logit_filter = prep_logit_filter ( logit_filter )
251284 ext_c .logit_filter_exclusive (logit_filter , [sorted (list (pass_tokens ))])
252285
253286 # Healing
@@ -260,6 +293,7 @@ def sample(
260293 for i in range (batch_size ):
261294 valid_token_lists .append (prefix_id_to_ids [prefix_token [i , 0 ].item ()])
262295
296+ logit_filter = prep_logit_filter (logit_filter )
263297 ext_c .logit_filter_exclusive (logit_filter , valid_token_lists )
264298
265299 # Begin Mirostat
@@ -272,20 +306,20 @@ def sample(
272306
273307 vs = tokenizer .get_vocab_size ()
274308 if vs < logits .shape [- 1 ]:
275- logits [:, vs :] = float ("-inf" )
309+ logits [:, :, vs :] = float ("-inf" )
276310
277311 # Sampling
278312
279- batch_size = logits . shape [ 0 ]
280-
281- output_tokens = torch .empty ((batch_size , 1 ), device = "cpu" , dtype = torch .long )
282- output_probs = torch . empty ((batch_size , 1 ), device = "cpu" , dtype = torch .float )
313+ output_tokens = torch . empty (( batch_size , 1 ), dtype = torch . long )
314+ # output_tokens = _get_output_tokens((batch_size, 1), torch.long)
315+ output_probs = torch .empty ((batch_size , 1 ), dtype = torch .float )
316+ # output_probs = _get_output_probs ((batch_size, 1), torch.float)
283317 if return_top_tokens == 0 :
284318 output_ktokens = none_tensor
285319 output_kprobs = none_tensor
286320 else :
287- output_ktokens = torch .empty ((batch_size , 1 , return_top_tokens ), device = "cpu" , dtype = torch .long )
288- output_kprobs = torch .empty ((batch_size , 1 , return_top_tokens ), device = "cpu" , dtype = torch .float )
321+ output_ktokens = torch .empty ((batch_size , 1 , return_top_tokens ), dtype = torch .long )
322+ output_kprobs = torch .empty ((batch_size , 1 , return_top_tokens ), dtype = torch .float )
289323
290324 m = ext_c .sample_basic (
291325 logits ,
@@ -301,7 +335,7 @@ def sample(
301335 output_probs ,
302336 output_kprobs ,
303337 output_ktokens ,
304- logit_filter ,
338+ logit_filter if logit_filter is not None else none_tensor ,
305339 settings .mirostat ,
306340 settings .mirostat_mu if settings .mirostat else [],
307341 settings .mirostat_tau ,
0 commit comments