@@ -67,6 +67,7 @@ class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
6767 }
6868 }
6969
70+ llama_backend_init (false );
7071 model = llama_load_model_from_file (modelPath.c_str (), params);
7172
7273 if (model == NULL ) {
@@ -124,7 +125,18 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
124125
125126 // Decode each token and accumulate the result.
126127 for (size_t i = 0 ; i < tokens.ElementLength (); i++) {
127- const char * str = llama_token_to_str (ctx, (llama_token)tokens[i]);
128+ // source: https://github.com/ggerganov/llama.cpp/blob/232caf3c1581a6cb023571780ff41dc2d66d1ca0/llama.cpp#L799-L811
129+ std::vector<char > result (8 , 0 );
130+ const int n_tokens = llama_token_to_str (ctx, (llama_token)tokens[i], result.data (), result.size ());
131+ if (n_tokens < 0 ) {
132+ result.resize (-n_tokens);
133+ int check = llama_token_to_str (ctx, (llama_token)tokens[i], result.data (), result.size ());
134+ GGML_ASSERT (check == -n_tokens);
135+ } else {
136+ result.resize (n_tokens);
137+ }
138+
139+ const char * str = result.data ();
128140 if (str == nullptr ) {
129141 Napi::Error::New (info.Env (), " Invalid token" ).ThrowAsJavaScriptException ();
130142 return info.Env ().Undefined ();
@@ -134,6 +146,15 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
134146
135147 return Napi::String::New (info.Env (), ss.str ());
136148 }
149+ Napi::Value TokenBos (const Napi::CallbackInfo& info) {
150+ return Napi::Number::From (info.Env (), llama_token_bos (ctx));
151+ }
152+ Napi::Value TokenEos (const Napi::CallbackInfo& info) {
153+ return Napi::Number::From (info.Env (), llama_token_eos (ctx));
154+ }
155+ Napi::Value GetMaxContextSize (const Napi::CallbackInfo& info) {
156+ return Napi::Number::From (info.Env (), llama_n_ctx (ctx));
157+ }
137158 Napi::Value Eval (const Napi::CallbackInfo& info);
138159 static void init (Napi::Object exports) {
139160 exports.Set (" LLAMAContext" ,
@@ -142,6 +163,9 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
142163 {
143164 InstanceMethod (" encode" , &LLAMAContext::Encode),
144165 InstanceMethod (" decode" , &LLAMAContext::Decode),
166+ InstanceMethod (" tokenBos" , &LLAMAContext::TokenBos),
167+ InstanceMethod (" tokenEos" , &LLAMAContext::TokenEos),
168+ InstanceMethod (" getMaxContextSize" , &LLAMAContext::GetMaxContextSize),
145169 InstanceMethod (" eval" , &LLAMAContext::Eval),
146170 }));
147171 }
@@ -151,7 +175,6 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
151175class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
152176 LLAMAContext* ctx;
153177 std::vector<llama_token> tokens;
154- std::vector<llama_token> restriction;
155178 llama_token result;
156179
157180 public:
@@ -160,13 +183,6 @@ class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
160183 Napi::Uint32Array tokens = info[0 ].As <Napi::Uint32Array>();
161184 this ->tokens .reserve (tokens.ElementLength ());
162185 for (size_t i = 0 ; i < tokens.ElementLength (); i++) { this ->tokens .push_back (static_cast <llama_token>(tokens[i])); }
163-
164- if (info.Length () > 1 && info[1 ].IsTypedArray ()) {
165- Napi::Uint32Array restriction = info[1 ].As <Napi::Uint32Array>();
166- this ->restriction .reserve (restriction.ElementLength ());
167- for (size_t i = 0 ; i < restriction.ElementLength (); i++) { this ->restriction .push_back (static_cast <llama_token>(restriction[i])); }
168- std::sort (this ->restriction .begin (), this ->restriction .end ());
169- }
170186 }
171187 ~LLAMAContextEvalWorker () { ctx->Unref (); }
172188 using Napi::AsyncWorker::Queue;
@@ -175,39 +191,30 @@ class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
175191 protected:
176192 void Execute () {
177193 // Perform the evaluation using llama_eval.
178- int r = llama_eval (ctx->ctx , tokens.data (), tokens.size (), llama_get_kv_cache_token_count (ctx->ctx ), 6 );
194+ int r = llama_eval (ctx->ctx , tokens.data (), int ( tokens.size () ), llama_get_kv_cache_token_count (ctx->ctx ), 6 );
179195 if (r != 0 ) {
180196 SetError (" Eval has failed" );
181197 return ;
182198 }
183199
200+ llama_token new_token_id = 0 ;
201+
184202 // Select the best prediction.
185- float * logits = llama_get_logits (ctx->ctx );
186- int n_vocab = llama_n_vocab (ctx->ctx );
187- llama_token re;
188- if (restriction.empty ()) {
189- float max = logits[0 ];
190- re = 0 ;
191- for (llama_token id = 1 ; id < n_vocab; id++) {
192- float logit = logits[id];
193- if (logit > max) {
194- max = logit;
195- re = id;
196- }
197- }
198- } else {
199- float max = logits[restriction[0 ]];
200- re = 0 ;
201- for (size_t i = 1 ; i < restriction.size (); i++) {
202- llama_token id = restriction[i];
203- float logit = logits[id];
204- if (logit > max) {
205- max = logit;
206- re = id;
207- }
208- }
203+ auto logits = llama_get_logits (ctx->ctx );
204+ auto n_vocab = llama_n_vocab (ctx->ctx );
205+
206+ std::vector<llama_token_data> candidates;
207+ candidates.reserve (n_vocab);
208+
209+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
210+ candidates.emplace_back (llama_token_data{ token_id, logits[token_id], 0 .0f });
209211 }
210- result = re;
212+
213+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
214+
215+ new_token_id = llama_sample_token_greedy (ctx->ctx , &candidates_p);
216+
217+ result = new_token_id;
211218 }
212219 void OnOK () {
213220 Napi::Env env = Napi::AsyncWorker::Env ();
@@ -223,15 +230,11 @@ Napi::Value LLAMAContext::Eval(const Napi::CallbackInfo& info) {
223230 return worker->Promise ();
224231}
225232
226- Napi::Value tokenBos (const Napi::CallbackInfo& info) { return Napi::Number::From (info.Env (), llama_token_bos ()); }
227- Napi::Value tokenEos (const Napi::CallbackInfo& info) { return Napi::Number::From (info.Env (), llama_token_eos ()); }
228233Napi::Value systemInfo (const Napi::CallbackInfo& info) { return Napi::String::From (info.Env (), llama_print_system_info ()); }
229234
230235Napi::Object registerCallback (Napi::Env env, Napi::Object exports) {
231236 llama_backend_init (false );
232237 exports.DefineProperties ({
233- Napi::PropertyDescriptor::Function (" tokenBos" , tokenBos),
234- Napi::PropertyDescriptor::Function (" tokenEos" , tokenEos),
235238 Napi::PropertyDescriptor::Function (" systemInfo" , systemInfo),
236239 });
237240 LLAMAModel::init (exports);
0 commit comments