Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9346,14 +9346,15 @@ static void ggml_compute_forward_get_rel_pos_f32(
const float k_scale = MAX((float)qh / kh, 1.0f);
const float q_scale = MAX((float)kh / qh, 1.0f);

float * src0_data = (float *) src0->data;
float * dst_data = (float *) dst->data;
const char * src0_d = (const char *) src0->data;
char * dst_d = (char *) dst->data;

for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
const float val = *(const float *) (src0_d + pos*nb01 + i0*nb00);
*(float *) (dst_d + i2*nb2 + i1*nb1 + i0*nb0) = val;
}
}
}
Expand All @@ -9375,14 +9376,15 @@ static void ggml_compute_forward_get_rel_pos_f16(
const float k_scale = MAX((float)qh / kh, 1.0f);
const float q_scale = MAX((float)kh / qh, 1.0f);

ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
const char * src0_d = (const char *) src0->data;
char * dst_d = (char *) dst->data;

for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
const ggml_fp16_t val = *(const ggml_fp16_t *) (src0_d + pos*nb01 + i0*nb00);
*(ggml_fp16_t *) (dst_d + i2*nb2 + i1*nb1 + i0*nb0) = val;
}
}
}
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2729,6 +2729,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
break;
case GGML_OP_GET_REL_POS:
ggml_cuda_op_get_rel_pos(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
Expand Down
68 changes: 39 additions & 29 deletions ggml/src/ggml-cuda/rel-pos.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,30 @@
#include "ggml.h"
#include "ggml-cuda/rel-pos.cuh"


template <typename T>
__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) {
int kh = gridDim.x;
int qh = gridDim.y;
float k_scale = MAX((float)qh / kh, 1.0f);
float q_scale = MAX((float)kh / qh, 1.0f);
__global__ static void get_rel_pos_kernel(const void * src, void * dst,
int C, int kh, int qh,
int nb00, int nb01,
int nb0, int nb1, int nb2) {
int ki = blockIdx.x;
int qi = blockIdx.y;
int pos = int(qi*q_scale - ki*k_scale + (kh - 1)*k_scale);

int s0 = C;
int s1 = C * kh;
if (ki >= kh || qi >= qh) {
return;
}

float k_scale = MAX((float) qh / kh, 1.0f);
float q_scale = MAX((float) kh / qh, 1.0f);

int pos = int(qi * q_scale - ki * k_scale + (kh - 1) * k_scale);

const char * src_d = (const char *) src;
char * dst_d = (char *) dst;

for (int ci = threadIdx.x; ci < C; ci += blockDim.x) {
((T *) dst)[qi*s1 + ki*s0 + ci] = ((const T *) src)[pos*C + ci];
const int src_offset = pos * nb01 + ci * nb00;
const int dst_offset = qi * nb2 + ki * nb1 + ci * nb0;
*(T *) (dst_d + dst_offset) = *(const T *) (src_d + src_offset);
}
}

Expand All @@ -44,26 +52,28 @@ void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst
int kh = ne1;
int qh = ne2;

int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C)));
dim3 grid { (unsigned int)kh, (unsigned int)qh, 1 };
int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C)));
dim3 grid{ (unsigned int) kh, (unsigned int) qh };

const void * src0_d = (const void *)src0->data;
void * dst_d = (void *)dst->data;
const void * src0_d = (const void *) src0->data;
void * dst_d = (void *) dst->data;
cudaStream_t stream = ctx.stream();

switch (src0->type)
{
case GGML_TYPE_F32:
get_rel_pos_kernel<float><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
break;
case GGML_TYPE_F16:
get_rel_pos_kernel<half><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
break;
case GGML_TYPE_BF16:
get_rel_pos_kernel<nv_bfloat16><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
break;
default:
GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type));
break;
switch (src0->type) {
case GGML_TYPE_F32:
get_rel_pos_kernel<float>
<<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2);
break;
case GGML_TYPE_F16:
get_rel_pos_kernel<half>
<<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2);
break;
case GGML_TYPE_BF16:
get_rel_pos_kernel<nv_bfloat16>
<<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2);
break;
default:
GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type));
break;
}
}
}
15 changes: 8 additions & 7 deletions ggml/src/ggml-cuda/win.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "common.cuh"
#include "ggml.h"
#include "convert.cuh"
#include "ggml-cuda/win.cuh"
#include "ggml.h"

/*

Expand Down Expand Up @@ -28,7 +29,7 @@ static void ggml_compute_forward_win_part_f16(
for (int64_t i3 = 0; i3 < ne3; i3++) {
int px = i3 % nep0;
int py = (i3 / nep0) % nep1;
int b = i3 / (nep0 * nep1);
int b = i3 / (nep0 * nep1);
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
Expand All @@ -38,7 +39,7 @@ static void ggml_compute_forward_win_part_f16(
const int64_t i00 = i0;

void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;

if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
*((ggml_fp16_t *) dp) = 0;
Expand Down Expand Up @@ -138,7 +139,7 @@ __global__ static void win_part_kernel(
if (py*p.w + i2 >= p.ne2 || px*p.w + i1 >= p.ne1) {
for (int i0 = threadIdx.x; i0 < p.C; i0 += blockDim.x) {
char * dp = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
*((T *) dp) = 0;
*((T *) dp) = ggml_cuda_cast<T>(0.0f);
}
return;
}
Expand Down Expand Up @@ -210,7 +211,7 @@ static unsigned int round_to_pow2(unsigned int v) {
v++;

return v;
}
}

void ggml_cuda_op_win_part(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
Expand Down Expand Up @@ -297,12 +298,12 @@ static void ggml_compute_forward_win_unpart_f16(
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int ip2 = i2/w;
const int ip1 = i1/w;

const int64_t i03 = i3*npx*npy + ip2*npx + ip1;
const int64_t i02 = i2%w;
const int64_t i01 = i1%w;
const int64_t i00 = i0;

void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;

Expand Down
12 changes: 12 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7871,8 +7871,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_get_rel_pos(type, 13, 7, 7, v));
// Square large: 14x14 attention
test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 14, v));
// Square large: 16x16 attention
test_cases.emplace_back(new test_get_rel_pos(type, 31, 16, 16, v));
// Rectangular: 14x7 attention
test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 7, v));
// Rectangular: 7x14 attention
test_cases.emplace_back(new test_get_rel_pos(type, 27, 7, 14, v));
// Rectangular: 16x8 attention
test_cases.emplace_back(new test_get_rel_pos(type, 31, 16, 8, v));
// Rectangular: 8x16 attention
test_cases.emplace_back(new test_get_rel_pos(type, 31, 8, 16, v));
// Rectangular: 28x14 attention
test_cases.emplace_back(new test_get_rel_pos(type, 55, 28, 14, v));
// Rectangular: 14x28 attention
test_cases.emplace_back(new test_get_rel_pos(type, 55, 14, 28, v));
// Edge case: 1x1 attention (minimum)
test_cases.emplace_back(new test_get_rel_pos(type, 1, 1, 1, v));
}
Expand Down