Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2338,25 +2338,37 @@ extern "C" {
struct ggml_tensor * ids);

// partition into non-overlapping windows with padding if needed
// example:
// a: 768 64 64 1
// w: 14
// res: 768 14 14 25
// used in sam
// a: [B, H, W, C]
// result: [B*NPY*NPX, w, w, C]
// NPY = ceil(H/w)
// NPX = ceil(W/w)
GGML_API struct ggml_tensor * ggml_win_part(
struct ggml_context * ctx,
struct ggml_tensor * a,
int w);

// reverse of ggml_win_part
// used in sam
GGML_API struct ggml_tensor * ggml_win_unpart(
struct ggml_context * ctx,
struct ggml_tensor * a,
int w0,
int h0,
int w);

// reverse of ggml_win_part with explicit output dimensions
// a: [C, w, w, B*NPY*NPX]
// result: [C, w0, h0, b0]
// w0, h0: output width and height (may differ from input due to padding removal)
// b0: output batch size
// w: window size (must match the one used in ggml_win_part)
GGML_API struct ggml_tensor * ggml_win_unpart_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
int w0,
int h0,
int b0,
int w);

GGML_API struct ggml_tensor * ggml_unary(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand All @@ -2367,14 +2379,18 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_unary_op op);

// used in sam
// relative position encoding
// a: [C, rel_pos_size]
// res: [C, kh, qh]
// where rel_pos_size >= qh + kh - 1
// extracts relative position embeddings for attention
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
GGML_API struct ggml_tensor * ggml_get_rel_pos(
struct ggml_context * ctx,
struct ggml_tensor * a,
int qh,
int kh);

// used in sam
GGML_API struct ggml_tensor * ggml_add_rel_pos(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
220 changes: 180 additions & 40 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8965,35 +8965,80 @@ static void ggml_compute_forward_win_part_f32(

const ggml_tensor * src0 = dst->src[0];

GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_UNARY_OP_LOCALS

const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
const int32_t w = ((const int32_t *)(dst->op_params))[2];
const int32_t bs = ((const int32_t *)(dst->op_params))[2];
const int32_t w = ((const int32_t *)(dst->op_params))[3];

assert(ne00 == ne0);
assert(ne3 == nep0*nep1);
GGML_ASSERT(ne00 == ne0);
GGML_ASSERT(ne3 == nep0*nep1*bs);

// TODO: optimize / multi-thread
for (int py = 0; py < nep1; ++py) {
for (int px = 0; px < nep0; ++px) {
const int64_t i3 = py*nep0 + px;
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int64_t i02 = py*w + i2;
const int64_t i01 = px*w + i1;
const int64_t i00 = i0;

const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;

if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
((float *) dst->data)[i] = 0.0f;
} else {
((float *) dst->data)[i] = ((float *) src0->data)[j];
}
for (int64_t i3 = 0; i3 < ne3; i3++) {
int px = i3 % nep0;
int py = (i3 / nep0) % nep1;
int b = i3 / (nep0 * nep1);
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int64_t i03 = b;
const int64_t i02 = py*w + i2;
const int64_t i01 = px*w + i1;
const int64_t i00 = i0;

const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;

if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
*((float *) dp) = 0;
} else {
*((float *) dp) = *((const float *) sp);
}
}
}
}
}
}

static void ggml_compute_forward_win_part_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
GGML_UNUSED(params);

const ggml_tensor * src0 = dst->src[0];

GGML_TENSOR_UNARY_OP_LOCALS

const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
const int32_t bs = ((const int32_t *)(dst->op_params))[2];
const int32_t w = ((const int32_t *)(dst->op_params))[3];

GGML_ASSERT(ne00 == ne0);
GGML_ASSERT(ne3 == nep0*nep1*bs);

