Skip to content

Commit 1496afd

Browse files
committed
feat: Add FILL for metal
Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 434ec07 commit 1496afd

File tree

6 files changed

+62
-0
lines changed

6 files changed

+62
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
187187
const char * op_str = "undefined";
188188
switch (op->op) {
189189
case GGML_OP_SCALE: op_str = "scale"; break;
190+
case GGML_OP_FILL: op_str = "fill"; break;
190191
case GGML_OP_CLAMP: op_str = "clamp"; break;
191192
case GGML_OP_SQR: op_str = "sqr"; break;
192193
case GGML_OP_SQRT: op_str = "sqrt"; break;

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
865865
case GGML_OP_ACC:
866866
case GGML_OP_REPEAT:
867867
case GGML_OP_SCALE:
868+
case GGML_OP_FILL:
868869
case GGML_OP_CONV_TRANSPOSE_1D:
869870
return true;
870871
case GGML_OP_CONV_TRANSPOSE_2D:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ typedef struct {
182182
float bias;
183183
} ggml_metal_kargs_scale;
184184

185+
typedef struct {
186+
float val;
187+
} ggml_metal_kargs_fill;
188+
185189
typedef struct {
186190
float min;
187191
float max;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
286286
{
287287
n_fuse = ggml_metal_op_scale(ctx, idx);
288288
} break;
289+
case GGML_OP_FILL:
290+
{
291+
n_fuse = ggml_metal_op_fill(ctx, idx);
292+
} break;
289293
case GGML_OP_CLAMP:
290294
{
291295
n_fuse = ggml_metal_op_clamp(ctx, idx);
@@ -737,6 +741,41 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
737741
return 1;
738742
}
739743

744+
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
745+
ggml_tensor * op = ctx->node(idx);
746+
747+
ggml_metal_library_t lib = ctx->lib;
748+
ggml_metal_encoder_t enc = ctx->enc;
749+
750+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
751+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
752+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
753+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
754+
755+
const float val = ggml_get_op_params_f32(op, 0);
756+
757+
ggml_metal_kargs_fill args = {
758+
/*.val =*/ val
759+
};
760+
761+
int64_t n = ggml_nelements(op);
762+
763+
if (n % 4 == 0) {
764+
n /= 4;
765+
}
766+
767+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
768+
769+
ggml_metal_encoder_set_pipeline(enc, pipeline);
770+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
771+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
772+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
773+
774+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
775+
776+
return 1;
777+
}
778+
740779
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
741780
ggml_tensor * op = ctx->node(idx);
742781

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
4747
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
4848
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
4949
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
50+
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
5051
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
5152
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
5253
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,22 @@ kernel void kernel_scale_f32_4(
12491249
dst[tpig] = src0[tpig] * args.scale + args.bias;
12501250
}
12511251

1252+
kernel void kernel_fill_f32(
1253+
constant ggml_metal_kargs_fill & args,
1254+
device const float * src0,
1255+
device float * dst,
1256+
uint tpig[[thread_position_in_grid]]) {
1257+
dst[tpig] = args.val;
1258+
}
1259+
1260+
kernel void kernel_fill_f32_4(
1261+
constant ggml_metal_kargs_fill & args,
1262+
device const float4 * src0,
1263+
device float4 * dst,
1264+
uint tpig[[thread_position_in_grid]]) {
1265+
dst[tpig] = args.val;
1266+
}
1267+
12521268
kernel void kernel_clamp_f32(
12531269
constant ggml_metal_kargs_clamp & args,
12541270
device const float * src0,

0 commit comments

Comments
 (0)