Skip to content

Commit bde188d

Browse files
metal: TRI, FILL, EXPM1, SOFTPLUS (#16623)
* feat(wip): Port initial TRI impl from pervious work The kernel does not work and is not optimized, but the code compiles and runs, so this will be the starting point now that the core op has been merged. Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Remove argument for constant val override This was added in the original draft, but later removed. With this, the kernel now passes tests. Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Move the ttype conditional to templating to avoid conditional in kernel Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Type fixes Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * feat: Add softplus for metal Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add EXPM1 for metal Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add FILL for metal Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Branchless version of tri using _ggml_vec_tri_cmp as a mask Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Remove unused arguments Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Use select instead of branch for softplus non-vec Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 9d02299 commit bde188d

File tree

7 files changed

+265
-0
lines changed

7 files changed

+265
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
175175
const char * op_str = "undefined";
176176
switch (op->op) {
177177
case GGML_OP_SCALE: op_str = "scale"; break;
178+
case GGML_OP_FILL: op_str = "fill"; break;
178179
case GGML_OP_CLAMP: op_str = "clamp"; break;
179180
case GGML_OP_SQR: op_str = "sqr"; break;
180181
case GGML_OP_SQRT: op_str = "sqrt"; break;
@@ -199,6 +200,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
199200
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
200201
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
201202
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
203+
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
204+
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
202205
default: GGML_ABORT("fatal error");
203206
} break;
204207
default: GGML_ABORT("fatal error");
@@ -332,6 +335,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_
332335
return res;
333336
}
334337

338+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
339+
GGML_ASSERT(op->op == GGML_OP_TRI);
340+
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
341+
342+
char base[256];
343+
char name[256];
344+
345+
const char * op_str = "tri";
346+
const int ttype = op->op_params[0];
347+
348+
snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
349+
350+
snprintf(name, 256, "%s", base);
351+
352+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
353+
if (!res.pipeline) {
354+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
355+
}
356+
357+
return res;
358+
}
359+
335360
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
336361
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
337362

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum
114114
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
115115
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
116116
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
117+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
117118
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
118119
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
119120
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
818818
case GGML_UNARY_OP_HARDSWISH:
819819
case GGML_UNARY_OP_HARDSIGMOID:
820820
case GGML_UNARY_OP_EXP:
821+
case GGML_UNARY_OP_SOFTPLUS:
822+
case GGML_UNARY_OP_EXPM1:
821823
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
822824
default:
823825
return false;
@@ -850,6 +852,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
850852
case GGML_OP_ACC:
851853
case GGML_OP_REPEAT:
852854
case GGML_OP_SCALE:
855+
case GGML_OP_FILL:
853856
case GGML_OP_CONV_TRANSPOSE_1D:
854857
return true;
855858
case GGML_OP_CONV_TRANSPOSE_2D:
@@ -867,6 +870,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
867870
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
868871
case GGML_OP_SUM:
869872
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
873+
case GGML_OP_TRI:
874+
return ggml_is_contiguous_rows(op->src[0]);
870875
case GGML_OP_SUM_ROWS:
871876
case GGML_OP_CUMSUM:
872877
case GGML_OP_MEAN:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ typedef struct {
182182
float bias;
183183
} ggml_metal_kargs_scale;
184184

