diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 33ab43d58f..ba3c342751 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -175,6 +175,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal const char * op_str = "undefined"; switch (op->op) { case GGML_OP_SCALE: op_str = "scale"; break; + case GGML_OP_FILL: op_str = "fill"; break; case GGML_OP_CLAMP: op_str = "clamp"; break; case GGML_OP_SQR: op_str = "sqr"; break; case GGML_OP_SQRT: op_str = "sqrt"; break; @@ -199,6 +200,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; case GGML_UNARY_OP_EXP: op_str = "exp"; break; + case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; + case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); @@ -332,6 +335,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_ return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_TRI); + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "tri"; + const int ttype = op->op_params[0]; + + snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 17baef2017..4be0432ea7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -114,6 +114,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index d22672a816..0d5a9814c7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -818,6 +818,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_EXPM1: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -850,6 +852,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: + case GGML_OP_FILL: case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_CONV_TRANSPOSE_2D: @@ -867,6 +870,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + case GGML_OP_TRI: + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_SUM_ROWS: case GGML_OP_CUMSUM: case GGML_OP_MEAN: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 342dc4f8c3..30109f83e1 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -182,6 +182,10 @@ typedef struct { float bias; } ggml_metal_kargs_scale; +typedef struct { + float val; +} ggml_metal_kargs_fill; + typedef struct { float min; float max; @@ -831,6 +835,25 @@ typedef struct { float slope; } ggml_metal_kargs_leaky_relu; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_tri; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index edb227a210..9efd51abba 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -286,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_scale(ctx, idx); } break; + case GGML_OP_FILL: + { + n_fuse = ggml_metal_op_fill(ctx, idx); + } break; case GGML_OP_CLAMP: { n_fuse = ggml_metal_op_clamp(ctx, idx); @@ -414,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_leaky_relu(ctx, idx); } break; + case GGML_OP_TRI: + { + n_fuse = ggml_metal_op_tri(ctx, idx); + } break; case GGML_OP_FLASH_ATTN_EXT: { n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); @@ -733,6 +741,41 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const float val = ggml_get_op_params_f32(op, 0); + + ggml_metal_kargs_fill args = { + /*.val =*/ val + }; + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3899,6 +3942,57 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index b5546146e1..902b544523 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -47,6 +47,7 @@ int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx); int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); @@ -83,6 +84,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3ca8d9b322..4b78d5a2ba 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1249,6 +1249,22 @@ kernel void kernel_scale_f32_4( dst[tpig] = src0[tpig] * args.scale + args.bias; } +kernel void kernel_fill_f32( + constant ggml_metal_kargs_fill & args, + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + +kernel void kernel_fill_f32_4( + constant ggml_metal_kargs_fill & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + kernel void kernel_clamp_f32( constant ggml_metal_kargs_clamp & args, device const float * src0, @@ -1595,6 +1611,36 @@ kernel void kernel_exp_f32_4( dst[tpig] = exp(src0[tpig]); } +kernel void kernel_softplus_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + +kernel void kernel_softplus_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + +kernel void kernel_expm1_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]) - 1.0f; +} + +kernel void kernel_expm1_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]) - 1.0f; +} + kernel void kernel_reglu_f32( constant ggml_metal_kargs_glu & args, device const char * src0, @@ -1943,6 +1989,75 @@ typedef decltype(kernel_cumsum_add) kernel_cumsum_add_t; template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add; + +template +bool _ggml_vec_tri_cmp(const int i, const int r); + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i < r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i <= r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i > r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i >= r; +} + +template +kernel void kernel_tri( + constant ggml_metal_kargs_tri & args, + device const char * src0, + device const char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + // Use the comparison as a mask for branchless + dst_row[i0] = static_cast(_ggml_vec_tri_cmp(i0, i1)) * src_row[i0]; + } +} + +typedef decltype(kernel_tri) kernel_tri_t; + +template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri; +#endif + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args,