@@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
406406 {
407407 n_fuse = ggml_metal_op_argsort (ctx, idx);
408408 } break ;
409+ case GGML_OP_TOP_K:
410+ {
411+ n_fuse = ggml_metal_op_top_k (ctx, idx);
412+ } break ;
409413 case GGML_OP_LEAKY_RELU:
410414 {
411415 n_fuse = ggml_metal_op_leaky_relu (ctx, idx);
@@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
36783682 }
36793683
36803684 ggml_metal_kargs_argsort args = {
3681- /* .ne00 =*/ ne00,
3682- /* .ne01 =*/ ne01,
3683- /* .ne02 =*/ ne02,
3684- /* .ne03 =*/ ne03,
3685- /* .nb00 =*/ nb00,
3686- /* .nb01 =*/ nb01,
3687- /* .nb02 =*/ nb02,
3688- /* .nb03 =*/ nb03,
3685+ /* .ne00 =*/ ne00,
3686+ /* .ne01 =*/ ne01,
3687+ /* .ne02 =*/ ne02,
3688+ /* .ne03 =*/ ne03,
3689+ /* .nb00 =*/ nb00,
3690+ /* .nb01 =*/ nb01,
3691+ /* .nb02 =*/ nb02,
3692+ /* .nb03 =*/ nb03,
3693+ /* .ne0 =*/ ne0,
3694+ /* .ne1 =*/ ne1,
3695+ /* .ne2 =*/ ne2,
3696+ /* .ne3 =*/ ne3,
3697+ /* .top_k =*/ nth,
36893698 };
36903699
36913700 ggml_metal_encoder_set_pipeline (enc, pipeline);
@@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
37053714 ggml_metal_op_concurrency_reset (ctx);
37063715
37073716 ggml_metal_kargs_argsort_merge args_merge = {
3708- .ne00 = ne00,
3709- .ne01 = ne01,
3710- .ne02 = ne02,
3711- .ne03 = ne03,
3712- .nb00 = nb00,
3713- .nb01 = nb01,
3714- .nb02 = nb02,
3715- .nb03 = nb03,
3716- .len = len,
3717+ /* .ne00 =*/ ne00,
3718+ /* .ne01 =*/ ne01,
3719+ /* .ne02 =*/ ne02,
3720+ /* .ne03 =*/ ne03,
3721+ /* .nb00 =*/ nb00,
3722+ /* .nb01 =*/ nb01,
3723+ /* .nb02 =*/ nb02,
3724+ /* .nb03 =*/ nb03,
3725+ /* .ne0 =*/ ne0,
3726+ /* .ne1 =*/ ne1,
3727+ /* .ne2 =*/ ne2,
3728+ /* .ne3 =*/ ne3,
3729+ /* .top_k =*/ ne00,
3730+ /* .len =*/ len,
37173731 };
37183732
37193733 // merges per row
@@ -3737,6 +3751,119 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
37373751 return 1 ;
37383752}
37393753
3754+ int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx) {
3755+ ggml_tensor * op = ctx->node (idx);
3756+
3757+ ggml_metal_library_t lib = ctx->lib ;
3758+ ggml_metal_encoder_t enc = ctx->enc ;
3759+
3760+ GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
3761+
3762+ GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
3763+ GGML_TENSOR_LOCALS (uint64_t , nb0, op->src [0 ], nb);
3764+ GGML_TENSOR_LOCALS ( int32_t , ne, op, ne);
3765+ GGML_TENSOR_LOCALS (uint64_t , nb, op, nb);
3766+
3767+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort (lib, op);
3768+
3769+ // bitonic sort requires the number of elements to be power of 2
3770+ int nth = 1 ;
3771+ while (nth < ne00 && 2 *nth <= ggml_metal_pipeline_max_theads_per_threadgroup (pipeline)) {
3772+ nth *= 2 ;
3773+ }
3774+
3775+ const int npr = (ne00 + nth - 1 )/nth;
3776+
3777+ // Metal kernels require the buffer size to be multiple of 16 bytes
3778+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3779+ const size_t smem = GGML_PAD (nth*sizeof (int32_t ), 16 );
3780+
3781+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
3782+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id (op);
3783+
3784+ ggml_metal_buffer_id bid_tmp = bid_dst;
3785+ bid_tmp.offs += sizeof (int32_t )*ggml_nelements (op->src [0 ]);
3786+
3787+ if ((int ) ceil (std::log (npr) / std::log (2 )) % 2 == 1 ) {
3788+ std::swap (bid_dst, bid_tmp);
3789+ }
3790+
3791+ const int top_k = ne0;
3792+
3793+ ggml_metal_kargs_argsort args = {
3794+ /* .ne00 =*/ ne00,
3795+ /* .ne01 =*/ ne01,
3796+ /* .ne02 =*/ ne02,
3797+ /* .ne03 =*/ ne03,
3798+ /* .nb00 =*/ nb00,
3799+ /* .nb01 =*/ nb01,
3800+ /* .nb02 =*/ nb02,
3801+ /* .nb03 =*/ nb03,
3802+ /* .ne0 =*/ ne0,
3803+ /* .ne1 =*/ ne1,
3804+ /* .ne2 =*/ ne2,
3805+ /* .ne3 =*/ ne3,
3806+ /* .top_k =*/ std::min (nth, top_k),
3807+ };
3808+
3809+ if (npr > 1 ) {
3810+ args.ne0 = (npr - 1 )*args.top_k + std::min (ne00 - (npr - 1 )*nth, args.top_k );
3811+ }
3812+
3813+ ggml_metal_encoder_set_pipeline (enc, pipeline);
3814+ ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
3815+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
3816+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2 );
3817+
3818+ ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
3819+
3820+ ggml_metal_encoder_dispatch_threadgroups (enc, npr*ne01, ne02, ne03, nth, 1 , 1 );
3821+
3822+ ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge (lib, op);
3823+
3824+ int len = args.top_k ;
3825+
3826+ while (len < args.ne0 ) {
3827+ ggml_metal_op_concurrency_reset (ctx);
3828+
3829+ ggml_metal_kargs_argsort_merge args_merge = {
3830+ /* .ne00 =*/ ne00,
3831+ /* .ne01 =*/ ne01,
3832+ /* .ne02 =*/ ne02,
3833+ /* .ne03 =*/ ne03,
3834+ /* .nb00 =*/ nb00,
3835+ /* .nb01 =*/ nb01,
3836+ /* .nb02 =*/ nb02,
3837+ /* .nb03 =*/ nb03,
3838+ /* .ne0 =*/ args.ne0 ,
3839+ /* .ne1 =*/ ne1,
3840+ /* .ne2 =*/ ne2,
3841+ /* .ne3 =*/ ne3,
3842+ /* .top_k =*/ 2 *len >= args.ne0 ? top_k : args.ne0 ,
3843+ /* .len =*/ len,
3844+ };
3845+
3846+ // merges per row
3847+ const int nm = (args.ne0 + 2 *len - 1 ) / (2 *len);
3848+
3849+ const int nth = std::min (512 , std::min (len, ggml_metal_pipeline_max_theads_per_threadgroup (pipeline_merge)));
3850+
3851+ ggml_metal_encoder_set_pipeline (enc, pipeline_merge);
3852+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof (args_merge), 0 );
3853+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
3854+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2 );
3855+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3 );
3856+
3857+ ggml_metal_encoder_dispatch_threadgroups (enc, nm*ne01, ne02, ne03, nth, 1 , 1 );
3858+
3859+ std::swap (bid_dst, bid_tmp);
3860+
3861+ len <<= 1 ;
3862+ }
3863+
3864+ return 1 ;
3865+ }
3866+
37403867int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx) {
37413868 ggml_tensor * op = ctx->node (idx);
37423869
0 commit comments