Skip to content

Commit 1179b8a

Browse files
committed
Fix ppl test for long seq lengths
1 parent 0122b11 commit 1179b8a

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

exllamav2/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def forward(self,
681681
return_last_state: bool = False,
682682
position_offsets: torch.Tensor | None = None,
683683
abort_event: threading.Event | None = None,
684+
cpu_logits: bool = False,
684685
**kwargs) \
685686
-> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None:
686687
"""
@@ -717,6 +718,11 @@ def forward(self,
717718
:param abort_event:
718719
Optional event that, if set, will abort the forward pass. Function will return None if aborted.
719720
721+
:param cpu_logits:
722+
If True, logits are collected and returned in system RAM. This is somewhat slower but can prevent
723+
out-of-memory errors when computing logits for all positions in a long sequence, such as during a
724+
perplexity test.
725+
720726
:return:
721727
FP16 logits tensor, shape (batch_size, q_len, vocab_size)
722728
(optional) state tensor, shape (batch_size, q_len, hidden_size)
@@ -819,6 +825,8 @@ def forward(self,
819825
if abort_event and abort_event.is_set(): return
820826

821827
if not _preprocess_only:
828+
if cpu_logits:
829+
r["logits"] = r["logits"].cpu()
822830
result = r["logits"] if result is None else torch.cat((result, r["logits"]), dim = 1)
823831

824832
chunk_begin = chunk_end

test_inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@
292292

293293
def ppl(input_ids__, logits__, lengths__, bins = False):
294294

295+
logits_device = model.modules[-1].device()
296+
295297
if bins:
296298
num_bins = (max(lengths__) + 255) // 256
297299
logprob_sum_ = [0.0] * num_bins
@@ -317,8 +319,8 @@ def ppl(input_ids__, logits__, lengths__, bins = False):
317319
a_ = b_
318320
b_ = min(b_ + chunksize, logits_.shape[1])
319321

320-
logits_f = logits_[:, a_:b_, :].float() + 1e-10
321-
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
322+
logits_f = logits_[:, a_:b_, :].to(logits_device).float() + 1e-10
323+
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_f.device)
322324

323325
log_probs = F.log_softmax(logits_f, dim=-1)
324326
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
@@ -398,7 +400,7 @@ def ppl(input_ids__, logits__, lengths__, bins = False):
398400

399401
input_ids = input_ids[:, :]
400402
if cache is not None: cache.current_seq_len = 0
401-
logits = model.forward(input_ids, cache)
403+
logits = model.forward(input_ids, cache, cpu_logits = input_ids.numel() > 2048)
402404
logits = logits[:, :-1, :]
403405

404406
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1], args.eval_context_lens)

0 commit comments

Comments
 (0)