Skip to content

Commit 09661b9

Browse files
committed
adding conv_transpose_2d
1 parent 1dc9406 commit 09661b9

File tree

3 files changed

+144
-0
lines changed

3 files changed

+144
-0
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,19 @@ typedef struct {
513513
uint64_t nb1;
514514
} ggml_metal_kargs_conv_transpose_1d;
515515

516+
typedef struct {
517+
int32_t IC;
518+
int32_t IH;
519+
int32_t IW;
520+
int32_t KH;
521+
int32_t KW;
522+
int32_t OC;
523+
int32_t s0;
524+
uint64_t nb0;
525+
uint64_t nb1;
526+
uint64_t nb2;
527+
} ggml_metal_kargs_conv_transpose_2d;
528+
516529
typedef struct {
517530
uint64_t ofs0;
518531
uint64_t ofs1;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
364364
{
365365
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
366366
} break;
367+
case GGML_OP_CONV_TRANSPOSE_2D:
368+
{
369+
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
370+
} break;
367371
case GGML_OP_UPSCALE:
368372
{
369373
n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -3068,6 +3072,58 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
30683072
return 1;
30693073
}
30703074

3075+
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
3076+
ggml_tensor * op = ctx->node(idx);
3077+
3078+
ggml_metal_library_t lib = ctx->lib;
3079+
ggml_metal_encoder_t enc = ctx->enc;
3080+
3081+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3082+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3083+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3084+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3085+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3086+
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3087+
3088+
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3089+
3090+
const int32_t IC = op->src[1]->ne[2];
3091+
const int32_t IH = op->src[1]->ne[1];
3092+
const int32_t IW = op->src[1]->ne[0];
3093+
3094+
const int32_t KH = op->src[0]->ne[1];
3095+
const int32_t KW = op->src[0]->ne[0];
3096+
3097+
const int32_t OW = op->ne[0];
3098+
const int32_t OH = op->ne[1];
3099+
const int32_t OC = op->ne[2];
3100+
3101+
ggml_metal_kargs_conv_transpose_2d args = {
3102+
/*.IC =*/ IC,
3103+
/*.IH =*/ IH,
3104+
/*.IW =*/ IW,
3105+
/*.KH =*/ KH,
3106+
/*.KW =*/ KW,
3107+
/*.OC =*/ OC,
3108+
/*.s0 =*/ s0,
3109+
/*.nb0 =*/ nb0,
3110+
/*.nb1 =*/ nb1,
3111+
/*.nb2 =*/ nb2,
3112+
};
3113+
3114+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3115+
3116+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3117+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3118+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3119+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3120+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3121+
3122+
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, 1, 1, 1);
3123+
3124+
return 1;
3125+
}
3126+
30713127
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
30723128
ggml_tensor * op = ctx->node(idx);
30733129

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4131,6 +4131,81 @@ kernel void kernel_conv_transpose_1d<half>(
41314131
uint3 tgpig[[threadgroup_position_in_grid]],
41324132
uint3 tgpg[[threadgroups_per_grid]]);
41334133

4134+
4135+
typedef void (conv_transpose_2d_t)(
4136+
constant ggml_metal_kargs_conv_transpose_2d & args,
4137+
device const float * src0,
4138+
device const float * src1,
4139+
device char * dst,
4140+
uint3 tgpig[[threadgroup_position_in_grid]],
4141+
uint3 tgpg[[threadgroups_per_grid]]);
4142+
4143+
template <typename T>
4144+
kernel void kernel_conv_transpose_2d(
4145+
constant ggml_metal_kargs_conv_transpose_2d & args,
4146+
device const T * src0,
4147+
device const float * src1,
4148+
device char * dst,
4149+
uint3 tgpig[[threadgroup_position_in_grid]],
4150+
uint3 tgpg[[threadgroups_per_grid]]) {
4151+
4152+
const int32_t out_x = tgpig[0];
4153+
const int32_t out_y = tgpig[1];
4154+
const int32_t out_c = tgpig[2];
4155+
4156+
float v = 0.0f;
4157+
4158+
for (int32_t in_c = 0; in_c<args.IC; in_c++){
4159+
for (int32_t kh = 0; kh<args.KH; kh++){
4160+
4161+
int32_t in_y = out_y - kh;
4162+
4163+
if (in_y < 0 || in_y % args.s0) continue;
4164+
4165+
in_y /= args.s0;
4166+
4167+
if (in_y >= args.IH) continue;
4168+
4169+
for (int32_t kw = 0; kw<args.KW; kw++){
4170+
int32_t in_x = out_x - kw;
4171+
4172+
if (in_x <0 || in_x % args.s0) continue;
4173+
4174+
in_x /= args.s0;
4175+
4176+
if (in_x >= args.IW) continue;
4177+
4178+
const int32_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
4179+
const int32_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
4180+
4181+
v += (float)src0[kernel_idx] * src1[input_idx];
4182+
4183+
}
4184+
}
4185+
}
4186+
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
4187+
4188+
dst_ptr[0] = v;
4189+
}
4190+
4191+
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
4192+
kernel void kernel_conv_transpose_2d<float>(
4193+
constant ggml_metal_kargs_conv_transpose_2d & args,
4194+
device const float * src0,
4195+
device const float * src1,
4196+
device char * dst,
4197+
uint3 tgpig[[threadgroup_position_in_grid]],
4198+
uint3 tgpg[[threadgroups_per_grid]]);
4199+
4200+
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
4201+
kernel void kernel_conv_transpose_2d<half>(
4202+
constant ggml_metal_kargs_conv_transpose_2d & args,
4203+
device const half * src0,
4204+
device const float * src1,
4205+
device char * dst,
4206+
uint3 tgpig[[threadgroup_position_in_grid]],
4207+
uint3 tgpg[[threadgroups_per_grid]]);
4208+
41344209
kernel void kernel_upscale_f32(
41354210
constant ggml_metal_kargs_upscale & args,
41364211
device const char * src0,

0 commit comments

Comments
 (0)