Skip to content

Commit 61bde8e

Browse files
authored
vulkan: Reduce temporary memory usage for TOP_K (ggml-org#17623)
- Compute row size for the temp buffer based on the output of the first pass. - Update shader addressing math to use the output row size - Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k" For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer from about 3.2MB to 500KB.
1 parent e251e5e commit 61bde8e

File tree

3 files changed

+54
-27
lines changed

3 files changed

+54
-27
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants {
12271227
uint32_t orig_ncols;
12281228
uint32_t ncols_input;
12291229
uint32_t ncols_output;
1230+
uint32_t k;
12301231
uint32_t nrows;
12311232
uint32_t first_pass;
12321233
uint32_t last_pass;
@@ -1673,6 +1674,14 @@ class vk_perf_logger {
16731674
timings[name.str()].push_back(time);
16741675
return;
16751676
}
1677+
if (node->op == GGML_OP_TOP_K) {
1678+
std::stringstream name;
1679+
name << ggml_op_name(node->op) <<
1680+
" K=" << node->ne[0] <<
1681+
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
1682+
timings[name.str()].push_back(time);
1683+
return;
1684+
}
16761685
timings[ggml_op_name(node->op)].push_back(time);
16771686
}
16781687
private:
@@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1034510354
uint32_t nrows = ggml_nrows(src0);
1034610355
uint32_t k = dst->ne[0];
1034710356

10348-
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
10357+
vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
1034910358

10350-
// Reserve space for ivec2 per element, double buffered
10351-
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
10352-
const size_t x_sz = dbl_buf_size * 2;
10353-
uint32_t dbl_buf_index = 0;
10354-
10355-
if (ctx->prealloc_size_x < x_sz) {
10356-
ctx->prealloc_size_x = x_sz;
10357-
ggml_vk_preallocate_buffers(ctx, subctx);
10358-
}
1035910359
if (ctx->prealloc_x_need_sync) {
1036010360
ggml_vk_sync_buffers(ctx, subctx);
1036110361
}
@@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1037010370
// largest elements. Repeat until we have the top K elements.
1037110371
// Need to do at least one iteration to write out the results.
1037210372
bool done_one_iter = false;
10373+
uint32_t dbl_buf_index = 0;
10374+
size_t dbl_buf_size;
1037310375
while (num_elements > k || !done_one_iter) {
10374-
done_one_iter = true;
1037510376

1037610377
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
1037710378
// But if K is larger, then we need a larger workgroup
@@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1041110412
// Number of elements remaining after this pass
1041210413
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
1041310414

10415+
pc2.ncols_output = num_dst_elements;
10416+
10417+
if (!done_one_iter) {
10418+
// Reserve space for ivec2 per element, double buffered
10419+
// K per workgroup per row
10420+
dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
10421+
dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
10422+
const size_t x_sz = dbl_buf_size * 2;
10423+
10424+
if (ctx->prealloc_size_x < x_sz) {
10425+
ctx->prealloc_size_x = x_sz;
10426+
ggml_vk_preallocate_buffers(ctx, subctx);
10427+
}
10428+
}
10429+
1041410430
vk_subbuffer src_buf;
1041510431
vk_subbuffer dst_buf;
1041610432

@@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1043610452
if (num_elements > k) {
1043710453
ggml_vk_sync_buffers(ctx, subctx);
1043810454
}
10455+
done_one_iter = true;
1043910456
}
1044010457
ctx->prealloc_x_need_sync = true;
1044110458
}

ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ layout (push_constant) uniform parameter {
1919
uint orig_ncols;
2020
uint ncols_input;
2121
uint ncols_output;
22+
uint k;
2223
uint nrows;
2324
uint first_pass;
2425
uint last_pass;
@@ -36,15 +37,15 @@ void topk(bool needs_bounds_check, const uint row) {
3637
const uint row_offset = row * p.ncols_input;
3738
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
3839
} else {
39-
const uint row_offset = row * p.orig_ncols;
40+
const uint row_offset = row * p.ncols_input;
4041
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
4142
}
4243
} else {
4344
dst_row[col] = ivec2(p.orig_ncols, 0);
4445
}
4546
barrier();
4647

47-
if (p.ncols_output == 1) {
48+
if (p.k == 1) {
4849
// Fast path for single output - just do a max reduction
4950
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
5051
if (col < s) {
@@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) {
8485
}
8586
}
8687

87-
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
88+
if (col < p.k) {
8889
if (p.last_pass != 0) {
89-
const uint row_offset = row * p.ncols_output;
90-
data_d[row_offset + col] = dst_row[col].x;
90+
if (gl_GlobalInvocationID.x < p.ncols_input) {
91+
const uint row_offset = row * p.k;
92+
data_d[row_offset + col] = dst_row[col].x;
93+
}
9194
} else {
92-
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
93-
data_t[row_offset + col] = dst_row[col];
95+
if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {
96+
const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
97+
data_t[row_offset + col] = dst_row[col];
98+
}
9499
}
95100
}
96101
}

ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
2525
uint orig_ncols;
2626
uint ncols_input;
2727
uint ncols_output;
28+
uint k;
2829
uint nrows;
2930
uint first_pass;
3031
uint last_pass;
@@ -60,15 +61,15 @@ void topk(const uint row) {
6061
const uint row_offset = row * p.ncols_input;
6162
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
6263
} else {
63-
const uint row_offset = row * p.orig_ncols;
64+
const uint row_offset = row * p.ncols_input;
6465
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
6566
}
6667
} else {
6768
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
6869
}
6970
barrier();
7071

71-
if (p.ncols_output == 1) {
72+
if (p.k == 1) {
7273
// Fast path for single output - just do a max reduction
7374
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
7475
if (tid < s) {
@@ -98,7 +99,7 @@ void topk(const uint row) {
9899
uint range_max = 0xFF800000;
99100
// How many are above the current range, and how many we need to find.
100101
uint total = 0;
101-
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
102+
uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
102103

103104
while (mask != 0) {
104105
barrier();
@@ -139,7 +140,7 @@ void topk(const uint row) {
139140
range_max = range_min + ((min_idx + 1) << shift);
140141
range_min = range_min + (min_idx << shift);
141142

142-
if (total == p.ncols_output) {
143+
if (total == p.k) {
143144
break;
144145
}
145146
total -= counts[min_idx];
@@ -179,13 +180,17 @@ void topk(const uint row) {
179180
barrier();
180181
}
181182

182-
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
183+
if (tid < p.k) {
183184
if (p.last_pass != 0) {
184-
const uint row_offset = row * p.ncols_output;
185-
data_d[row_offset + tid] = dst_row[tid].x;
185+
if (gl_GlobalInvocationID.x < p.ncols_input) {
186+
const uint row_offset = row * p.k;
187+
data_d[row_offset + tid] = dst_row[tid].x;
188+
}
186189
} else {
187-
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
188-
data_t[row_offset + tid] = dst_row[tid];
190+
if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {
191+
const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
192+
data_t[row_offset + tid] = dst_row[tid];
193+
}
189194
}
190195
}
191196
}

0 commit comments

Comments
 (0)