Skip to content

Commit b837773

Browse files
authored
Port speculative decoding from upstream to llama-server (#645)
* server : integrate speculative decoding * server: Fix field names * server: fix include, whitespace * fix compile errors in speculative.cpp * add llama_sampling_sample_and_accept_n to sampling * finish porting speculative decoding in server * port functions from common/speculative, common/sampling * remove arg * fix function names * init params_dft to none * correct value for n_ctx * prefix kv cache tensors with model name to avoid conflict * fix call arguments * fix spec decoding args * correct slot.id * use n_max * port the rest of sampling funcs * fix func arguments * slot.id starts at 1? * Revert "prefix kv cache tensors with model name to avoid conflict" This reverts commit fbd5dfd. * disable draft logging * disable logging in speculative.cpp in mainline, these would be LOG_DEBUG, but since ik_llama doesnt support it, logging is disabled entirely * add more draft model parameters * fix * pass flash_attn * add speculative params for parity * set speculative params in launch_slot_with_task instead
1 parent 4239d25 commit b837773

File tree

8 files changed

+655
-41
lines changed

8 files changed

+655
-41
lines changed

common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ add_library(${TARGET} STATIC
7676
minja.hpp
7777
ngram-cache.h
7878
ngram-cache.cpp
79+
speculative.cpp
7980
)
8081

8182
if (BUILD_SHARED_LIBS)

common/common.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
505505
params.n_ctx = std::stoi(argv[i]);
506506
return true;
507507
}
508+
if (arg == "-cd" || arg == "--ctx-size-draft") {
509+
CHECK_ARG
510+
params.n_ctx_draft = std::stoi(argv[i]);
511+
return true;
512+
}
508513
if (arg == "--grp-attn-n" || arg == "-gan") {
509514
CHECK_ARG
510515
params.grp_attn_n = std::stoi(argv[i]);
@@ -725,7 +730,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
725730
}
726731
}
727732
return true;
728-
}
733+
}
729734
if (arg == "--cfg-negative-prompt") {
730735
CHECK_ARG
731736
sparams.cfg_negative_prompt = argv[i];
@@ -765,11 +770,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
765770
params.n_keep = std::stoi(argv[i]);
766771
return true;
767772
}
768-
if (arg == "--draft") {
773+
if (arg == "--draft" || arg == "--draft-max" || arg == "--draft-n") {
769774
CHECK_ARG
770775
params.n_draft = std::stoi(argv[i]);
771776
return true;
772777
}
778+
if (arg == "--draft-min" || arg == "--draft-n-min") {
779+
CHECK_ARG
780+
params.n_draft_min = std::stoi(argv[i]);
781+
return true;
782+
}
783+
if (arg == "--draft-p-min") {
784+
CHECK_ARG
785+
params.p_draft_min = std::stof(argv[i]);
786+
return true;
787+
}
773788
if (arg == "--chunks") {
774789
CHECK_ARG
775790
params.n_chunks = std::stoi(argv[i]);
@@ -934,6 +949,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
934949
params.cache_type_v = argv[++i];
935950
return true;
936951
}
952+
if (arg == "-ctkd" || arg == "--cache-type-k-draft") {
953+
params.cache_type_k_draft = argv[++i];
954+
return true;
955+
}
956+
if (arg == "-ctvd" || arg == "--cache-type-v-draft") {
957+
params.cache_type_v_draft = argv[++i];
958+
return true;
959+
}
937960
if (arg == "-mli" || arg == "--multiline-input") {
938961
params.multiline_input = true;
939962
return true;
@@ -1071,7 +1094,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10711094
size_t pos = 0;
10721095
while ((pos = servers.find(",")) != std::string::npos) {
10731096
std::string server = servers.substr(0, pos);
1074-
ggml_backend_rpc_buffer_type(server.c_str());
1097+
ggml_backend_rpc_buffer_type(server.c_str());
10751098
servers.erase(0, pos + 1);
10761099
}
10771100
ggml_backend_rpc_buffer_type(servers.c_str());
@@ -1693,14 +1716,14 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
16931716
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
16941717
options.push_back({ "speculative", "-tbd, --threads-batch-draft N",
16951718
"number of threads to use during batch and prompt processing (default: same as --threads-draft)" });
1696-
options.push_back({ "speculative", " --draft N", "number of tokens to draft for speculative decoding (default: %d)", params.n_draft });
16971719
options.push_back({ "speculative", "-ps, --p-split N", "speculative decoding split probability (default: %.1f)", (double)params.p_split });
16981720
options.push_back({ "*", "-lcs, --lookup-cache-static FNAME",
16991721
"path to static lookup cache to use for lookup decoding (not updated by generation)" });
17001722
options.push_back({ "*", "-lcd, --lookup-cache-dynamic FNAME",
17011723
"path to dynamic lookup cache to use for lookup decoding (updated by generation)" });
17021724

17031725
options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
1726+
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.n_ctx_draft });
17041727
options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
17051728
options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch });
17061729
options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch });
@@ -1811,6 +1834,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
18111834
options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" });
18121835
options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() });
18131836
options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() });
1837+
options.push_back({ "*", "-ctkd, --cache-type-k-draft TYPE", "KV cache data type for K for the draft model" });
1838+
options.push_back({ "*", "-ctvd, --cache-type-v-draft TYPE", "KV cache data type for V for the draft model" });
18141839

