Skip to content

Commit 790cc09

Browse files
committed
try to use the logit_bias instead of logit_processors in test_llama
1 parent 90ed7a6 commit 790cc09

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tests/test_llama.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
import llama_cpp
1212
import llama_cpp._internals as internals
1313

14+
from typing import (
15+
List,
16+
Dict,
17+
)
18+
1419

1520
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf"
1621

@@ -81,7 +86,6 @@ def test_real_model(llama_cpp_model_path):
8186
cparams.n_ubatch = 16
8287
cparams.n_threads = multiprocessing.cpu_count()
8388
cparams.n_threads_batch = multiprocessing.cpu_count()
84-
cparams.logits_all = False
8589
cparams.flash_attn = True
8690
cparams.swa_full = True
8791

@@ -153,15 +157,13 @@ def test_real_llama(llama_cpp_model_path):
153157
assert output["choices"][0]["text"] == "true"
154158

155159
suffix = b"rot"
160+
156161
tokens = model.tokenize(suffix, add_bos=True, special=True)
157-
def logit_processor_func(input_ids, logits):
158-
for token in tokens:
159-
logits[token] *= 1000
160-
return logits
161162

162-
logit_processors = llama_cpp.LogitsProcessorList(
163-
[logit_processor_func]
164-
)
163+
logit_bias: Dict[int, float] = {}
164+
165+
for token_id in tokens:
166+
logit_bias[token_id] = 1000
165167

166168
output = model.create_completion(
167169
"The capital of france is par",
@@ -170,8 +172,9 @@ def logit_processor_func(input_ids, logits):
170172
top_p=0.9,
171173
temperature=0.8,
172174
seed=1337,
173-
logits_processor=logit_processors
175+
logit_bias=logit_bias
174176
)
177+
175178
assert output["choices"][0]["text"].lower().startswith("rot")
176179

177180
model.set_seed(1337)

0 commit comments

Comments
 (0)