292292
293293 def ppl (input_ids__ , logits__ , lengths__ , bins = False ):
294294
295+ logits_device = model .modules [- 1 ].device ()
296+
295297 if bins :
296298 num_bins = (max (lengths__ ) + 255 ) // 256
297299 logprob_sum_ = [0.0 ] * num_bins
@@ -317,8 +319,8 @@ def ppl(input_ids__, logits__, lengths__, bins = False):
317319 a_ = b_
318320 b_ = min (b_ + chunksize , logits_ .shape [1 ])
319321
320- logits_f = logits_ [:, a_ :b_ , :].float () + 1e-10
321- target_ids = input_ids_ [:, a_ + 1 :b_ + 1 ].to (logits_ .device )
322+ logits_f = logits_ [:, a_ :b_ , :].to ( logits_device ). float () + 1e-10
323+ target_ids = input_ids_ [:, a_ + 1 :b_ + 1 ].to (logits_f .device )
322324
323325 log_probs = F .log_softmax (logits_f , dim = - 1 )
324326 token_log_probs = log_probs .gather (- 1 , target_ids .unsqueeze (- 1 )).squeeze (- 1 )
@@ -398,7 +400,7 @@ def ppl(input_ids__, logits__, lengths__, bins = False):
398400
399401 input_ids = input_ids [:, :]
400402 if cache is not None : cache .current_seq_len = 0
401- logits = model .forward (input_ids , cache )
403+ logits = model .forward (input_ids , cache , cpu_logits = input_ids . numel () > 2048 )
402404 logits = logits [:, :- 1 , :]
403405
404406 logprob_sum__ , logprob_count__ = ppl (input_ids , logits , eval_len [i :i + 1 ], args .eval_context_lens )
0 commit comments