Skip to content

Commit 857f381

Browse files
committed
vulkan: Add N-ary search algorithm for topk
1 parent 1b63ed1 commit 857f381

File tree

4 files changed

+218
-4
lines changed

4 files changed

+218
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ struct vk_device_struct {
516516
bool single_queue;
517517
bool support_async;
518518
uint32_t subgroup_size;
519+
uint32_t subgroup_size_log2;
519520
uint32_t shader_core_count;
520521
bool uma;
521522
bool prefer_host_memory;
@@ -3978,9 +3979,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
39783979
for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
39793980
const uint32_t BLOCK_SIZE = 1u << i;
39803981
const uint32_t NCOLS_PADDED_LOG2 = i;
3981-
if (i <= device->max_workgroup_size_log2 &&
3982-
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
3983-
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_f32_len, topk_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
3982+
if (i <= device->max_workgroup_size_log2) {
3983+
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
3984+
sizeof(int) * device->subgroup_size +
3985+
2 * sizeof(int) +
3986+
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
3987+
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
3988+
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
3989+
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
3990+
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
3991+
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
3992+
}
39843993
}
39853994
}
39863995

@@ -4353,6 +4362,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
43534362
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
43544363

43554364
device->subgroup_size = subgroup_props.subgroupSize;
4365+
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
43564366
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
43574367
if (sm_builtins) {
43584368
device->shader_core_count = sm_props.shaderSMCount;
@@ -10191,6 +10201,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1019110201
// But if K is larger, then we need a larger workgroup
1019210202
uint32_t max_pipeline = num_topk_pipelines - 3;
1019310203
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
10204+
// require full subgroup
10205+
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
1019410206

1019510207
uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
1019610208
pipeline_idx = std::min(pipeline_idx, max_pipeline);
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#version 450
2+
#extension GL_EXT_control_flow_attributes : enable
3+
#extension GL_EXT_debug_printf : enable
4+
#extension GL_KHR_shader_subgroup_basic : enable
5+
#extension GL_KHR_shader_subgroup_ballot : enable
6+
#extension GL_KHR_shader_subgroup_arithmetic : enable
7+
#extension GL_KHR_shader_subgroup_shuffle : enable
8+
9+
#include "types.glsl"
10+
11+
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
12+
layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
13+
layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
14+
15+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
16+
17+
// Input can either be the source (A) or intermediate values (S).
18+
// Similarly, output can be either destination (D) or intermediate values (S).
19+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
20+
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
21+
layout (binding = 1) writeonly buffer D {int data_d[];};
22+
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
23+
24+
layout (push_constant) uniform parameter {
25+
uint orig_ncols;
26+
uint ncols_input;
27+
uint ncols_output;
28+
uint nrows;
29+
uint first_pass;
30+
uint last_pass;
31+
} p;
32+
33+
// pairs of (gid, value)
34+
shared ivec2 dst_row[BLOCK_SIZE];
35+
36+
shared int counts[SUBGROUP_SIZE];
37+
shared int sh_min_idx;
38+
shared uint sh_total;
39+
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
40+
41+
// Map float values to uint such that comparisons still work.
42+
// Positive values set the high bit, negative values are inverted.
43+
// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
44+
uint f2ui(float x) {
45+
uint y = floatBitsToUint(x);
46+
if ((y & 0x80000000) != 0) {
47+
y ^= ~0;
48+
} else {
49+
y |= 0x80000000;
50+
}
51+
return y;
52+
}
53+
54+
void topk(const uint row) {
55+
const int tid = int(gl_LocalInvocationID.x);
56+
57+
// initialize indices
58+
if (gl_GlobalInvocationID.x < p.ncols_input) {
59+
if (p.first_pass != 0) {
60+
const uint row_offset = row * p.ncols_input;
61+
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
62+
} else {
63+
const uint row_offset = row * p.orig_ncols;
64+
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
65+
}
66+
} else {
67+
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
68+
}
69+
barrier();
70+
71+
if (p.ncols_output == 1) {
72+
// Fast path for single output - just do a max reduction
73+
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
74+
if (tid < s) {
75+
ivec2 a = dst_row[tid];
76+
ivec2 b = dst_row[tid + s];
77+
if (a.x >= p.orig_ncols ||
78+
b.x < p.orig_ncols && b.y > a.y) {
79+
dst_row[tid] = b;
80+
}
81+
}
82+
barrier();
83+
}
84+
} else {
85+
// Do an N-ary search to find the K-th largest value.
86+
// We remap the float values to be comparable as unsigned integers,
87+
// and split the range into 2^N smaller ranges where N is the
88+
// subgroup size. Count how many values are in each range, if the K-th
89+
// largest value is in the middle of one of thee ranges then repeat
90+
// and split again.
91+
92+
// Mask is the current set of bits we're searching. Shift is the LSB index.
93+
int shift = 32 - SUBGROUP_SIZE_LOG2;
94+
uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
95+
96+
// The current range.
97+
uint range_min = 0;
98+
uint range_max = 0xFF800000;
99+
// How many are above the current range, and how many we need to find.
100+
uint total = 0;
101+
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
102+
103+
while (mask != 0) {
104+
barrier();
105+
// Initialize bucket counts to zero.
106+
if (tid < SUBGROUP_SIZE) {
107+
counts[tid] = 0;
108+
}
109+
barrier();
110+
// Count how many values are in each bucket.
111+
if (tid < p.ncols_input) {
112+
float y = intBitsToFloat(dst_row[tid].y);
113+
uint fy = f2ui(y);
114+
if (fy >= range_min && fy < range_max) {
115+
uint bucket = (fy & mask) >> shift;
116+
atomicAdd(counts[bucket], 1);
117+
}
118+
}
119+
barrier();
120+
121+
// On the first subgroup, do a scan to count (from the top down) how
122+
// many elements are in the top N buckets. Find the index of the first
123+
// that is over the limit. Copy it to the other invocations through
124+
// shared memory.
125+
if (tid < SUBGROUP_SIZE) {
126+
uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
127+
partial_sum = subgroupInclusiveAdd(partial_sum) + total;
128+
uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
129+
int min_idx = int(SUBGROUP_SIZE - 1 - t);
130+
total = subgroupShuffle(partial_sum, t);
131+
sh_min_idx = min_idx;
132+
sh_total = total;
133+
}
134+
barrier();
135+
int min_idx = sh_min_idx;
136+
total = sh_total;
137+
138+
// Update the range, and break if we've found the K-th largest.
139+
if (min_idx != SUBGROUP_SIZE - 1) {
140+
range_max = range_min + ((min_idx + 1) << shift);
141+
}
142+
range_min = range_min + (min_idx << shift);
143+
144+
if (total == p.ncols_output) {
145+
break;
146+
}
147+
total -= counts[min_idx];
148+
mask >>= SUBGROUP_SIZE_LOG2;
149+
shift -= SUBGROUP_SIZE_LOG2;
150+
if (shift < 0) {
151+
shift = 0;
152+
}
153+
}
154+
155+
ivec2 v = dst_row[tid];
156+
157+
// We need to compact these values to the start of the dst_row array.
158+
// Have each subgroup count how many items it'll store, so other
159+
// subgroups can compute their base offset.
160+
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
161+
uvec4 b = subgroupBallot(top);
162+
uint bit_count = subgroupBallotBitCount(b);
163+
if ((tid % SUBGROUP_SIZE) == 0) {
164+
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
165+
}
166+
barrier();
167+
168+
uint out_idx = 0;
169+
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
170+
if (i < tid / SUBGROUP_SIZE) {
171+
out_idx += offset_partials[i];
172+
}
173+
}
174+
175+
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
176+
if (top) {
177+
// TODO: Copy directly to the output?
178+
dst_row[out_idx + bit_count_ex] = v;
179+
}
180+
181+
barrier();
182+
}
183+
184+
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
185+
if (p.last_pass != 0) {
186+
const uint row_offset = row * p.ncols_output;
187+
data_d[row_offset + tid] = dst_row[tid].x;
188+
} else {
189+
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
190+
data_t[row_offset + tid] = dst_row[tid];
191+
}
192+
}
193+
}
194+
195+
void main() {
196+
uint row = gl_WorkGroupID.y;
197+
while (row < p.nrows) {
198+
topk(row);
199+
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
200+
}
201+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,8 @@ void process_shaders() {
913913
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
914914
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
915915

916-
string_to_spv("topk_f32", "topk.comp", {{"A_TYPE", "float"}});
916+
string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
917+
string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
917918

918919
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
919920
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

0 commit comments

Comments
 (0)