@@ -3763,18 +3763,17 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
37633763 GGML_TENSOR_LOCALS ( int32_t , ne, op, ne);
37643764 GGML_TENSOR_LOCALS (uint64_t , nb, op, nb);
37653765
3766- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort (lib, op);
3766+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k (lib, op);
37673767
37683768 // bitonic sort requires the number of elements to be power of 2
37693769 int nth = 1 ;
37703770 while (nth < ne00 && 2 *nth <= ggml_metal_pipeline_max_theads_per_threadgroup (pipeline)) {
37713771 nth *= 2 ;
37723772 }
37733773
3774+ // blocks per row
37743775 const int npr = (ne00 + nth - 1 )/nth;
37753776
3776- // Metal kernels require the buffer size to be multiple of 16 bytes
3777- // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
37783777 const size_t smem = GGML_PAD (nth*sizeof (int32_t ), 16 );
37793778
37803779 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
@@ -3802,7 +3801,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38023801 /* .ne1 =*/ ne1,
38033802 /* .ne2 =*/ ne2,
38043803 /* .ne3 =*/ ne3,
3805- /* .top_k =*/ std::min (nth, top_k),
3804+ /* .top_k =*/ std::min (nth, top_k), // for each block, keep just the top_k indices
38063805 };
38073806
38083807 if (npr > 1 ) {
@@ -3818,13 +3817,18 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38183817
38193818 ggml_metal_encoder_dispatch_threadgroups (enc, npr*ne01, ne02, ne03, nth, 1 , 1 );
38203819
3821- ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge (lib, op);
3820+ ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge (lib, op);
38223821
38233822 int len = args.top_k ;
38243823
38253824 while (len < args.ne0 ) {
38263825 ggml_metal_op_concurrency_reset (ctx);
38273826
3827+ // merges per row
3828+ const int nm = (args.ne0 + 2 *len - 1 ) / (2 *len);
3829+
3830+ const int nth = std::min (512 , std::min (len, ggml_metal_pipeline_max_theads_per_threadgroup (pipeline_merge)));
3831+
38283832 ggml_metal_kargs_argsort_merge args_merge = {
38293833 /* .ne00 =*/ ne00,
38303834 /* .ne01 =*/ ne01,
@@ -3838,15 +3842,10 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
38383842 /* .ne1 =*/ ne1,
38393843 /* .ne2 =*/ ne2,
38403844 /* .ne3 =*/ ne3,
3841- /* .top_k =*/ 2 *len >= args. ne0 ? top_k : args.ne0 ,
3845+ /* .top_k =*/ nm == 1 ? top_k : args.ne0 , // the final merge outputs top_k elements
38423846 /* .len =*/ len,
38433847 };
38443848
3845- // merges per row
3846- const int nm = (args.ne0 + 2 *len - 1 ) / (2 *len);
3847-
3848- const int nth = std::min (512 , std::min (len, ggml_metal_pipeline_max_theads_per_threadgroup (pipeline_merge)));
3849-
38503849 ggml_metal_encoder_set_pipeline (enc, pipeline_merge);
38513850 ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof (args_merge), 0 );
38523851 ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
0 commit comments