Skip to content

Commit 5fb632f

Browse files
committed
vulkan: implement ADD1
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent a16954c commit 5fb632f

File tree

4 files changed

+77
-0
lines changed

4 files changed

+77
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,10 @@ struct vk_device_struct {
665665
vk_pipeline pipeline_hardswish[2];
666666
vk_pipeline pipeline_abs[2];
667667

668+
vk_pipeline pipeline_add1_f16_f16;
669+
vk_pipeline pipeline_add1_f16_f32;
670+
vk_pipeline pipeline_add1_f32_f32;
671+
668672
vk_pipeline pipeline_geglu[2];
669673
vk_pipeline pipeline_reglu[2];
670674
vk_pipeline pipeline_swiglu[2];
@@ -3839,6 +3843,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
38393843
CREATE_UNARY_RTE(exp)
38403844
#undef CREATE_UNARY_RTE
38413845

3846+
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3847+
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3848+
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3849+
38423850
#define CREATE_GLU(name) \
38433851
if (device->float_controls_rte_fp16) { \
38443852
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@@ -8527,6 +8535,17 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85278535
}
85288536
}
85298537
return nullptr;
8538+
case GGML_OP_ADD1:
8539+
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8540+
return ctx->device->pipeline_add1_f16_f16;
8541+
}
8542+
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
8543+
return ctx->device->pipeline_add1_f16_f32;
8544+
}
8545+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8546+
return ctx->device->pipeline_add1_f32_f32;
8547+
}
8548+
return nullptr;
85308549
default:
85318550
return nullptr;
85328551
}
@@ -8817,6 +8836,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88178836
case GGML_OP_SUB:
88188837
case GGML_OP_DIV:
88198838
case GGML_OP_MUL:
8839+
case GGML_OP_ADD1:
88208840
case GGML_OP_SCALE:
88218841
case GGML_OP_SQR:
88228842
case GGML_OP_SQRT:
@@ -9423,6 +9443,21 @@ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, cons
94239443
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
94249444
}
94259445

9446+
static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9447+
const uint32_t src0_type_size = ggml_type_size(src0->type);
9448+
const uint32_t src1_type_size = ggml_type_size(src1->type);
9449+
const uint32_t dst_type_size = ggml_type_size(dst->type);
9450+
9451+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {
9452+
(uint32_t)ggml_nelements(src0),
9453+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
9454+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9455+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
9456+
0,
9457+
0.0f, 0.0f, 0,
9458+
});
9459+
}
9460+
94269461
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
94279462
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
94289463
}
@@ -11223,6 +11258,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1122311258
case GGML_OP_SUB:
1122411259
case GGML_OP_MUL:
1122511260
case GGML_OP_DIV:
11261+
case GGML_OP_ADD1:
1122611262
case GGML_OP_CONCAT:
1122711263
case GGML_OP_UPSCALE:
1122811264
case GGML_OP_SCALE:
@@ -11435,6 +11471,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1143511471
case GGML_OP_UPSCALE:
1143611472
ggml_vk_upscale(ctx, compute_ctx, src0, node);
1143711473

11474+
break;
11475+
case GGML_OP_ADD1:
11476+
ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
11477+
11478+
break;
1143811479
break;
1143911480
case GGML_OP_SCALE:
1144011481
ggml_vk_scale(ctx, compute_ctx, src0, node);
@@ -11721,6 +11762,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1172111762
case GGML_OP_SUB:
1172211763
case GGML_OP_MUL:
1172311764
case GGML_OP_DIV:
11765+
case GGML_OP_ADD1:
1172411766
case GGML_OP_ADD_ID:
1172511767
case GGML_OP_CONCAT:
1172611768
case GGML_OP_UPSCALE:
@@ -13699,6 +13741,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1369913741
case GGML_OP_UPSCALE:
1370013742
case GGML_OP_ACC:
1370113743
case GGML_OP_CONCAT:
13744+
case GGML_OP_ADD1:
1370213745
case GGML_OP_SCALE:
1370313746
case GGML_OP_PAD:
1370413747
case GGML_OP_ROLL:
@@ -14181,6 +14224,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1418114224
} else if (tensor->op == GGML_OP_SCALE) {
1418214225
const float * params = (const float *)tensor->op_params;
1418314226
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
14227+
} else if (tensor->op == GGML_OP_ADD1) {
14228+
tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
1418414229
} else if (tensor->op == GGML_OP_SQR) {
1418514230
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
1418614231
} else if (tensor->op == GGML_OP_SQRT) {
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_16bit_storage : require
4+
5+
#include "types.glsl"
6+
#include "generic_binary_head.glsl"
7+
8+
const uint num_threads = 256;
9+
10+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
11+
12+
void main() {
13+
uint idx = get_idx();
14+
15+
const uint num_iter = 2;
16+
17+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
18+
if (idx >= p.ne) {
19+
continue;
20+
}
21+
uint i00, i01, i02, i03;
22+
get_indices(idx, i00, i01, i02, i03);
23+
24+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset()]));
25+
26+
idx += num_threads;
27+
}
28+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,9 @@ void process_shaders() {
842842
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
843843
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
844844
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
845+
string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
846+
string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
847+
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
845848

846849
for (auto rte : {false, true}) {
847850
std::string suffix = rte ? "_rte" : "";

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7015,6 +7015,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
70157015
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
70167016

70177017
test_cases.emplace_back(new test_add1());
7018+
test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1}));
70187019
test_cases.emplace_back(new test_scale());
70197020
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
70207021
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test

0 commit comments

Comments
 (0)