Skip to content

Commit 4dea5dd

Browse files
committed
metal : add top_k support
1 parent a283069 commit 4dea5dd

File tree

9 files changed

+250
-43
lines changed

9 files changed

+250
-43
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7888,7 +7888,7 @@ static void ggml_compute_forward_top_k_f32(
78887888

78897889
const int64_t nr = ggml_nrows(src0);
78907890

7891-
const int k = ggml_get_op_params_i32(dst, 0);
7891+
const int top_k = ne0;
78927892

78937893
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
78947894

@@ -7899,11 +7899,11 @@ static void ggml_compute_forward_top_k_f32(
78997899
tmp[j] = j;
79007900
}
79017901

7902-
std::partial_sort(tmp, tmp + k, tmp + ne00, cmp_top_k{src_data});
7902+
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
79037903

79047904
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
79057905

7906-
std::copy(tmp, tmp + k, dst_data);
7906+
std::copy(tmp, tmp + top_k, dst_data);
79077907
}
79087908
}
79097909

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
905905
case GGML_OP_LEAKY_RELU:
906906
return op->src[0]->type == GGML_TYPE_F32;
907907
case GGML_OP_ARGSORT:
908+
case GGML_OP_TOP_K:
908909
case GGML_OP_ARANGE:
909910
return true;
910911
case GGML_OP_FLASH_ATTN_EXT:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -832,14 +832,19 @@ typedef struct {
832832
} ggml_metal_kargs_leaky_relu;
833833

834834
typedef struct {
835-
int64_t ne00;
836-
int64_t ne01;
837-
int64_t ne02;
838-
int64_t ne03;
835+
int32_t ne00;
836+
int32_t ne01;
837+
int32_t ne02;
838+
int32_t ne03;
839839
uint64_t nb00;
840840
uint64_t nb01;
841841
uint64_t nb02;
842842
uint64_t nb03;
843+
int32_t ne0;
844+
int32_t ne1;
845+
int32_t ne2;
846+
int32_t ne3;
847+
int32_t top_k;
843848
} ggml_metal_kargs_argsort;
844849

845850
typedef struct {
@@ -851,6 +856,11 @@ typedef struct {
851856
uint64_t nb01;
852857
uint64_t nb02;
853858
uint64_t nb03;
859+
int32_t ne0;
860+
int32_t ne1;
861+
int32_t ne2;
862+
int32_t ne3;
863+
int32_t top_k;
854864
int32_t len;
855865
} ggml_metal_kargs_argsort_merge;
856866

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 144 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
405405
{
406406
n_fuse = ggml_metal_op_argsort(ctx, idx);
407407
} break;
408+
case GGML_OP_TOP_K:
409+
{
410+
n_fuse = ggml_metal_op_top_k(ctx, idx);
411+
} break;
408412
case GGML_OP_LEAKY_RELU:
409413
{
410414
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
@@ -3677,14 +3681,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
36773681
}
36783682

36793683
ggml_metal_kargs_argsort args = {
3680-
/*.ne00 =*/ ne00,
3681-
/*.ne01 =*/ ne01,
3682-
/*.ne02 =*/ ne02,
3683-
/*.ne03 =*/ ne03,
3684-
/*.nb00 =*/ nb00,
3685-
/*.nb01 =*/ nb01,
3686-
/*.nb02 =*/ nb02,
3687-
/*.nb03 =*/ nb03,
3684+
/*.ne00 =*/ ne00,
3685+
/*.ne01 =*/ ne01,
3686+
/*.ne02 =*/ ne02,
3687+
/*.ne03 =*/ ne03,
3688+
/*.nb00 =*/ nb00,
3689+
/*.nb01 =*/ nb01,
3690+
/*.nb02 =*/ nb02,
3691+
/*.nb03 =*/ nb03,
3692+
/*.ne0 =*/ ne0,
3693+
/*.ne1 =*/ ne1,
3694+
/*.ne2 =*/ ne2,
3695+
/*.ne3 =*/ ne3,
3696+
/*.top_k =*/ nth,
36883697
};
36893698

36903699
ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -3704,15 +3713,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
37043713
ggml_metal_op_concurrency_reset(ctx);
37053714

37063715
ggml_metal_kargs_argsort_merge args_merge = {
3707-
.ne00 = ne00,
3708-
.ne01 = ne01,
3709-
.ne02 = ne02,
3710-
.ne03 = ne03,
3711-
.nb00 = nb00,
3712-
.nb01 = nb01,
3713-
.nb02 = nb02,
3714-
.nb03 = nb03,
3715-
.len = len,
3716+
/*.ne00 =*/ ne00,
3717+
/*.ne01 =*/ ne01,
3718+
/*.ne02 =*/ ne02,
3719+
/*.ne03 =*/ ne03,
3720+
/*.nb00 =*/ nb00,
3721+
/*.nb01 =*/ nb01,
3722+
/*.nb02 =*/ nb02,
3723+
/*.nb03 =*/ nb03,
3724+
/*.ne0 =*/ ne0,
3725+
/*.ne1 =*/ ne1,
3726+
/*.ne2 =*/ ne2,
3727+
/*.ne3 =*/ ne3,
3728+
/*.top_k =*/ ne00,
3729+
/*.len =*/ len,
37163730
};
37173731

37183732
// merges per row
@@ -3736,6 +3750,119 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
37363750
return 1;
37373751
}
37383752