// TODO: optimize / multi-thread
for (int64_t i3 = 0; i3 < ne3; i3++) {
int px = i3 % nep0;
int py = (i3 / nep0) % nep1;
int b = i3 / (nep0 * nep1);
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int64_t i03 = b;
const int64_t i02 = py*w + i2;
const int64_t i01 = px*w + i1;
const int64_t i00 = i0;

const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;

if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
*((ggml_fp16_t *) dp) = 0;
} else {
*((ggml_fp16_t *) dp) = *((const ggml_fp16_t *) sp);
}
}
}
Expand All @@ -9008,10 +9053,16 @@ void ggml_compute_forward_win_part(
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_I32:
case GGML_TYPE_F32:
{
ggml_compute_forward_win_part_f32(params, dst);
} break;
case GGML_TYPE_BF16:
case GGML_TYPE_F16:
{
ggml_compute_forward_win_part_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
Expand All @@ -9028,35 +9079,82 @@ static void ggml_compute_forward_win_unpart_f32(

const ggml_tensor * src0 = dst->src[0];

GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_UNARY_OP_LOCALS

const int32_t w = ((const int32_t *)(dst->op_params))[0];

// padding
const int px = (w - ne1%w)%w;
//const int py = (w - ne2%w)%w;
const int py = (w - ne2%w)%w;

const int npx = (px + ne1)/w;
//const int npy = (py + ne2)/w;
const int npy = (py + ne2)/w;

assert(ne0 == ne00);
assert(ne03 == npx*npy*ne3);

// TODO: optimize / multi-thread
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int ip2 = i2/w;
const int ip1 = i1/w;
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int ip2 = i2/w;
const int ip1 = i1/w;

const int64_t i03 = i3*npx*npy + ip2*npx + ip1;
const int64_t i02 = i2%w;
const int64_t i01 = i1%w;
const int64_t i00 = i0;

const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;

*((float *) dp) = *((const float *) sp);
}
}
}
}
}

static void ggml_compute_forward_win_unpart_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
GGML_UNUSED(params);

const ggml_tensor * src0 = dst->src[0];

GGML_TENSOR_UNARY_OP_LOCALS

const int32_t w = ((const int32_t *)(dst->op_params))[0];

const int64_t i02 = i2%w;
const int64_t i01 = i1%w;
const int64_t i00 = i0;
// padding
const int px = (w - ne1%w)%w;
const int py = (w - ne2%w)%w;

const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
const int npx = (px + ne1)/w;
const int npy = (py + ne2)/w;

((float *) dst->data)[j] = ((float *) src0->data)[i];
assert(ne0 == ne00);
assert(ne03 == npx*npy*ne3);

// TODO: optimize / multi-thread
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int ip2 = i2/w;
const int ip1 = i1/w;

const int64_t i03 = i3*npx*npy + ip2*npx + ip1;
const int64_t i02 = i2%w;
const int64_t i01 = i1%w;
const int64_t i00 = i0;

const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;

*((ggml_fp16_t *) dp) = *((const ggml_fp16_t *) sp);
}
}
}
}
Expand All @@ -9069,10 +9167,16 @@ void ggml_compute_forward_win_unpart(
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_I32:
case GGML_TYPE_F32:
{
ggml_compute_forward_win_unpart_f32(params, dst);
} break;
case GGML_TYPE_BF16:
case GGML_TYPE_F16:
{
ggml_compute_forward_win_unpart_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -9226,6 +9330,35 @@ void ggml_compute_forward_glu(

// ggml_compute_forward_get_rel_pos

static void ggml_compute_forward_get_rel_pos_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
GGML_UNUSED(params);

const ggml_tensor * src0 = dst->src[0];

// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322

GGML_TENSOR_UNARY_OP_LOCALS

const int64_t kh = ne1;
const int64_t qh = ne2;
const float k_scale = MAX((float)qh / kh, 1.0f);
const float q_scale = MAX((float)kh / qh, 1.0f);

float * src0_data = (float *) src0->data;
float * dst_data = (float *) dst->data;

for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
}
}
}
}

static void ggml_compute_forward_get_rel_pos_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
Expand All @@ -9237,14 +9370,17 @@ static void ggml_compute_forward_get_rel_pos_f16(

GGML_TENSOR_UNARY_OP_LOCALS

const int64_t w = ne1;
const int64_t kh = ne1;
const int64_t qh = ne2;
const float k_scale = MAX((float)qh / kh, 1.0f);
const float q_scale = MAX((float)kh / qh, 1.0f);

ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;

for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
const int64_t pos = (w - i1 - 1) + i2;
const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
}
Expand All @@ -9259,6 +9395,10 @@ void ggml_compute_forward_get_rel_pos(
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_get_rel_pos_f32(params, dst);
} break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
Expand Down
Loading