From a9c924409686287d669980e40f68b063fac5f622 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 08:07:13 -0700 Subject: [PATCH 01/10] feat(wip): Port initial TRI impl from pervious work The kernel does not work and is not optimized, but the code compiles and runs, so this will be the starting point now that the core op has been merged. Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 20 ++++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 2 + ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 60 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 50 +++++++++++++++++++ 7 files changed, 155 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index c647baef878..e92397a3b14 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -356,6 +356,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_libr return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "tri"; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + return ggml_metal_library_compile_pipeline(lib, base, name, nullptr); +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 3976e622b9b..2cfb92bbb66 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -115,6 +115,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4d2bfcf91c6..84171665352 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -880,6 +880,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + case GGML_OP_TRI: + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_SUM_ROWS: case GGML_OP_CUMSUM: case GGML_OP_MEAN: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 342dc4f8c37..2c691658635 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -831,6 +831,27 @@ typedef struct { float slope; } ggml_metal_kargs_leaky_relu; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float c; + uint32_t ttype; +} ggml_metal_kargs_tri; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9871e976f23..bedb9f94ddb 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -414,6 +414,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_leaky_relu(ctx, idx); } break; + case GGML_OP_TRI: + { + n_fuse = ggml_metal_op_tri(ctx, idx); + } break; case GGML_OP_FLASH_ATTN_EXT: { n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); @@ -3899,6 +3903,62 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0]; + const float c = *((float *) &(op->op_params[1])); + + ggml_metal_kargs_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.c =*/ c, + /*.ttype =*/ static_cast(ttype) + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_tri(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index b5546146e13..d07028652c2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -83,6 +83,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3ca8d9b322b..863cc2520f0 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1943,6 +1943,56 @@ typedef decltype(kernel_cumsum_add) kernel_cumsum_add_t; template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add; +inline static bool _ggml_vec_tri_cmp(const int i, const int r, const uint32_t type) { + switch (type) { + // ggml.h:620 + case /* GGML_TRI_TYPE_LOWER */ 3: return i < r; break; + case /* GGML_TRI_TYPE_LOWER_DIAG */ 2: return i <= r; break; + case /* GGML_TRI_TYPE_UPPER */ 1: return i > r; break; + case /* GGML_TRI_TYPE_UPPER_DIAG */ 0: return i >= r; break; + } +} + +template +kernel void kernel_tri( + constant ggml_metal_kargs_tri & args, + device const char * src0, + device const char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + const bool keep_org_val = isnan(args.c); + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype) + ? (keep_org_val ? src_row[i0] : static_cast(args.c)) + : static_cast(0.f); + } +} + +typedef decltype(kernel_tri) kernel_tri_t; + +template [[host_name("kernel_tri_f32")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16")]] kernel kernel_tri_t kernel_tri; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_tri_bf16")]] kernel kernel_tri_t kernel_tri; +#endif + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args, From 2a7bbc77ed30b7692b257320308e5a5a685486cf Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 08:18:51 -0700 Subject: [PATCH 02/10] fix: Remove argument for constant val override This was added in the original draft, but later removed. With this, the kernel now passes tests. Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-impl.h | 1 - ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 -- ggml/src/ggml-metal/ggml-metal.metal | 5 +---- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 2c691658635..4fd02792ea3 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -848,7 +848,6 @@ typedef struct { uint64_t nb1; uint64_t nb2; uint64_t nb3; - float c; uint32_t ttype; } ggml_metal_kargs_tri; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index bedb9f94ddb..8f675de11f5 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3915,7 +3915,6 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0]; - const float c = *((float *) &(op->op_params[1])); ggml_metal_kargs_tri args = { /*.ne00 =*/ ne00, @@ -3934,7 +3933,6 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { /*.nb1 =*/ nb1, /*.nb2 =*/ nb2, /*.nb3 =*/ nb3, - /*.c =*/ c, /*.ttype =*/ static_cast(ttype) }; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 863cc2520f0..556550cc3e5 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1977,11 +1977,8 @@ kernel void kernel_tri( // Each thread is a single element of the row if ne00 < max threads per // threadgroup, so this will loop once for each index that this thread is // responsible for - const bool keep_org_val = isnan(args.c); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { - dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype) - ? (keep_org_val ? src_row[i0] : static_cast(args.c)) - : static_cast(0.f); + dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype) ? src_row[i0] : static_cast(0.f); } } From f2ad88776f953dc32c45473a7959735f6cd2def0 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 08:34:51 -0700 Subject: [PATCH 03/10] feat: Move the ttype conditional to templating to avoid conditional in kernel Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 3 +- ggml/src/ggml-metal/ggml-metal-impl.h | 1 - ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 -- ggml/src/ggml-metal/ggml-metal.metal | 51 ++++++++++++++++------- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e92397a3b14..979ae19804a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -363,8 +363,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri(ggml_metal_library_t l char name[256]; const char * op_str = "tri"; + const int ttype = op->op_params[0]; - snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype); snprintf(name, 256, "%s", base); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 4fd02792ea3..a0543684f83 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -848,7 +848,6 @@ typedef struct { uint64_t nb1; uint64_t nb2; uint64_t nb3; - uint32_t ttype; } ggml_metal_kargs_tri; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 8f675de11f5..abda2e8d087 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3914,8 +3914,6 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); - const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0]; - ggml_metal_kargs_tri args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -3933,7 +3931,6 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { /*.nb1 =*/ nb1, /*.nb2 =*/ nb2, /*.nb3 =*/ nb3, - /*.ttype =*/ static_cast(ttype) }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_tri(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 556550cc3e5..f479fb46bf7 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1943,17 +1943,31 @@ typedef decltype(kernel_cumsum_add) kernel_cumsum_add_t; template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add; -inline static bool _ggml_vec_tri_cmp(const int i, const int r, const uint32_t type) { - switch (type) { - // ggml.h:620 - case /* GGML_TRI_TYPE_LOWER */ 3: return i < r; break; - case /* GGML_TRI_TYPE_LOWER_DIAG */ 2: return i <= r; break; - case /* GGML_TRI_TYPE_UPPER */ 1: return i > r; break; - case /* GGML_TRI_TYPE_UPPER_DIAG */ 0: return i >= r; break; - } + +template +bool _ggml_vec_tri_cmp(const int i, const int r); + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i < r; } -template +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i <= r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i > r; +} + +template<> +bool _ggml_vec_tri_cmp(const int i, const int r) { + return i >= r; +} + +template kernel void kernel_tri( constant ggml_metal_kargs_tri & args, device const char * src0, @@ -1978,16 +1992,25 @@ kernel void kernel_tri( // threadgroup, so this will loop once for each index that this thread is // responsible for for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { - dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype) ? src_row[i0] : static_cast(0.f); + dst_row[i0] = _ggml_vec_tri_cmp(i0, i1) ? src_row[i0] : static_cast(0.f); } } -typedef decltype(kernel_tri) kernel_tri_t; +typedef decltype(kernel_tri) kernel_tri_t; -template [[host_name("kernel_tri_f32")]] kernel kernel_tri_t kernel_tri; -template [[host_name("kernel_tri_f16")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_tri_bf16")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri; #endif template From 6a2705036a6f1550b2432fd792856931626b3507 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 10:02:13 -0700 Subject: [PATCH 04/10] fix: Type fixes Signed-off-by: Gabe Goodhart Co-authored-by: Georgi Gerganov Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-device.cpp | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 16 ++++++++-------- ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 6 +++--- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 979ae19804a..84f1423d660 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -357,6 +357,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_libr } ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_TRI); GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); char base[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index a0543684f83..c93f75eb5cb 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -832,18 +832,18 @@ typedef struct { } ggml_metal_kargs_leaky_relu; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; - int64_t ne0; - int64_t ne1; - int64_t ne2; - int64_t ne3; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index abda2e8d087..3c81c9da763 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3912,7 +3912,7 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_kargs_tri args = { /*.ne00 =*/ ne00, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f479fb46bf7..fb9ba97ddb9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1977,9 +1977,9 @@ kernel void kernel_tri( ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { return; From 7cbbff74b99dfcb0cbb88c2aff3572ec0f5f083e Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 11:48:24 -0700 Subject: [PATCH 05/10] feat: Add softplus for metal Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal.metal | 16 ++++++++++++++++ 3 files changed, 18 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 84f1423d660..f678da076cc 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -211,6 +211,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; case GGML_UNARY_OP_EXP: op_str = "exp"; break; + case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 84171665352..c9cfdfaac23 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -831,6 +831,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index fb9ba97ddb9..c91f5371482 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1595,6 +1595,22 @@ kernel void kernel_exp_f32_4( dst[tpig] = exp(src0[tpig]); } +kernel void kernel_softplus_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 20.0f) ? x : log(1.0f + exp(x)); +} + +kernel void kernel_softplus_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); +} + kernel void kernel_reglu_f32( constant ggml_metal_kargs_glu & args, device const char * src0, From 434ec07479813c6df61e47b4778a84ca54b75a61 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 12:01:11 -0700 Subject: [PATCH 06/10] feat: Add EXPM1 for metal Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal.metal | 14 ++++++++++++++ 3 files changed, 16 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f678da076cc..168be07163b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -212,6 +212,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; case GGML_UNARY_OP_EXP: op_str = "exp"; break; case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; + case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c9cfdfaac23..ea26cd1bf7c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -832,6 +832,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_EXPM1: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c91f5371482..4cee08a2756 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1611,6 +1611,20 @@ kernel void kernel_softplus_f32_4( dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); } +kernel void kernel_expm1_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]) - 1.0f; +} + +kernel void kernel_expm1_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]) - 1.0f; +} + kernel void kernel_reglu_f32( constant ggml_metal_kargs_glu & args, device const char * src0, From 1496afdfb51535e118f4eef5520fad11decc0459 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 12:57:18 -0700 Subject: [PATCH 07/10] feat: Add FILL for metal Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 4 +++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 39 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 16 ++++++++++ 6 files changed, 62 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 168be07163b..ea6640d313d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -187,6 +187,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t const char * op_str = "undefined"; switch (op->op) { case GGML_OP_SCALE: op_str = "scale"; break; + case GGML_OP_FILL: op_str = "fill"; break; case GGML_OP_CLAMP: op_str = "clamp"; break; case GGML_OP_SQR: op_str = "sqr"; break; case GGML_OP_SQRT: op_str = "sqrt"; break; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index ea26cd1bf7c..a9c46a83772 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -865,6 +865,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: + case GGML_OP_FILL: case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_CONV_TRANSPOSE_2D: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index c93f75eb5cb..30109f83e10 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -182,6 +182,10 @@ typedef struct { float bias; } ggml_metal_kargs_scale; +typedef struct { + float val; +} ggml_metal_kargs_fill; + typedef struct { float min; float max; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3c81c9da763..66569ade028 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -286,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_scale(ctx, idx); } break; + case GGML_OP_FILL: + { + n_fuse = ggml_metal_op_fill(ctx, idx); + } break; case GGML_OP_CLAMP: { n_fuse = ggml_metal_op_clamp(ctx, idx); @@ -737,6 +741,41 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const float val = ggml_get_op_params_f32(op, 0); + + ggml_metal_kargs_fill args = { + /*.val =*/ val + }; + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index d07028652c2..902b5445232 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -47,6 +47,7 @@ int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx); int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4cee08a2756..097008600f9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1249,6 +1249,22 @@ kernel void kernel_scale_f32_4( dst[tpig] = src0[tpig] * args.scale + args.bias; } +kernel void kernel_fill_f32( + constant ggml_metal_kargs_fill & args, + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + +kernel void kernel_fill_f32_4( + constant ggml_metal_kargs_fill & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + kernel void kernel_clamp_f32( constant ggml_metal_kargs_clamp & args, device const float * src0, From 7690808999fa19c57752e6a4cf5792746943139f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 21:21:12 -0700 Subject: [PATCH 08/10] refactor: Branchless version of tri using _ggml_vec_tri_cmp as a mask Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.metal | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 097008600f9..153ea2c4465 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2038,7 +2038,8 @@ kernel void kernel_tri( // threadgroup, so this will loop once for each index that this thread is // responsible for for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { - dst_row[i0] = _ggml_vec_tri_cmp(i0, i1) ? src_row[i0] : static_cast(0.f); + // Use the comparison as a mask for branchless + dst_row[i0] = static_cast(_ggml_vec_tri_cmp(i0, i1)) * src_row[i0]; } } From 60fe39b0497534c1259a85b7bddb214cc921fafc Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 4 Dec 2025 09:18:05 -0700 Subject: [PATCH 09/10] fix: Remove unused arguments Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.metal | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 153ea2c4465..f348a91c4fc 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2020,8 +2020,6 @@ kernel void kernel_tri( device const char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i3 = tgpig.z; const int i2 = tgpig.y; From 338acb32b4bb1280cc301c1ab920bbaa21429345 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 4 Dec 2025 09:28:29 -0700 Subject: [PATCH 10/10] refactor: Use select instead of branch for softplus non-vec Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.metal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f348a91c4fc..4b78d5a2bad 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1616,7 +1616,7 @@ kernel void kernel_softplus_f32( device float * dst, uint tpig[[thread_position_in_grid]]) { device const float & x = src0[tpig]; - dst[tpig] = (x > 20.0f) ? x : log(1.0f + exp(x)); + dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); } kernel void kernel_softplus_f32_4(