diff --git a/common/common.cpp b/common/common.cpp index 287ec58f3..e7ade95d4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1012,6 +1012,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.fused_moe_up_gate = true; return true; } + if (arg == "-ger" || arg == "--grouped-expert-routing") { + params.grouped_expert_routing = true; + return true; + } if (arg == "-no-fug" || arg == "--no-fused-up-gate") { params.fused_up_gate = false; return true; @@ -1800,6 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); + options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" }); options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" @@ -2755,6 +2760,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate = params.fused_moe_up_gate; + cparams.grouped_expert_routing = params.grouped_expert_routing; cparams.fused_up_gate = params.fused_up_gate; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; @@ -3871,6 +3877,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); + fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false"); fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false"); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 2b4d1540d..ddd507554 100644 --- a/common/common.h +++ b/common/common.h @@ -235,6 +235,7 @@ struct gpt_params { int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool fused_up_gate = true; // fused up*unary(gate) op + bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch) int min_experts = -1; float thresh_experts = 0; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 6bb646bdc..c30151344 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -261,6 +261,7 @@ struct cmd_params { bool warmup; bool repack = false; bool fmoe = false; + bool ger = false; // ger = Grouped Expert Routing bool no_fug = false; bool use_thp = false; output_formats output_format; @@ -296,9 +297,10 @@ static const cmd_params cmd_params_defaults = { /* verbose */ false, /* warmup */ true, /* repack */ false, - /* use_thp */ false, /* fmoe */ false, + /* ger */ false, /* no_fug */ false, + /* use_thp */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -341,6 +343,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0"); printf(" -ot, --override-tensor pattern (default: none)\n"); printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); + printf(" -ger, --grouped-expert-routing <0|1>(default: %s)\n", cmd_params_defaults.ger ? "1" : "0"); printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); @@ -739,6 +742,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.fmoe = std::stoi(argv[i]); + } else if (arg == "-ger" || arg == "--grouped-expert-routing") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ger = std::stoi(argv[i]); } else if (arg == "-no-fug" || arg == "--no-fused-up-gate") { if (++i >= argc) { invalid_param = true; @@ -829,6 +838,7 @@ struct cmd_params_instance { bool embeddings; bool repack = false; bool fmoe = false; + bool ger = false; bool no_fug = false; bool use_thp = false; const llama_model_tensor_buft_override* buft_overrides; @@ -876,6 +886,7 @@ struct cmd_params_instance { cparams.mla_attn = mla_attn; cparams.attn_max_batch = attn_max_batch; cparams.fused_moe_up_gate = fmoe; + cparams.grouped_expert_routing = ger; cparams.fused_up_gate = !no_fug; cparams.min_experts = ser.first; cparams.thresh_experts = ser.second; @@ -935,6 +946,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .ger = */ params.ger, /* .no_fug = */ params.no_fug, /* .use_thp = */ params.use_thp, /* .buft_overrides=*/ params.buft_overrides.data(), @@ -970,6 +982,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .ger = */ params.ger, /* .no_fug = */ params.no_fug, /* .use_thp = */ params.use_thp, /* .buft_overrides=*/ params.buft_overrides.data(), @@ -1005,6 +1018,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .ger = */ params.ger, /* .no_fug = */ params.no_fug, /* .use_thp = */ params.use_thp, /* .buft_overrides=*/ params.buft_overrides.data(), @@ -1040,6 +1054,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .ger = */ params.ger, /* .no_fug = */ params.no_fug, /* .use_thp = */ params.use_thp, /* .buft_overrides=*/ params.buft_overrides.data(), @@ -1086,6 +1101,7 @@ struct test { bool embeddings; bool repack = false; bool fmoe = false; + bool ger = false; bool no_fug = false; bool use_thp = false; int n_prompt; @@ -1120,6 +1136,8 @@ struct test { use_mmap = inst.use_mmap; embeddings = inst.embeddings; repack = inst.repack; + fmoe = inst.fmoe; + ger = inst.ger; no_fug = inst.no_fug; use_thp = inst.use_thp; n_prompt = inst.n_prompt; @@ -1212,7 +1230,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", - "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "fused_up_gate", "use_thp", + "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1234,7 +1252,7 @@ struct test { if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || - field == "fused_moe" || field == "fused_up_gate") { + field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1277,7 +1295,8 @@ struct test { std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), - std::to_string(repack), std::to_string(fmoe), std::to_string(no_fug), std::to_string(use_thp), + std::to_string(repack), std::to_string(fmoe), std::to_string(ger), + std::to_string(no_fug), std::to_string(use_thp), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()), @@ -1461,6 +1480,9 @@ struct markdown_printer : public printer { if (field == "fused_moe") { return 4; } + if (field == "grouped_er") { + return 3; + } if (field == "fused_up_gate") { return 6; } @@ -1513,6 +1535,12 @@ struct markdown_printer : public printer { if (field == "fused_moe") { return "fmoe"; } + if (field == "grouped_er") { + return "ger"; + } + if (field == "grouped_er") { + return "ger"; + } if (field == "fused_up_gate") { return "no-fug"; } @@ -1589,6 +1617,9 @@ struct markdown_printer : public printer { if (params.fmoe != cmd_params_defaults.fmoe) { fields.emplace_back("fused_moe"); } + if (params.ger != cmd_params_defaults.ger) { + fields.emplace_back("grouped_er"); + } if (params.no_fug != cmd_params_defaults.no_fug) { fields.emplace_back("fused_up_gate"); } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a5fda8a4e..41d691bbb 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -650,6 +650,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_ARGSORT_THRESH, + GGML_OP_GROUPED_TOPK, GGML_OP_LEAKY_RELU, GGML_OP_SOFTCAP, GGML_OP_SOFT_CAP_MAX, @@ -2265,6 +2266,13 @@ extern "C" { int k, int min_entries, float thresh); + GGML_API struct ggml_tensor * ggml_grouped_topk( + struct ggml_context * ctx, + struct ggml_tensor * a, + int num_groups, + int num_top_groups, + int nk, + int topk_experts); #define GGML_KQ_MASK_PAD 16 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bfd5e41e9..63c0b9950 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4253,6 +4253,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "ARGSORT_THRESH", + "GROUPED_TOPK", "LEAKY_RELU", "SOFTCAP", "SOFT_CAP_MAX", @@ -4288,7 +4289,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4356,6 +4357,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "argsort_thresh(x)", + "grouped_topk(x)", "leaky_relu(x)", "k2*tanh(k1*x)", "soft_max(k2*tanh(k1*x))", @@ -4391,7 +4393,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)," }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -9439,6 +9441,39 @@ struct ggml_tensor * ggml_argsort_thresh( return result; } +struct ggml_tensor * ggml_grouped_topk( + struct ggml_context * ctx, + struct ggml_tensor * a, + int num_groups, + int num_top_groups, + int nk, + int topk_experts) { + + GGML_ASSERT(num_top_groups <= num_groups); + GGML_ASSERT(a->ne[0] % num_groups == 0); + GGML_ASSERT(a->ne[0] >= topk_experts); + int64_t n_per_group = a->ne[0] / num_groups; + GGML_ASSERT(n_per_group >= nk); + + bool is_node = false; + + int64_t ne[GGML_MAX_DIMS]; + for (int i = 1; i < GGML_MAX_DIMS; ++i) ne[i] = a->ne[i]; + ne[0] = topk_experts; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, ne); + + ggml_set_op_params_i32(result, 0, num_groups); + ggml_set_op_params_i32(result, 1, num_top_groups); + ggml_set_op_params_i32(result, 2, nk); + + result->op = GGML_OP_GROUPED_TOPK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + + // ggml_top_k struct ggml_tensor * ggml_top_k( @@ -20024,6 +20059,24 @@ static void ggml_compute_forward_argsort_thresh( } } +static void ggml_compute_forward_grouped_topk( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + iqk_grouped_top_k(dst, params->ith, params->nth); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -22521,6 +22574,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_argsort_thresh(params, tensor); } break; + case GGML_OP_GROUPED_TOPK: + { + ggml_compute_forward_grouped_topk(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -23539,6 +23596,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_GROUPED_TOPK: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_LEAKY_RELU: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -24281,6 +24342,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_ARGSORT_THRESH: + case GGML_OP_GROUPED_TOPK: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 74de1479a..115f25ceb 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -10,8 +10,97 @@ #include #include #include +#include -void iqk_grouped_top_k([[maybe_unused]] ggml_tensor * dst, [[maybe_unused]] int ith, [[maybe_unused]] int nth) { +namespace { +// Playing around with group scores: use sum of probabilities in the group +inline float group_score(int n_per_group, const float * data) { + float sum = 0; + for (int j = 0; j < n_per_group; ++j) sum += data[j]; + return sum; +} +// Playing around with group scores: use max of probabilities in the group +inline float group_score_max(int n_per_group, const float * data) { + float max = data[0]; + for (int j = 1; j < n_per_group; ++j) max = std::max(max, data[j]); + return max; +} +// Actual top-nk group score: sum of top-nk probabilities in the group +inline float group_score(int n_per_group, int nk, const float * data, float * aux) { + for (int j = 0; j < n_per_group; ++j) aux[j] = data[j]; + std::partial_sort(aux, aux + nk, aux + n_per_group, std::greater{}); + float sum = 0; + for (int j = 0; j < nk; ++j) sum += aux[j]; + return sum; +} +inline std::vector> & get_work_buffer(size_t size) { + thread_local std::vector> buffer; + if (buffer.size() < size) buffer.resize(size); + return buffer; + +} +} + +void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) { + auto src = dst->src[0]; + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nrows(src) == ggml_nrows(dst)); + + auto nrows = ggml_nrows(src); + auto npt = (nrows + nth - 1)/nth; + auto first = npt*ith; + auto last = std::min(first + npt, nrows); + if (last <= first) return; + + int n_groups = dst->op_params[0]; + int n_top_groups = dst->op_params[1]; + int nk = dst->op_params[2]; + + int ne00 = src->ne[0]; + int ne0 = dst->ne[0]; + GGML_ASSERT(ne0 <= ne00); + GGML_ASSERT(ne00%n_groups == 0); + int n_per_group = ne00/n_groups; + GGML_ASSERT(nk <= n_per_group); + GGML_ASSERT(n_top_groups <= n_groups); + + size_t work_size = n_groups + n_per_group*n_top_groups; + auto& aux = get_work_buffer(work_size); + + auto groups = aux.data() + n_per_group*n_top_groups; + + for (int ir = first; ir < last; ++ir) { + auto data = (const float *)((const char *)src->data + ir*src->nb[1]); + auto result = (int32_t *)((char *)dst->data + ir*dst->nb[1]); + if (ne0 > n_per_group*n_top_groups) { + for (int j = 0; j < ne0; ++j) result[j] = j; + continue; + } + if (n_top_groups < n_groups) { + for (int ig = 0; ig < n_groups; ++ig) { + //groups[ig] = { group_score(n_per_group, data + ig*n_per_group), ig }; + //groups[ig] = { group_score_max(n_per_group, data + ig*n_per_group), ig }; + groups[ig] = { group_score(n_per_group, nk, data + ig*n_per_group, (float *)aux.data()), ig }; + } + std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater>{}); + + for (int ig = 0; ig < n_top_groups; ++ig) { + int i0 = n_per_group * ig; + int j0 = n_per_group * groups[ig].second; + for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { data[j0 + j], j0 + j }; + } + } else { + for (int j = 0; j < ne00; ++j) aux[j] = { data[j], j }; + } + if (ne0 < n_top_groups*n_per_group) { + std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + n_top_groups*n_per_group, std::greater>{}); + } else { + std::sort(aux.begin(), aux.begin() + ne0, std::greater>{}); + } + for (int j = 0; j < ne0; ++j) result[j] = aux[j].second; + + } } void iqk_argsort(ggml_tensor * dst, int ith, int nth) { @@ -30,8 +119,7 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) { int nk = dst->op_params[1]; int ne00 = src->ne[0]; - thread_local std::vector> aux; - if ((int)aux.size() < ne00) aux.resize(ne00); + auto& aux = get_work_buffer(ne00); for (int ir = first; ir < last; ++ir) { auto data = (const float *)((const char *)src->data + ir*src->nb[1]); diff --git a/include/llama.h b/include/llama.h index 4f9fc9c85..d24de2306 100644 --- a/include/llama.h +++ b/include/llama.h @@ -420,6 +420,7 @@ extern "C" { int mla_attn; // whether to use MLA attention [EXPERIMENTAL] int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/gate op + bool grouped_expert_routing; // whether to use grouped expert routing (BailingMoeV2 arch) bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL] int min_experts; float thresh_experts; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index e1817bae0..b3208dcd6 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -48,6 +48,7 @@ llm_build_context::llm_build_context( mla_attn (cparams.mla_attn), attn_max_batch (cparams.attn_max_batch), fused_moe_up_gate(cparams.fused_moe_up_gate), + grouped_expert_routing(cparams.grouped_expert_routing), fused_up_gate (cparams.fused_up_gate), min_experts (cparams.min_experts), thresh_experts (cparams.thresh_experts), @@ -820,42 +821,15 @@ llm_expert_gating_func_type gating_op, selection_probs = logits; } - if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { + // select experts + ggml_tensor * selected_experts; + if (lctx.cparams.grouped_expert_routing && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { auto& hparams = lctx.model.hparams; - const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups; - - // organize experts into n_expert_groups - ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups] -#if 0 - ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups] - group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups] - - // get top n_group_used expert groups - group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1] -#else - // Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead - ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups] - group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups] - - // get top n_group_used expert groups - group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1] -#endif - ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1] - cb(expert_groups->src[0], "ffn_moe_group_argsort", il); - cb(expert_groups, "ffn_moe_group_topk", il); - - // mask out the other groups - selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used] - selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups] - selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert] - selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens] - cb(selection_probs, "ffn_moe_probs_masked", il); + selected_experts = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2, n_expert_used); + } else { + selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, + lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] } - - // select experts - ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, - lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] - cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] diff --git a/src/llama-build-context.h b/src/llama-build-context.h index a1f0b8ae9..2381a6562 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -78,6 +78,7 @@ struct llm_build_context { const int mla_attn; const int attn_max_batch; const bool fused_moe_up_gate; + const bool grouped_expert_routing; const bool fused_up_gate; const int min_experts; const float thresh_experts; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index e8ec0f745..cbfb4949e 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -31,6 +31,7 @@ struct llama_cparams { int mla_attn; int attn_max_batch; bool fused_moe_up_gate; + bool grouped_expert_routing; bool fused_up_gate; int min_experts; float thresh_experts; diff --git a/src/llama.cpp b/src/llama.cpp index 5dc11e477..928d66b03 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3754,6 +3754,7 @@ struct llama_context_params llama_context_default_params() { /*.mla_attn =*/ 0, /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ false, + /*.grouped_expert_routing =*/ false, /*.fused_up_gate =*/ true, /*.min_experts =*/ -1, /*.thtesh_experts =*/ 0.0f, @@ -3963,6 +3964,7 @@ struct llama_context * llama_new_context_with_model( cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate= params.fused_moe_up_gate; + cparams.grouped_expert_routing = params.grouped_expert_routing; cparams.fused_up_gate = params.fused_up_gate; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; @@ -4043,6 +4045,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); + LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing); LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);