@@ -3764,18 +3764,17 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
37643764 GGML_TENSOR_LOCALS ( int32_t , ne, op, ne);
37653765 GGML_TENSOR_LOCALS (uint64_t , nb, op, nb);
37663766
3767- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort (lib, op);
3767+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k (lib, op);
37683768
37693769 // bitonic sort requires the number of elements to be power of 2
37703770 int nth = 1 ;
37713771 while (nth < ne00 && 2 *nth <= ggml_metal_pipeline_max_theads_per_threadgroup (pipeline)) {
37723772 nth *= 2 ;
37733773 }
37743774
3775+ // blocks per row
37753776 const int npr = (ne00 + nth - 1 )/nth;
37763777
3777- // Metal kernels require the buffer size to be multiple of 16 bytes
3778- // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
37793778 const size_t smem = GGML_PAD (nth*sizeof (int32_t ), 16 );
37803779
37813780 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
@@ -3803,7 +3802,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38033802 /* .ne1 =*/ ne1,
38043803 /* .ne2 =*/ ne2,
38053804 /* .ne3 =*/ ne3,
3806- /* .top_k =*/ std::min (nth, top_k),
3805+ /* .top_k =*/ std::min (nth, top_k), // for each block, keep just the top_k indices
38073806 };
38083807
38093808 if (npr > 1 ) {
@@ -3819,13 +3818,18 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38193818
38203819 ggml_metal_encoder_dispatch_threadgroups (enc, npr*ne01, ne02, ne03, nth, 1 , 1 );
38213820
3822- ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge (lib, op);
3821+ ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge (lib, op);
38233822
38243823 int len = args.top_k ;
38253824
38263825 while (len < args.ne0 ) {
38273826 ggml_metal_op_concurrency_reset (ctx);
38283827
3828+ // merges per row
3829+ const int nm = (args.ne0 + 2 *len - 1 ) / (2 *len);
3830+
3831+ const int nth = std::min (512 , std::min (len, ggml_metal_pipeline_max_theads_per_threadgroup (pipeline_merge)));
3832+
38293833 ggml_metal_kargs_argsort_merge args_merge = {
38303834 /* .ne00 =*/ ne00,
38313835 /* .ne01 =*/ ne01,
@@ -3839,15 +3843,10 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38393843 /* .ne1 =*/ ne1,
38403844 /* .ne2 =*/ ne2,
38413845 /* .ne3 =*/ ne3,
3842- /* .top_k =*/ 2 *len >= args. ne0 ? top_k : args.ne0 ,
3846+ /* .top_k =*/ nm == 1 ? top_k : args.ne0 , // the final merge outputs top_k elements
38433847 /* .len =*/ len,
38443848 };
38453849
3846- // merges per row
3847- const int nm = (args.ne0 + 2 *len - 1 ) / (2 *len);
3848-
3849- const int nth = std::min (512 , std::min (len, ggml_metal_pipeline_max_theads_per_threadgroup (pipeline_merge)));
3850-
38513850 ggml_metal_encoder_set_pipeline (enc, pipeline_merge);
38523851 ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof (args_merge), 0 );
38533852 ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
0 commit comments