Skip to content

Commit 434ec07

Browse files
committed
feat: Add EXPM1 for metal
Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 7cbbff7 commit 434ec07

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
212212
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
213213
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
214214
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
215+
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
215216
default: GGML_ABORT("fatal error");
216217
} break;
217218
default: GGML_ABORT("fatal error");

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
832832
case GGML_UNARY_OP_HARDSIGMOID:
833833
case GGML_UNARY_OP_EXP:
834834
case GGML_UNARY_OP_SOFTPLUS:
835+
case GGML_UNARY_OP_EXPM1:
835836
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
836837
default:
837838
return false;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,20 @@ kernel void kernel_softplus_f32_4(
16111611
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
16121612
}
16131613

1614+
kernel void kernel_expm1_f32(
1615+
device const float * src0,
1616+
device float * dst,
1617+
uint tpig[[thread_position_in_grid]]) {
1618+
dst[tpig] = exp(src0[tpig]) - 1.0f;
1619+
}
1620+
1621+
kernel void kernel_expm1_f32_4(
1622+
device const float4 * src0,
1623+
device float4 * dst,
1624+
uint tpig[[thread_position_in_grid]]) {
1625+
dst[tpig] = exp(src0[tpig]) - 1.0f;
1626+
}
1627+
16141628
kernel void kernel_reglu_f32(
16151629
constant ggml_metal_kargs_glu & args,
16161630
device const char * src0,

0 commit comments

Comments
 (0)