Skip to content

Commit 5d413c3

Browse files
committed
ggml : cleanup
1 parent 48f1225 commit 5d413c3

File tree

5 files changed

+78
-18
lines changed

5 files changed

+78
-18
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
10091009
return res;
10101010
}
10111011

1012+
// note: reuse the argsort kernel for top_k
1013+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
1014+
assert(op->op == GGML_OP_TOP_K);
1015+
1016+
char base[256];
1017+
char name[256];
1018+
1019+
// note: the top_k kernel is always descending order
1020+
ggml_sort_order order = GGML_SORT_ORDER_DESC;
1021+
1022+
const char * order_str = "undefined";
1023+
switch (order) {
1024+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1025+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1026+
default: GGML_ABORT("fatal error");
1027+
};
1028+
1029+
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1030+
snprintf(name, 256, "%s", base);
1031+
1032+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1033+
if (res) {
1034+
return res;
1035+
}
1036+
1037+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1038+
1039+
return res;
1040+
}
1041+
1042+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1043+
assert(op->op == GGML_OP_TOP_K);
1044+
1045+
char base[256];
1046+
char name[256];
1047+
1048+
ggml_sort_order order = GGML_SORT_ORDER_DESC;
1049+
1050+
const char * order_str = "undefined";
1051+
switch (order) {
1052+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1053+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1054+
default: GGML_ABORT("fatal error");
1055+
};
1056+
1057+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1058+
snprintf(name, 256, "%s", base);
1059+
1060+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1061+
if (res) {
1062+
return res;
1063+
}
1064+
1065+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1066+
1067+
return res;
1068+
}
1069+
10121070
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
10131071
ggml_metal_library_t lib,
10141072
const struct ggml_tensor * op,

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
128128
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
129129
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
130130
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
131+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
132+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
131133
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
132134
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
133135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3764,18 +3764,17 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
37643764
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
37653765
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
37663766

3767-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3767+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
37683768

37693769
// bitonic sort requires the number of elements to be power of 2
37703770
int nth = 1;
37713771
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
37723772
nth *= 2;
37733773
}
37743774

3775+
// blocks per row
37753776
const int npr = (ne00 + nth - 1)/nth;
37763777

3777-
// Metal kernels require the buffer size to be multiple of 16 bytes
3778-
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
37793778
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
37803779

37813780
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
@@ -3803,7 +3802,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38033802
/*.ne1 =*/ ne1,
38043803
/*.ne2 =*/ ne2,
38053804
/*.ne3 =*/ ne3,
3806-
/*.top_k =*/ std::min(nth, top_k),
3805+
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
38073806
};
38083807

38093808
if (npr > 1) {
@@ -3819,13 +3818,18 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38193818

38203819
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
38213820

3822-
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3821+
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
38233822

38243823
int len = args.top_k;
38253824

38263825
while (len < args.ne0) {
38273826
ggml_metal_op_concurrency_reset(ctx);
38283827

3828+
// merges per row
3829+
const int nm = (args.ne0 + 2*len - 1) / (2*len);
3830+
3831+
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
3832+
38293833
ggml_metal_kargs_argsort_merge args_merge = {
38303834
/*.ne00 =*/ ne00,
38313835
/*.ne01 =*/ ne01,
@@ -3839,15 +3843,10 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38393843
/*.ne1 =*/ ne1,
38403844
/*.ne2 =*/ ne2,
38413845
/*.ne3 =*/ ne3,
3842-
/*.top_k =*/ 2*len >= args.ne0 ? top_k : args.ne0,
3846+
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
38433847
/*.len =*/ len,
38443848
};
38453849

3846-
// merges per row
3847-
const int nm = (args.ne0 + 2*len - 1) / (2*len);
3848-
3849-
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
3850-
38513850
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
38523851
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
38533852
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4779,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32(
47794779
const int chunk = (total + ntg.x - 1) / ntg.x;
47804780

47814781
const int k0 = tpitg.x * chunk;
4782-
const int k1 = min(k0 + chunk, total);
4782+
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
4783+
4784+
if (k0 >= args.top_k) {
4785+
return;
4786+
}
47834787

47844788
if (k0 >= total) {
47854789
return;
@@ -4830,16 +4834,16 @@ kernel void kernel_argsort_merge_f32_i32(
48304834
val1 = src0_row[idx1];
48314835
}
48324836

4833-
for (int k = k0; k < k1 && k < args.top_k; ++k) {
4837+
for (int k = k0; k < k1; ++k) {
48344838
int32_t out_idx;
48354839

48364840
if (i >= len0) {
4837-
while (k < k1 && k < args.top_k) {
4841+
while (k < k1) {
48384842
dst[k++] = tmp1[j++];
48394843
}
48404844
break;
48414845
} else if (j >= len1) {
4842-
while (k < k1 && k < args.top_k) {
4846+
while (k < k1) {
48434847
dst[k++] = tmp0[i++];
48444848
}
48454849
break;

ggml/src/ggml.c

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5158,9 +5158,6 @@ struct ggml_tensor * ggml_top_k(
51585158

51595159
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
51605160

5161-
// TODO: tmp
5162-
ggml_set_op_params_i32(result, 0, (int32_t) GGML_SORT_ORDER_DESC);
5163-
51645161
result->op = GGML_OP_TOP_K;
51655162
result->src[0] = a;
51665163

0 commit comments

Comments
 (0)