@@ -405,6 +405,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
405405 {
406406 n_fuse = ggml_metal_op_argsort (ctx, idx);
407407 } break ;
408+ case GGML_OP_TOP_K:
409+ {
410+ n_fuse = ggml_metal_op_top_k (ctx, idx);
411+ } break ;
408412 case GGML_OP_LEAKY_RELU:
409413 {
410414 n_fuse = ggml_metal_op_leaky_relu (ctx, idx);
@@ -3677,14 +3681,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
36773681 }
36783682
36793683 ggml_metal_kargs_argsort args = {
3680- /* .ne00 =*/ ne00,
3681- /* .ne01 =*/ ne01,
3682- /* .ne02 =*/ ne02,
3683- /* .ne03 =*/ ne03,
3684- /* .nb00 =*/ nb00,
3685- /* .nb01 =*/ nb01,
3686- /* .nb02 =*/ nb02,
3687- /* .nb03 =*/ nb03,
3684+ /* .ne00 =*/ ne00,
3685+ /* .ne01 =*/ ne01,
3686+ /* .ne02 =*/ ne02,
3687+ /* .ne03 =*/ ne03,
3688+ /* .nb00 =*/ nb00,
3689+ /* .nb01 =*/ nb01,
3690+ /* .nb02 =*/ nb02,
3691+ /* .nb03 =*/ nb03,
3692+ /* .ne0 =*/ ne0,
3693+ /* .ne1 =*/ ne1,
3694+ /* .ne2 =*/ ne2,
3695+ /* .ne3 =*/ ne3,
3696+ /* .top_k =*/ nth,
36883697 };
36893698
36903699 ggml_metal_encoder_set_pipeline (enc, pipeline);
@@ -3704,15 +3713,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
37043713 ggml_metal_op_concurrency_reset (ctx);
37053714
37063715 ggml_metal_kargs_argsort_merge args_merge = {
3707- .ne00 = ne00,
3708- .ne01 = ne01,
3709- .ne02 = ne02,
3710- .ne03 = ne03,
3711- .nb00 = nb00,
3712- .nb01 = nb01,
3713- .nb02 = nb02,
3714- .nb03 = nb03,
3715- .len = len,
3716+ /* .ne00 =*/ ne00,
3717+ /* .ne01 =*/ ne01,
3718+ /* .ne02 =*/ ne02,
3719+ /* .ne03 =*/ ne03,
3720+ /* .nb00 =*/ nb00,
3721+ /* .nb01 =*/ nb01,
3722+ /* .nb02 =*/ nb02,
3723+ /* .nb03 =*/ nb03,
3724+ /* .ne0 =*/ ne0,
3725+ /* .ne1 =*/ ne1,
3726+ /* .ne2 =*/ ne2,
3727+ /* .ne3 =*/ ne3,
3728+ /* .top_k =*/ ne00,
3729+ /* .len =*/ len,
37163730 };
37173731
37183732 // merges per row
@@ -3736,6 +3750,119 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
37363750 return 1 ;
37373751}
37383752
3753+ int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx) {
3754+ ggml_tensor * op = ctx->node (idx);
3755+
3756+ ggml_metal_library_t lib = ctx->lib ;
3757+ ggml_metal_encoder_t enc = ctx->enc ;
3758+
3759+ GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
3760+
3761+ GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
3762+ GGML_TENSOR_LOCALS (uint64_t , nb0, op->src [0 ], nb);
3763+ GGML_TENSOR_LOCALS ( int32_t , ne, op, ne);
3764+ GGML_TENSOR_LOCALS (uint64_t , nb, op, nb);
3765+
3766+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort (lib, op);
3767+
3768+ // bitonic sort requires the number of elements to be power of 2
3769+ int nth = 1 ;
3770+ while (nth < ne00 && 2 *nth <= ggml_metal_pipeline_max_theads_per_threadgroup (pipeline)) {
3771+ nth *= 2 ;
3772+ }
3773+
3774+ const int npr = (ne00 + nth - 1 )/nth;
3775+
3776+ // Metal kernels require the buffer size to be multiple of 16 bytes
3777+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3778+ const size_t smem = GGML_PAD (nth*sizeof (int32_t ), 16 );
3779+
3780+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
3781+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id (op);
3782+
3783+ ggml_metal_buffer_id bid_tmp = bid_dst;
3784+ bid_tmp.offs += sizeof (int32_t )*ggml_nelements (op->src [0 ]);
3785+
3786+ if ((int ) ceil (std::log (npr) / std::log (2 )) % 2 == 1 ) {
3787+ std::swap (bid_dst, bid_tmp);
3788+ }
3789+
3790+ const int top_k = ne0;
3791+
3792+ ggml_metal_kargs_argsort args = {
3793+ /* .ne00 =*/ ne00,
3794+ /* .ne01 =*/ ne01,
3795+ /* .ne02 =*/ ne02,
3796+ /* .ne03 =*/ ne03,
3797+ /* .nb00 =*/ nb00,
3798+ /* .nb01 =*/ nb01,
3799+ /* .nb02 =*/ nb02,
3800+ /* .nb03 =*/ nb03,
3801+ /* .ne0 =*/ ne0,
3802+ /* .ne1 =*/ ne1,
3803+ /* .ne2 =*/ ne2,
3804+ /* .ne3 =*/ ne3,
3805+ /* .top_k =*/ std::min (nth, top_k),
3806+ };
3807+
3808+ if (npr > 1 ) {
3809+ args.ne0 = (npr - 1 )*args.top_k + std::min (ne00 - (npr - 1 )*nth, args.top_k );
3810+ }
3811+
3812+ ggml_metal_encoder_set_pipeline (enc, pipeline);
3813+ ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
3814+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
3815+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2 );
3816+
3817+ ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
3818+
3819+ ggml_metal_encoder_dispatch_threadgroups (enc, npr*ne01, ne02, ne03, nth, 1 , 1 );
3820+
3821+ ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge (lib, op);
3822+
3823+ int len = args.top_k ;
3824+
3825+ while (len < args.ne0 ) {
3826+ ggml_metal_op_concurrency_reset (ctx);
3827+
3828+ ggml_metal_kargs_argsort_merge args_merge = {
3829+ /* .ne00 =*/ ne00,
3830+ /* .ne01 =*/ ne01,
3831+ /* .ne02 =*/ ne02,
3832+ /* .ne03 =*/ ne03,
3833+ /* .nb00 =*/ nb00,
3834+ /* .nb01 =*/ nb01,
3835+ /* .nb02 =*/ nb02,
3836+ /* .nb03 =*/ nb03,
3837+ /* .ne0 =*/ args.ne0 ,
3838+ /* .ne1 =*/ ne1,
3839+ /* .ne2 =*/ ne2,
3840+ /* .ne3 =*/ ne3,
3841+ /* .top_k =*/ 2 *len >= args.ne0 ? top_k : args.ne0 ,
3842+ /* .len =*/ len,
3843+ };
3844+
3845+ // merges per row
3846+ const int nm = (args.ne0 + 2 *len - 1 ) / (2 *len);
3847+
3848+ const int nth = std::min (512 , std::min (len, ggml_metal_pipeline_max_theads_per_threadgroup (pipeline_merge)));
3849+
3850+ ggml_metal_encoder_set_pipeline (enc, pipeline_merge);
3851+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof (args_merge), 0 );
3852+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
3853+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2 );
3854+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3 );
3855+
3856+ ggml_metal_encoder_dispatch_threadgroups (enc, nm*ne01, ne02, ne03, nth, 1 , 1 );
3857+
3858+ std::swap (bid_dst, bid_tmp);
3859+
3860+ len <<= 1 ;
3861+ }
3862+
3863+ return 1 ;
3864+ }
3865+
37393866int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx) {
37403867 ggml_tensor * op = ctx->node (idx);
37413868
0 commit comments