Skip to content

Commit a22250d

Browse files
ikawrakowIwan Kawrakow
andauthored
llama-bench: enable having different number of threads for tg and pp (#284)
* llama-bench: enable having different number of threads for tg and pp * Add -tgb to usage --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 279b7d3 commit a22250d

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ static uint64_t get_time_ns() {
4141
return std::chrono::nanoseconds(clock::now().time_since_epoch()).count();
4242
}
4343

44+
template <typename T1, typename T2>
45+
std::ostream& operator<<(std::ostream& str, const std::pair<T1, T2>& item) {
46+
str << '{' << item.first << ", " << item.second << '}';
47+
return str;
48+
}
49+
4450
template<class T>
4551
static std::string join(const std::vector<T> & values, const std::string & delim) {
4652
std::ostringstream str;
@@ -228,7 +234,7 @@ struct cmd_params {
228234
std::vector<int> n_ubatch;
229235
std::vector<ggml_type> type_k;
230236
std::vector<ggml_type> type_v;
231-
std::vector<int> n_threads;
237+
std::vector<std::pair<int,int>> n_threads;
232238
std::vector<int> n_gpu_layers;
233239
std::vector<std::string> rpc_servers;
234240
std::vector<llama_split_mode> split_mode;
@@ -263,7 +269,7 @@ static const cmd_params cmd_params_defaults = {
263269
/* n_ubatch */ {512},
264270
/* type_k */ {GGML_TYPE_F16},
265271
/* type_v */ {GGML_TYPE_F16},
266-
/* n_threads */ {cpu_get_num_math()},
272+
/* n_threads */ {{cpu_get_num_math(), cpu_get_num_math()}},
267273
/* n_gpu_layers */ {99},
268274
/* rpc_servers */ {""},
269275
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
@@ -303,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) {
303309
printf(" -ctk, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
304310
printf(" -ctv, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
305311
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
312+
printf(" -tgb, --threads-gen-batch <n1,n2> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
306313
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
307314
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
308315
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
@@ -538,7 +545,23 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
538545
break;
539546
}
540547
auto p = string_split<int>(argv[i], split_delim);
541-
params.n_threads.insert(params.n_threads.end(), p.begin(), p.end());
548+
params.n_threads.reserve(params.n_threads.size() + p.size());
549+
for (auto t : p) params.n_threads.push_back({t, t});
550+
//params.n_threads.insert(params.n_threads.end(), p.begin(), p.end());
551+
} else if (arg == "-tgb" || arg == "--threads-gen-batch") {
552+
if (++i >= argc) {
553+
invalid_param = true;
554+
break;
555+
}
556+
auto ps = string_split<std::string>(argv[i], ';');
557+
for (auto& s : ps) {
558+
auto p = string_split<int>(s.c_str(), ',');
559+
if (p.size() != 2) {
560+
invalid_param = true;
561+
break;
562+
}
563+
params.n_threads.push_back({p[0], p[1]});
564+
}
542565
} else if (arg == "-ngl" || arg == "--n-gpu-layers") {
543566
if (++i >= argc) {
544567
invalid_param = true;
@@ -775,7 +798,7 @@ struct cmd_params_instance {
775798
int n_ubatch;
776799
ggml_type type_k;
777800
ggml_type type_v;
778-
int n_threads;
801+
std::pair<int,int> n_threads;
779802
int n_gpu_layers;
780803
std::string rpc_servers;
781804
llama_split_mode split_mode;
@@ -1024,7 +1047,7 @@ struct test {
10241047
uint64_t model_n_params;
10251048
int n_batch;
10261049
int n_ubatch;
1027-
int n_threads;
1050+
std::pair<int,int> n_threads;
10281051
bool has_rpc;
10291052
ggml_type type_k;
10301053
ggml_type type_v;
@@ -1218,14 +1241,15 @@ struct test {
12181241
str << ser.first << ',' << ser.second;
12191242
return str.str();
12201243
};
1244+
bool is_gen = n_gen > 0;
12211245
std::vector<std::string> values = {
12221246
build_commit, std::to_string(build_number),
12231247
std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan),
12241248
std::to_string(metal), std::to_string(sycl), std::to_string(has_rpc), std::to_string(gpu_blas), std::to_string(blas),
12251249
cpu_info, gpu_info,
12261250
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
12271251
std::to_string(n_batch), std::to_string(n_ubatch),
1228-
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
1252+
std::to_string(is_gen ? n_threads.first : n_threads.second), ggml_type_name(type_k), ggml_type_name(type_v),
12291253
std::to_string(n_gpu_layers), split_mode_str(split_mode),
12301254
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
12311255
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
@@ -1787,10 +1811,10 @@ int main(int argc, char ** argv) {
17871811
if (params.warmup) {
17881812
if (t.n_prompt > 0) {
17891813
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1790-
test_prompt(ctx, 1, 0, t.n_batch, t.n_threads);
1814+
test_prompt(ctx, 1, 0, t.n_batch, t.n_threads.second);
17911815
}
17921816
if (t.n_gen > 0) {
1793-
test_gen(ctx, 1, 0, t.n_threads);
1817+
test_gen(ctx, 1, 0, t.n_threads.first);
17941818
}
17951819
}
17961820

@@ -1800,11 +1824,11 @@ int main(int argc, char ** argv) {
18001824
uint64_t t_start = get_time_ns();
18011825

18021826
if (t.n_prompt > 0) {
1803-
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
1827+
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads.second);
18041828
}
18051829
if (t.test_kind == TEST_KIND_GP) t_start = get_time_ns();
18061830
if (t.n_gen > 0) {
1807-
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
1831+
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads.first);
18081832
}
18091833

18101834
uint64_t t_ns = get_time_ns() - t_start;

0 commit comments

Comments
 (0)