diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 95966ce1d8e..2e5dbbcd635 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -718,6 +718,7 @@ struct vk_device_struct { vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; + vk_pipeline pipeline_get_rel_pos_f32, pipeline_get_rel_pos_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines]; @@ -4022,6 +4023,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } + ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f32, "get_rel_pos_f32", get_rel_pos_f32_len, get_rel_pos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f16, "get_rel_pos_f16", get_rel_pos_f16_len, get_rel_pos_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); if (i <= device->max_workgroup_size_log2 && @@ -8867,6 +8871,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_fill_f32; } return nullptr; + case GGML_OP_GET_REL_POS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rel_pos_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_get_rel_pos_f16; + } + return nullptr; default: return nullptr; } @@ -9149,6 +9161,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_UNARY: case GGML_OP_GLU: case GGML_OP_CONV_2D_DW: + case GGML_OP_GET_REL_POS: { uint32_t ne = ggml_nelements(dst); if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { @@ -10254,6 +10267,11 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride)); } +static void ggml_vk_get_rel_pos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_unary_push_constants pc = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GET_REL_POS, pc); +} + static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { const uint32_t * op_params = (const uint32_t *)dst->op_params; @@ -12063,6 +12081,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ROPE_BACK: ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true); + break; + case GGML_OP_GET_REL_POS: + ggml_vk_get_rel_pos(ctx, compute_ctx, src0, node); + break; case GGML_OP_ARGSORT: if (ctx->num_additional_fused_ops) { @@ -14082,6 +14104,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_TRI: return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && op->type == op->src[0]->type; + case GGML_OP_GET_REL_POS: + return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || + (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); case GGML_OP_ARGSORT: { if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rel_pos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rel_pos.comp new file mode 100644 index 00000000000..3657ce8e110 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rel_pos.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + if (idx >= p.ne) { + return; + } + + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + + const float kh = float(p.ne11); + const float qh = float(p.ne12); + const float k_scale = max(qh / kh, 1.0f); + const float q_scale = max(kh / qh, 1.0f); + + // Add a small epsilon to avoid floating point precision issues + const float epsilon = 0.0001f; + const int pos = int(float(i12) * q_scale - float(i11) * k_scale + (kh - 1.0f) * k_scale + epsilon); + + const uint src_idx = pos*p.nb01 + i10*p.nb00; + const uint dst_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; + + data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + src_idx]); +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 92bae088b20..2237506e03a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -919,6 +919,9 @@ void process_shaders() { string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("get_rel_pos_f32", "get_rel_pos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("get_rel_pos_f16", "get_rel_pos.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});