185+
typedef struct {
186+
float val;
187+
} ggml_metal_kargs_fill;
188+
185189
typedef struct {
186190
float min;
187191
float max;
@@ -831,6 +835,25 @@ typedef struct {
831835
float slope;
832836
} ggml_metal_kargs_leaky_relu;
833837

838+
typedef struct {
839+
int32_t ne00;
840+
int32_t ne01;
841+
int32_t ne02;
842+
int32_t ne03;
843+
uint64_t nb00;
844+
uint64_t nb01;
845+
uint64_t nb02;
846+
uint64_t nb03;
847+
int32_t ne0;
848+
int32_t ne1;
849+
int32_t ne2;
850+
int32_t ne3;
851+
uint64_t nb0;
852+
uint64_t nb1;
853+
uint64_t nb2;
854+
uint64_t nb3;
855+
} ggml_metal_kargs_tri;
856+
834857
typedef struct {
835858
int32_t ne00;
836859
int32_t ne01;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
286286
{
287287
n_fuse = ggml_metal_op_scale(ctx, idx);
288288
} break;
289+
case GGML_OP_FILL:
290+
{
291+
n_fuse = ggml_metal_op_fill(ctx, idx);
292+
} break;
289293
case GGML_OP_CLAMP:
290294
{
291295
n_fuse = ggml_metal_op_clamp(ctx, idx);
@@ -414,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
414418
{
415419
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
416420
} break;
421+
case GGML_OP_TRI:
422+
{
423+
n_fuse = ggml_metal_op_tri(ctx, idx);
424+
} break;
417425
case GGML_OP_FLASH_ATTN_EXT:
418426
{
419427
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
@@ -733,6 +741,41 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
733741
return 1;
734742
}
735743

744+
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
745+
ggml_tensor * op = ctx->node(idx);
746+
747+
ggml_metal_library_t lib = ctx->lib;
748+
ggml_metal_encoder_t enc = ctx->enc;
749+
750+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
751+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
752+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
753+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
754+
755+
const float val = ggml_get_op_params_f32(op, 0);
756+
757+
ggml_metal_kargs_fill args = {
758+
/*.val =*/ val
759+
};
760+
761+
int64_t n = ggml_nelements(op);
762+
763+
if (n % 4 == 0) {
764+
n /= 4;
765+
}
766+
767+
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
768+
769+
ggml_metal_encoder_set_pipeline(enc, pipeline);
770+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
771+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
772+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
773+
774+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
775+
776+
return 1;
777+
}
778+
736779
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
737780
ggml_tensor * op = ctx->node(idx);
738781

@@ -3899,6 +3942,57 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
38993942
return 1;
39003943
}
39013944

3945+
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
3946+
ggml_tensor * op = ctx->node(idx);
3947+
3948+
ggml_metal_library_t lib = ctx->lib;
3949+
ggml_metal_encoder_t enc = ctx->enc;
3950+
3951+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3952+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3953+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3954+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3955+
3956+
ggml_metal_kargs_tri args = {
3957+
/*.ne00 =*/ ne00,
3958+
/*.ne01 =*/ ne01,
3959+
/*.ne02 =*/ ne02,
3960+
/*.ne03 =*/ ne03,
3961+
/*.nb00 =*/ nb00,
3962+
/*.nb01 =*/ nb01,
3963+
/*.nb02 =*/ nb02,
3964+
/*.nb03 =*/ nb03,
3965+
/*.ne0 =*/ ne0,
3966+
/*.ne1 =*/ ne1,
3967+
/*.ne2 =*/ ne2,
3968+
/*.ne3 =*/ ne3,
3969+
/*.nb0 =*/ nb0,
3970+
/*.nb1 =*/ nb1,
3971+
/*.nb2 =*/ nb2,
3972+
/*.nb3 =*/ nb3,
3973+
};
3974+
3975+
auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
3976+
3977+
int nth = 32; // SIMD width
3978+
3979+
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3980+
nth *= 2;
3981+
}
3982+
3983+
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3984+
nth = std::min(nth, ne00);
3985+
3986+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3987+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3988+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3989+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3990+
3991+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3992+
3993+
return 1;
3994+
}
3995+
39023996
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
39033997
ggml_tensor * op = ctx->node(idx);
39043998

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
4747
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
4848
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
4949
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
50+
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
5051
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
5152
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
5253
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
@@ -83,6 +84,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
8384
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
8485
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
8586
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
87+
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
8688
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
8789
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
8890

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,22 @@ kernel void kernel_scale_f32_4(
12491249
dst[tpig] = src0[tpig] * args.scale + args.bias;
12501250
}
12511251

