Skip to content

Commit 3a68904

Browse files
committed
microoptimizations
1 parent 857f381 commit 3a68904

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10208,6 +10208,17 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1020810208
pipeline_idx = std::min(pipeline_idx, max_pipeline);
1020910209
pipeline_idx = std::max(pipeline_idx, min_pipeline);
1021010210

10211+
if (num_elements > (1u << pipeline_idx)) {
10212+
// If we could finish on this loop iteration (i.e. a single workgroup)
10213+
// then do so. It's better than the overhead of another pass.
10214+
for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
10215+
if (num_elements <= (1u << i)) {
10216+
pipeline_idx = i;
10217+
break;
10218+
}
10219+
}
10220+
}
10221+
1021110222
vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
1021210223
// If the device doesn't support a pipeline this large, use smaller
1021310224
while (!pipeline) {
@@ -10242,9 +10253,11 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1024210253

1024310254
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
1024410255
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
10245-
ggml_vk_sync_buffers(ctx, subctx);
1024610256
num_elements = num_dst_elements;
1024710257
dbl_buf_index ^= 1;
10258+
if (num_elements > k) {
10259+
ggml_vk_sync_buffers(ctx, subctx);
10260+
}
1024810261
}
1024910262
ctx->prealloc_x_need_sync = true;
1025010263
}

ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,17 @@ void topk(const uint row) {
126126
uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
127127
partial_sum = subgroupInclusiveAdd(partial_sum) + total;
128128
uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
129-
int min_idx = int(SUBGROUP_SIZE - 1 - t);
130-
total = subgroupShuffle(partial_sum, t);
131-
sh_min_idx = min_idx;
132-
sh_total = total;
129+
if (tid == t) {
130+
sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
131+
sh_total = partial_sum;
132+
}
133133
}
134134
barrier();
135135
int min_idx = sh_min_idx;
136136
total = sh_total;
137137

138138
// Update the range, and break if we've found the K-th largest.
139-
if (min_idx != SUBGROUP_SIZE - 1) {
140-
range_max = range_min + ((min_idx + 1) << shift);
141-
}
139+
range_max = range_min + ((min_idx + 1) << shift);
142140
range_min = range_min + (min_idx << shift);
143141

144142
if (total == p.ncols_output) {

0 commit comments

Comments
 (0)