Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ enum shader_reduction_mode {
// argsort pipelines for up to 1<<10 invocations per workgroup
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr uint32_t num_topk_pipelines = 11;

static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
Expand Down Expand Up @@ -515,6 +516,7 @@ struct vk_device_struct {
bool single_queue;
bool support_async;
uint32_t subgroup_size;
uint32_t subgroup_size_log2;
uint32_t shader_core_count;
bool uma;
bool prefer_host_memory;
Expand Down Expand Up @@ -704,6 +706,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
Expand Down Expand Up @@ -1204,6 +1207,15 @@ struct vk_op_argsort_push_constants {
uint32_t inner_end;
};

struct vk_op_topk_push_constants {
uint32_t orig_ncols;
uint32_t ncols_input;
uint32_t ncols_output;
uint32_t nrows;
uint32_t first_pass;
uint32_t last_pass;
};

struct vk_op_im2col_push_constants {
uint64_t dst_addr;
uint32_t batch_offset; uint32_t offset_delta;
Expand Down Expand Up @@ -3964,6 +3976,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
}

for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
const uint32_t BLOCK_SIZE = 1u << i;
const uint32_t NCOLS_PADDED_LOG2 = i;
if (i <= device->max_workgroup_size_log2) {
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
sizeof(int) * device->subgroup_size +
2 * sizeof(int) +
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
}
}
}

ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);

ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
Expand Down Expand Up @@ -4333,6 +4362,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);

device->subgroup_size = subgroup_props.subgroupSize;
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
if (sm_builtins) {
device->shader_core_count = sm_props.shaderSMCount;
Expand Down Expand Up @@ -10134,6 +10164,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
}
}

static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
uint32_t ncols = src0->ne[0];
uint32_t nrows = ggml_nrows(src0);
uint32_t k = dst->ne[0];

vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };

// Reserve space for ivec2 per element, double buffered
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
const size_t x_sz = dbl_buf_size * 2;
uint32_t dbl_buf_index = 0;

if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}

std::array<uint32_t, 3> elements;
elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = 1;

uint32_t num_elements = ncols;

// Each iteration reduces a workgroup's worth of elements down to the K
// largest elements. Repeat until we have the top K elements.
// Need to do at least one iteration to write out the results.
bool done_one_iter = false;
while (num_elements > k || !done_one_iter) {
done_one_iter = true;

// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
uint32_t max_pipeline = num_topk_pipelines - 3;
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);

uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
pipeline_idx = std::min(pipeline_idx, max_pipeline);
pipeline_idx = std::max(pipeline_idx, min_pipeline);

if (num_elements > (1u << pipeline_idx)) {
// If we could finish on this loop iteration (i.e. a single workgroup)
// then do so. It's better than the overhead of another pass.
for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
if (num_elements <= (1u << i)) {
pipeline_idx = i;
break;
}
}
}

vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
// If the device doesn't support a pipeline this large, use smaller
while (!pipeline) {
pipeline_idx--;
GGML_ASSERT(pipeline_idx >= min_pipeline);
pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
}

vk_op_topk_push_constants pc2 = pc;
pc2.ncols_input = num_elements;

// Number of elements remaining after this pass
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);

vk_subbuffer src_buf;
vk_subbuffer dst_buf;

if (num_elements == ncols) {
pc2.first_pass = 1;
src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
} else {
src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
}
if (num_dst_elements == k) {
pc2.last_pass = 1;
dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
} else {
dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
}

elements[0] = num_elements;

ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
num_elements = num_dst_elements;
dbl_buf_index ^= 1;
if (num_elements > k) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
ctx->prealloc_x_need_sync = true;
}

static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
Expand Down Expand Up @@ -11741,6 +11869,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ggml_vk_argsort(ctx, compute_ctx, src0, node);
}

break;
case GGML_OP_TOP_K:
ggml_vk_topk(ctx, compute_ctx, src0, node);

break;
case GGML_OP_SUM:
ggml_vk_sum(ctx, compute_ctx, src0, node);
Expand Down Expand Up @@ -13769,6 +13901,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
}
}
case GGML_OP_TOP_K:
{
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
// We could potentially support larger, using argsort to sort the
// whole thing. Not clear if this is needed.
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
if (min_pipeline >= num_topk_pipelines ||
!device->pipeline_topk_f32[min_pipeline]) {
return false;
}
}
return true;
case GGML_OP_UPSCALE:
case GGML_OP_ACC:
case GGML_OP_CONCAT:
Expand Down Expand Up @@ -14432,6 +14580,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ARGSORT) {
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_TOP_K) {
tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
} else if (tensor->op == GGML_OP_SUM) {
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) {
Expand Down
113 changes: 113 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable

#include "types.glsl"

layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

// Input can either be the source (A) or intermediate values (S).
// Similarly, output can be either destination (D) or intermediate values (S).
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
layout (binding = 1) writeonly buffer D {int data_d[];};
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};

layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint nrows;
uint first_pass;
uint last_pass;
} p;

// pairs of (gid, value)
shared ivec2 dst_row[BLOCK_SIZE];

void topk(bool needs_bounds_check, const uint row) {
const int col = int(gl_LocalInvocationID.x);

// initialize indices
if (gl_GlobalInvocationID.x < p.ncols_input) {
if (p.first_pass != 0) {
const uint row_offset = row * p.ncols_input;
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[col] = ivec2(p.orig_ncols, 0);
}
barrier();

if (p.ncols_output == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (col < s) {
ivec2 a = dst_row[col];
ivec2 b = dst_row[col + s];
if (a.x >= p.orig_ncols ||
b.x < p.orig_ncols && b.y > a.y) {
dst_row[col] = b;
}
}
barrier();
}
} else {
// bitonic sort on this group of elements
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
const int ixj = int(col ^ j);

int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;

ivec2 sh_idx_0 = dst_row[idx_0];
ivec2 sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;

if ((idx_0_oob ||
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
dst_row[idx_0] = sh_idx_1;
dst_row[idx_1] = sh_idx_0;
}

barrier();
}
}
}

if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + col] = dst_row[col].x;
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + col] = dst_row[col];
}
}
}

void main() {
// Fast path for fully occupied workgroups
if ((p.ncols_input % BLOCK_SIZE) == 0) {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}
Loading
Loading