1252+
kernel void kernel_fill_f32(
1253+
constant ggml_metal_kargs_fill & args,
1254+
device const float * src0,
1255+
device float * dst,
1256+
uint tpig[[thread_position_in_grid]]) {
1257+
dst[tpig] = args.val;
1258+
}
1259+
1260+
kernel void kernel_fill_f32_4(
1261+
constant ggml_metal_kargs_fill & args,
1262+
device const float4 * src0,
1263+
device float4 * dst,
1264+
uint tpig[[thread_position_in_grid]]) {
1265+
dst[tpig] = args.val;
1266+
}
1267+
12521268
kernel void kernel_clamp_f32(
12531269
constant ggml_metal_kargs_clamp & args,
12541270
device const float * src0,
@@ -1595,6 +1611,36 @@ kernel void kernel_exp_f32_4(
15951611
dst[tpig] = exp(src0[tpig]);
15961612
}
15971613

1614+
kernel void kernel_softplus_f32(
1615+
device const float * src0,
1616+
device float * dst,
1617+
uint tpig[[thread_position_in_grid]]) {
1618+
device const float & x = src0[tpig];
1619+
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1620+
}
1621+
1622+
kernel void kernel_softplus_f32_4(
1623+
device const float4 * src0,
1624+
device float4 * dst,
1625+
uint tpig[[thread_position_in_grid]]) {
1626+
device const float4 & x = src0[tpig];
1627+
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1628+
}
1629+
1630+
kernel void kernel_expm1_f32(
1631+
device const float * src0,
1632+
device float * dst,
1633+
uint tpig[[thread_position_in_grid]]) {
1634+
dst[tpig] = exp(src0[tpig]) - 1.0f;
1635+
}
1636+
1637+
kernel void kernel_expm1_f32_4(
1638+
device const float4 * src0,
1639+
device float4 * dst,
1640+
uint tpig[[thread_position_in_grid]]) {
1641+
dst[tpig] = exp(src0[tpig]) - 1.0f;
1642+
}
1643+
15981644
kernel void kernel_reglu_f32(
15991645
constant ggml_metal_kargs_glu & args,
16001646
device const char * src0,
@@ -1943,6 +1989,75 @@ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
19431989

19441990
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
19451991

1992+
1993+
template<uint32_t ttype>
1994+
bool _ggml_vec_tri_cmp(const int i, const int r);
1995+
1996+
template<>
1997+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
1998+
return i < r;
1999+
}
2000+
2001+
template<>
2002+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
2003+
return i <= r;
2004+
}
2005+
2006+
template<>
2007+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
2008+
return i > r;
2009+
}
2010+
2011+
template<>
2012+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
2013+
return i >= r;
2014+
}
2015+
2016+
template<typename T, int ttype>
2017+
kernel void kernel_tri(
2018+
constant ggml_metal_kargs_tri & args,
2019+
device const char * src0,
2020+
device const char * dst,
2021+
uint3 tgpig[[threadgroup_position_in_grid]],
2022+
ushort3 tpitg[[thread_position_in_threadgroup]],
2023+
ushort3 ntg[[threads_per_threadgroup]]) {
2024+
const int i3 = tgpig.z;
2025+
const int i2 = tgpig.y;
2026+
const int i1 = tgpig.x;
2027+
2028+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
2029+
return;
2030+
}
2031+
2032+
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
2033+
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
2034+
2035+
// Each thread is a single element of the row if ne00 < max threads per
2036+
// threadgroup, so this will loop once for each index that this thread is
2037+
// responsible for
2038+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
2039+
// Use the comparison as a mask for branchless
2040+
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
2041+
}
2042+
}
2043+
2044+
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
2045+
2046+
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
2047+
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
2048+
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
2049+
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
2050+
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
2051+
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
2052+
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
2053+
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
2054+
#if defined(GGML_METAL_HAS_BF16)
2055+
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
2056+
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
2057+
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
2058+
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
2059+
#endif
2060+
19462061
template<typename T>
19472062
kernel void kernel_soft_max(
19482063
constant ggml_metal_kargs_soft_max & args,

0 commit comments

Comments
 (0)