File tree Expand file tree Collapse file tree 1 file changed +12
-9
lines changed
Expand file tree Collapse file tree 1 file changed +12
-9
lines changed Original file line number Diff line number Diff line change 1111import llama_cpp
1212import llama_cpp ._internals as internals
1313
14+ from typing import (
15+ List ,
16+ Dict ,
17+ )
18+
1419
1520MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf"
1621
@@ -81,7 +86,6 @@ def test_real_model(llama_cpp_model_path):
8186 cparams .n_ubatch = 16
8287 cparams .n_threads = multiprocessing .cpu_count ()
8388 cparams .n_threads_batch = multiprocessing .cpu_count ()
84- cparams .logits_all = False
8589 cparams .flash_attn = True
8690 cparams .swa_full = True
8791
@@ -153,15 +157,13 @@ def test_real_llama(llama_cpp_model_path):
153157 assert output ["choices" ][0 ]["text" ] == "true"
154158
155159 suffix = b"rot"
160+
156161 tokens = model .tokenize (suffix , add_bos = True , special = True )
157- def logit_processor_func (input_ids , logits ):
158- for token in tokens :
159- logits [token ] *= 1000
160- return logits
161162
162- logit_processors = llama_cpp .LogitsProcessorList (
163- [logit_processor_func ]
164- )
163+ logit_bias : Dict [int , float ] = {}
164+
165+ for token_id in tokens :
166+ logit_bias [token_id ] = 1000
165167
166168 output = model .create_completion (
167169 "The capital of france is par" ,
@@ -170,8 +172,9 @@ def logit_processor_func(input_ids, logits):
170172 top_p = 0.9 ,
171173 temperature = 0.8 ,
172174 seed = 1337 ,
173- logits_processor = logit_processors
175+ logit_bias = logit_bias
174176 )
177+
175178 assert output ["choices" ][0 ]["text" ].lower ().startswith ("rot" )
176179
177180 model .set_seed (1337 )
You can’t perform that action at this time.
0 commit comments