Skip to content
Open
25 changes: 25 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_get_rel_pos_f32, pipeline_get_rel_pos_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32;
Expand Down Expand Up @@ -3934,6 +3935,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}

ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f32, "get_rel_pos_f32", get_rel_pos_f32_len, get_rel_pos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f16, "get_rel_pos_f16", get_rel_pos_f16_len, get_rel_pos_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, { 512 }, 1);

for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
if (i <= device->max_workgroup_size_log2 &&
Expand Down Expand Up @@ -8623,6 +8627,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_fill_f32;
}
return nullptr;
case GGML_OP_GET_REL_POS:
if (src0->type == GGML_TYPE_F32) {
return ctx->device->pipeline_get_rel_pos_f32;
}
if (src0->type == GGML_TYPE_F16) {
return ctx->device->pipeline_get_rel_pos_f16;
}
return nullptr;
default:
return nullptr;
}
Expand Down Expand Up @@ -8932,6 +8944,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_UNARY:
case GGML_OP_GLU:
case GGML_OP_CONV_2D_DW:
case GGML_OP_GET_REL_POS:
{
uint32_t ne = ggml_nelements(dst);
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
Expand Down Expand Up @@ -10030,6 +10043,11 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
}

static void ggml_vk_get_rel_pos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants pc = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GET_REL_POS, pc);
}

static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const uint32_t * op_params = (const uint32_t *)dst->op_params;

Expand Down Expand Up @@ -11715,6 +11733,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ROPE_BACK:
ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);

break;
case GGML_OP_GET_REL_POS:
ggml_vk_get_rel_pos(ctx, compute_ctx, src0, node);

break;
case GGML_OP_ARGSORT:
if (ctx->num_additional_fused_ops) {
Expand Down Expand Up @@ -13737,6 +13759,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_GET_REL_POS:
return ggml_is_contiguous(op->src[0]) && ggml_vk_dim01_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_ARGSORT:
{
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
Expand Down
35 changes: 35 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/get_rel_pos.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#version 450

#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}

const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;

const float kh = float(p.ne11);
const float qh = float(p.ne12);
const float k_scale = max(qh / kh, 1.0f);
const float q_scale = max(kh / qh, 1.0f);

// Add a small epsilon to avoid floating point precision issues
const float epsilon = 0.0001f;
const int pos = int(float(i12) * q_scale - float(i11) * k_scale + (kh - 1.0f) * k_scale + epsilon);

const uint src_idx = pos*p.nb01 + i10*p.nb00;
const uint dst_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;

data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + src_idx]);
}

3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,9 @@ void process_shaders() {
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});

string_to_spv("get_rel_pos_f32", "get_rel_pos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("get_rel_pos_f16", "get_rel_pos.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});

Expand Down