3753+
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3754+
ggml_tensor * op = ctx->node(idx);
3755+
3756+
ggml_metal_library_t lib = ctx->lib;
3757+
ggml_metal_encoder_t enc = ctx->enc;
3758+
3759+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3760+
3761+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3762+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3763+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3764+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3765+
3766+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3767+
3768+
// bitonic sort requires the number of elements to be power of 2
3769+
int nth = 1;
3770+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3771+
nth *= 2;
3772+
}
3773+
3774+
const int npr = (ne00 + nth - 1)/nth;
3775+
3776+
// Metal kernels require the buffer size to be multiple of 16 bytes
3777+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3778+
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3779+
3780+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3781+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3782+
3783+
ggml_metal_buffer_id bid_tmp = bid_dst;
3784+
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
3785+
3786+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3787+
std::swap(bid_dst, bid_tmp);
3788+
}
3789+
3790+
const int top_k = ne0;
3791+
3792+
ggml_metal_kargs_argsort args = {
3793+
/*.ne00 =*/ ne00,
3794+
/*.ne01 =*/ ne01,
3795+
/*.ne02 =*/ ne02,
3796+
/*.ne03 =*/ ne03,
3797+
/*.nb00 =*/ nb00,
3798+
/*.nb01 =*/ nb01,
3799+
/*.nb02 =*/ nb02,
3800+
/*.nb03 =*/ nb03,
3801+
/*.ne0 =*/ ne0,
3802+
/*.ne1 =*/ ne1,
3803+
/*.ne2 =*/ ne2,
3804+
/*.ne3 =*/ ne3,
3805+
/*.top_k =*/ std::min(nth, top_k),
3806+
};
3807+
3808+
if (npr > 1) {
3809+
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
3810+
}
3811+
3812+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3813+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3814+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3815+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3816+
3817+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3818+
3819+
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3820+
3821+
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3822+
3823+
int len = args.top_k;
3824+
3825+
while (len < args.ne0) {
3826+
ggml_metal_op_concurrency_reset(ctx);
3827+
3828+
ggml_metal_kargs_argsort_merge args_merge = {
3829+
/*.ne00 =*/ ne00,
3830+
/*.ne01 =*/ ne01,
3831+
/*.ne02 =*/ ne02,
3832+
/*.ne03 =*/ ne03,
3833+
/*.nb00 =*/ nb00,
3834+
/*.nb01 =*/ nb01,
3835+
/*.nb02 =*/ nb02,
3836+
/*.nb03 =*/ nb03,
3837+
/*.ne0 =*/ args.ne0,
3838+
/*.ne1 =*/ ne1,
3839+
/*.ne2 =*/ ne2,
3840+
/*.ne3 =*/ ne3,
3841+
/*.top_k =*/ 2*len >= args.ne0 ? top_k : args.ne0,
3842+
/*.len =*/ len,
3843+
};
3844+
3845+
// merges per row
3846+
const int nm = (args.ne0 + 2*len - 1) / (2*len);
3847+
3848+
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
3849+
3850+
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3851+
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3852+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3853+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3854+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3855+
3856+
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3857+
3858+
std::swap(bid_dst, bid_tmp);
3859+
3860+
len <<= 1;
3861+
}
3862+
3863+
return 1;
3864+
}
3865+
37393866
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
37403867
ggml_tensor * op = ctx->node(idx);
37413868

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
8181
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
8282
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
8383
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
84+
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
8485
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
8586
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
8687
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

