Skip to content

Commit a283069

Browse files
committed
cont : add ggml_argsort_top_k
1 parent 20f1050 commit a283069

File tree

3 files changed

+53
-29
lines changed

3 files changed

+53
-29
lines changed

ggml/include/ggml.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,18 +2259,24 @@ extern "C" {
22592259
struct ggml_tensor * a,
22602260
enum ggml_sort_order order);
22612261

2262-
GGML_API struct ggml_tensor * ggml_arange(
2262+
// same as ggml_top_k but implemented as `argsort` + `view`
2263+
GGML_API struct ggml_tensor * ggml_argsort_top_k(
22632264
struct ggml_context * ctx,
2264-
float start,
2265-
float stop,
2266-
float step);
2265+
struct ggml_tensor * a,
2266+
int k);
22672267

22682268
// top k elements per row
22692269
GGML_API struct ggml_tensor * ggml_top_k(
22702270
struct ggml_context * ctx,
22712271
struct ggml_tensor * a,
22722272
int k);
22732273

2274+
GGML_API struct ggml_tensor * ggml_arange(
2275+
struct ggml_context * ctx,
2276+
float start,
2277+
float stop,
2278+
float step);
2279+
22742280
#define GGML_KQ_MASK_PAD 64
22752281

22762282
// q: [n_embd_k, n_batch, n_head, ne3 ]

ggml/src/ggml.c

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5038,28 +5038,6 @@ struct ggml_tensor * ggml_roll(
50385038
return result;
50395039
}
50405040

5041-
// ggml_arange
5042-
5043-
struct ggml_tensor * ggml_arange(
5044-
struct ggml_context * ctx,
5045-
float start,
5046-
float stop,
5047-
float step) {
5048-
GGML_ASSERT(stop > start);
5049-
5050-
const int64_t steps = (int64_t) ceilf((stop - start) / step);
5051-
5052-
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
5053-
5054-
ggml_set_op_params_f32(result, 0, start);
5055-
ggml_set_op_params_f32(result, 1, stop);
5056-
ggml_set_op_params_f32(result, 2, step);
5057-
5058-
result->op = GGML_OP_ARANGE;
5059-
5060-
return result;
5061-
}
5062-
50635041
// ggml_timestep_embedding
50645042

50655043
struct ggml_tensor * ggml_timestep_embedding(
@@ -5152,6 +5130,24 @@ struct ggml_tensor * ggml_argsort(
51525130
return result;
51535131
}
51545132

5133+
// ggml_argsort_top_k
5134+
5135+
struct ggml_tensor * ggml_argsort_top_k(
5136+
struct ggml_context * ctx,
5137+
struct ggml_tensor * a,
5138+
int k) {
5139+
GGML_ASSERT(a->ne[0] >= k);
5140+
5141+
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
5142+
5143+
result = ggml_view_4d(ctx, result,
5144+
k, result->ne[1], result->ne[2], result->ne[3],
5145+
result->nb[1], result->nb[2], result->nb[3],
5146+
0);
5147+
5148+
return result;
5149+
}
5150+
51555151
// ggml_top_k
51565152

51575153
struct ggml_tensor * ggml_top_k(
@@ -5170,6 +5166,28 @@ struct ggml_tensor * ggml_top_k(
51705166
return result;
51715167
}
51725168

5169+
// ggml_arange
5170+
5171+
struct ggml_tensor * ggml_arange(
5172+
struct ggml_context * ctx,
5173+
float start,
5174+
float stop,
5175+
float step) {
5176+
GGML_ASSERT(stop > start);
5177+
5178+
const int64_t steps = (int64_t) ceilf((stop - start) / step);
5179+
5180+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
5181+
5182+
ggml_set_op_params_f32(result, 0, start);
5183+
ggml_set_op_params_f32(result, 1, stop);
5184+
ggml_set_op_params_f32(result, 2, step);
5185+
5186+
result->op = GGML_OP_ARANGE;
5187+
5188+
return result;
5189+
}
5190+
51735191
// ggml_flash_attn_ext
51745192

51755193
struct ggml_tensor * ggml_flash_attn_ext(

src/llama-graph.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
961961
// organize experts into n_expert_groups
962962
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
963963

964-
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
964+
ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
965965
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
966966

967967
// get top n_group_used expert groups
968968
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
969969
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
970970

971-
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
971+
ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
972972
cb(expert_groups, "ffn_moe_group_topk", il);
973973

974974
// mask out the other groups
@@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
979979
}
980980

981981
// select experts
982-
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
982+
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
983983
cb(selected_experts->src[0], "ffn_moe_argsort", il);
984984
cb(selected_experts, "ffn_moe_topk", il);
985985

0 commit comments

Comments
 (0)