18151840
options.push_back({ "perplexity" });
18161841
options.push_back({ "perplexity", " --all-logits", "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" });
@@ -1893,6 +1918,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
18931918
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
18941919
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
18951920
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
1921+
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
1922+
"number of tokens to draft for speculative decoding (default: %d)", params.n_draft });
1923+
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
1924+
options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.p_draft_min });
18961925

18971926
options.push_back({ "retrieval" });
18981927
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
@@ -2052,7 +2081,7 @@ std::string string_join(const std::vector<std::string> & strs, const std::string
20522081
if (strs.empty()) {
20532082
return "";
20542083
}
2055-
2084+
20562085
std::ostringstream oss;
20572086
for (size_t i = 0; i < strs.size(); ++i) {
20582087
if (i > 0) {

common/common.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,13 @@ struct gpt_params {
8383
int32_t n_threads_batch_draft = -1;
8484
int32_t n_predict = -1; // new tokens to predict
8585
int32_t n_ctx = 0; // context size
86+
int32_t n_ctx_draft = 0; // context size for draft model
8687
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
8788
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
8889
int32_t n_keep = 0; // number of tokens to keep from initial prompt
89-
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
90+
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
91+
int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding
92+
float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy)
9093
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
9194
int32_t n_parallel = 1; // number of parallel sequences to decode
9295
int32_t n_sequences = 1; // number of sequences to decode
@@ -207,6 +210,8 @@ struct gpt_params {
207210

208211
std::string cache_type_k = "f16"; // KV cache data type for the K
209212
std::string cache_type_v = "f16"; // KV cache data type for the V
213+
std::string cache_type_k_draft = ""; // KV cache data type for K for the draft model
214+
std::string cache_type_v_draft = ""; // KV cache data type for V for the draft model
210215

211216
// multimodal models (see examples/llava)
212217
std::string mmproj = ""; // path to multimodal projector

common/sampling.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
442442
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
443443
}
444444

445-
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
445+
ctx_sampling->cur_p = { cur.data(), cur.size(), false };
446+
447+
llama_token_data_array & cur_p = ctx_sampling->cur_p;
446448

447449
// apply penalties
448450
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
@@ -506,3 +508,47 @@ void llama_sampling_accept(
506508
llama_sampler_dry_accept(ctx_sampling->smpl, id);
507509
}
508510
}
511+
512+
llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling) {
513+
return &ctx_sampling->cur_p;
514+
}
515+
516+
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
517+
std::vector<int> idxs(draft.size() + 1);
518+
for (size_t i = 0; i < idxs.size(); ++i) {
519+
idxs[i] = i;
520+
}
521+
522+
return llama_sampling_sample_and_accept_n(gsmpl, ctx, idxs, draft);
523+
}
524+
525+
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft) {
526+
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
527+
528+
std::vector<llama_token> result;
529+
result.reserve(idxs.size());
530+
531+
size_t i = 0;
532+
for (; i < draft.size(); i++) {
533+
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
534+
535+
llama_sampling_accept(gsmpl, ctx, id, true);
536+
537+
result.push_back(id);
538+
539+
if (draft[i] != id) {
540+
break;
541+
}
542+
}
543+
544+
if (i == draft.size()) {
545+
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
546+
547+
llama_sampling_accept(gsmpl, ctx, id, true);
548+
549+
result.push_back(id);
550+
}
551+
552+
return result;
553+
}
554+

common/sampling.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ struct llama_sampling_context {
101101

102102
size_t n_valid; // Number of correct top tokens with correct probabilities.
103103

104+
llama_token_data_array cur_p; // current candidates
105+
104106
std::mt19937 rng;
105107
};
106108

@@ -176,3 +178,11 @@ void llama_sampling_accept(
176178
struct llama_context * ctx_main,
177179
llama_token id,
178180
bool apply_grammar);
181+
182+
// returns at least 1 token, up to draft.size()
183+
// access the internal list of current candidate tokens
184+
llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling);
185+
186+
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft);
187+
188+
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft);

0 commit comments

Comments
 (0)