@@ -665,6 +665,10 @@ struct vk_device_struct {
665665 vk_pipeline pipeline_hardswish[2];
666666 vk_pipeline pipeline_abs[2];
667667
668+ vk_pipeline pipeline_add1_f16_f16;
669+ vk_pipeline pipeline_add1_f16_f32;
670+ vk_pipeline pipeline_add1_f32_f32;
671+
668672 vk_pipeline pipeline_geglu[2];
669673 vk_pipeline pipeline_reglu[2];
670674 vk_pipeline pipeline_swiglu[2];
@@ -3839,6 +3843,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
38393843 CREATE_UNARY_RTE(exp)
38403844#undef CREATE_UNARY_RTE
38413845
3846+ ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3847+ ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3848+ ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3849+
38423850#define CREATE_GLU(name) \
38433851 if (device->float_controls_rte_fp16) { \
38443852 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@@ -8527,6 +8535,17 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85278535 }
85288536 }
85298537 return nullptr;
8538+ case GGML_OP_ADD1:
8539+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8540+ return ctx->device->pipeline_add1_f16_f16;
8541+ }
8542+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
8543+ return ctx->device->pipeline_add1_f16_f32;
8544+ }
8545+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8546+ return ctx->device->pipeline_add1_f32_f32;
8547+ }
8548+ return nullptr;
85308549 default:
85318550 return nullptr;
85328551 }
@@ -8817,6 +8836,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88178836 case GGML_OP_SUB:
88188837 case GGML_OP_DIV:
88198838 case GGML_OP_MUL:
8839+ case GGML_OP_ADD1:
88208840 case GGML_OP_SCALE:
88218841 case GGML_OP_SQR:
88228842 case GGML_OP_SQRT:
@@ -9423,6 +9443,21 @@ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, cons
94239443 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
94249444}
94259445
9446+ static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9447+ const uint32_t src0_type_size = ggml_type_size(src0->type);
9448+ const uint32_t src1_type_size = ggml_type_size(src1->type);
9449+ const uint32_t dst_type_size = ggml_type_size(dst->type);
9450+
9451+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {
9452+ (uint32_t)ggml_nelements(src0),
9453+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
9454+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9455+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
9456+ 0,
9457+ 0.0f, 0.0f, 0,
9458+ });
9459+ }
9460+
94269461static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
94279462 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
94289463}
@@ -11223,6 +11258,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1122311258 case GGML_OP_SUB:
1122411259 case GGML_OP_MUL:
1122511260 case GGML_OP_DIV:
11261+ case GGML_OP_ADD1:
1122611262 case GGML_OP_CONCAT:
1122711263 case GGML_OP_UPSCALE:
1122811264 case GGML_OP_SCALE:
@@ -11435,6 +11471,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1143511471 case GGML_OP_UPSCALE:
1143611472 ggml_vk_upscale(ctx, compute_ctx, src0, node);
1143711473
11474+ break;
11475+ case GGML_OP_ADD1:
11476+ ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
11477+
11478+ break;
1143811479 break;
1143911480 case GGML_OP_SCALE:
1144011481 ggml_vk_scale(ctx, compute_ctx, src0, node);
@@ -11721,6 +11762,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1172111762 case GGML_OP_SUB:
1172211763 case GGML_OP_MUL:
1172311764 case GGML_OP_DIV:
11765+ case GGML_OP_ADD1:
1172411766 case GGML_OP_ADD_ID:
1172511767 case GGML_OP_CONCAT:
1172611768 case GGML_OP_UPSCALE:
@@ -13699,6 +13741,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1369913741 case GGML_OP_UPSCALE:
1370013742 case GGML_OP_ACC:
1370113743 case GGML_OP_CONCAT:
13744+ case GGML_OP_ADD1:
1370213745 case GGML_OP_SCALE:
1370313746 case GGML_OP_PAD:
1370413747 case GGML_OP_ROLL:
@@ -14181,6 +14224,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1418114224 } else if (tensor->op == GGML_OP_SCALE) {
1418214225 const float * params = (const float *)tensor->op_params;
1418314226 tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
14227+ } else if (tensor->op == GGML_OP_ADD1) {
14228+ tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
1418414229 } else if (tensor->op == GGML_OP_SQR) {
1418514230 tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
1418614231 } else if (tensor->op == GGML_OP_SQRT) {
0 commit comments