Skip to content

Commit 7cbbff7

Browse files
committed
feat: Add softplus for metal
Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 6a27050 commit 7cbbff7

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
211211
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
212212
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
213213
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
214+
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
214215
default: GGML_ABORT("fatal error");
215216
} break;
216217
default: GGML_ABORT("fatal error");

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
831831
case GGML_UNARY_OP_HARDSWISH:
832832
case GGML_UNARY_OP_HARDSIGMOID:
833833
case GGML_UNARY_OP_EXP:
834+
case GGML_UNARY_OP_SOFTPLUS:
834835
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
835836
default:
836837
return false;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,22 @@ kernel void kernel_exp_f32_4(
15951595
dst[tpig] = exp(src0[tpig]);
15961596
}
15971597

1598+
kernel void kernel_softplus_f32(
1599+
device const float * src0,
1600+
device float * dst,
1601+
uint tpig[[thread_position_in_grid]]) {
1602+
device const float & x = src0[tpig];
1603+
dst[tpig] = (x > 20.0f) ? x : log(1.0f + exp(x));
1604+
}
1605+
1606+
kernel void kernel_softplus_f32_4(
1607+
device const float4 * src0,
1608+
device float4 * dst,
1609+
uint tpig[[thread_position_in_grid]]) {
1610+
device const float4 & x = src0[tpig];
1611+
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1612+
}
1613+
15981614
kernel void kernel_reglu_f32(
15991615
constant ggml_metal_kargs_glu & args,
16001616
device const char * src0,

0 commit comments

Comments
 (0)