ggml/src/ggml-metal/ggml-metal.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
202202
{
203203
res *= 2;
204204
} break;
205+
case GGML_OP_TOP_K:
206+
{
207+
res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
208+
} break;
205209
default:
206210
break;
207211
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32(
46704670
ushort3 ntg[[threads_per_threadgroup]]) {
46714671
// bitonic sort
46724672
const int col = tpitg[0];
4673+
const int ib = tgpig[0] / args.ne01;
46734674

4674-
const int i00 = (tgpig[0]/args.ne01)*ntg.x;
4675-
const int i01 = tgpig[0]%args.ne01;
4676-
const int i02 = tgpig[1];
4677-
const int i03 = tgpig[2];
4675+
const int i00 = ib*ntg.x;
4676+
const int i01 = tgpig[0] % args.ne01;
4677+
const int i02 = tgpig[1];
4678+
const int i03 = tgpig[2];
46784679

46794680
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
46804681

@@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32(
47104711
}
47114712
}
47124713

4714+
const int64_t i0 = ib*args.top_k;
4715+
47134716
// copy the result to dst without the padding
4714-
if (i00 + col < args.ne00) {
4715-
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
4717+
if (i0 + col < args.ne0 && col < args.top_k) {
4718+
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
47164719

47174720
dst[col] = shmem_i32[col];
47184721
}
@@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32(
47474750

47484751
const int start = im * (2 * args.len);
47494752

4750-
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
4751-
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
4753+
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
4754+
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
47524755

47534756
const int total = len0 + len1;
47544757

47554758
device const int32_t * tmp0 = tmp + start
4756-
+ i01*args.ne00
4757-
+ i02*args.ne00*args.ne01
4758-
+ i03*args.ne00*args.ne01*args.ne02;
4759+
+ i01*args.ne0
4760+
+ i02*args.ne0*args.ne01
4761+
+ i03*args.ne0*args.ne01*args.ne02;
47594762

47604763
device const int32_t * tmp1 = tmp0 + args.len;
47614764

47624765
dst += start
4763-
+ i01*args.ne00
4764-
+ i02*args.ne00*args.ne01
4765-
+ i03*args.ne00*args.ne01*args.ne02;
4766+
+ i01*args.top_k
4767+
+ i02*args.top_k*args.ne01
4768+
+ i03*args.top_k*args.ne01*args.ne02;
47664769

47674770
device const float * src0_row = (device const float *)(src0
47684771
+ args.nb01*i01
@@ -4827,16 +4830,16 @@ kernel void kernel_argsort_merge_f32_i32(
48274830
val1 = src0_row[idx1];
48284831
}
48294832

4830-
for (int k = k0; k < k1; ++k) {
4833+
for (int k = k0; k < k1 && k < args.top_k; ++k) {
48314834
int32_t out_idx;
48324835

48334836
if (i >= len0) {
4834-
while (k < k1) {
4837+
while (k < k1 && k < args.top_k) {
48354838
dst[k++] = tmp1[j++];
48364839
}
48374840
break;
48384841
} else if (j >= len1) {
4839-
while (k < k1) {
4842+
while (k < k1 && k < args.top_k) {
48404843
dst[k++] = tmp0[i++];
48414844
}
48424845
break;

ggml/src/ggml.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5158,7 +5158,8 @@ struct ggml_tensor * ggml_top_k(
51585158

51595159
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
51605160

5161-
ggml_set_op_params_i32(result, 0, (int32_t) k);
5161+
// TODO: tmp
5162+
ggml_set_op_params_i32(result, 0, (int32_t) GGML_SORT_ORDER_DESC);
51625163

51635164
result->op = GGML_OP_TOP_K;
51645165
result->src[0] = a;

0 commit comments

Comments
 (0)