Skip to content

Commit b6cc848

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 83bba6d + b715342 commit b6cc848

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "ggml-cuda/set-rows.cuh"
4848
#include "ggml-cuda/argmax.cuh"
4949
#include "ggml-cuda/multiadd.cuh"
50+
#include "ggml-cuda/hadamard.cuh"
5051

5152
#include <algorithm>
5253
#include <array>
@@ -2958,6 +2959,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
29582959
case GGML_OP_ARGMAX:
29592960
ggml_cuda_argmax(ctx, dst);
29602961
break;
2962+
case GGML_OP_HADAMARD:
2963+
ggml_cuda_op_hadamard(ctx, dst);
2964+
break;
29612965
case GGML_OP_REPEAT:
29622966
ggml_cuda_op_repeat(ctx, dst);
29632967
break;
@@ -4060,6 +4064,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
40604064
} break;
40614065
case GGML_OP_ARGMAX:
40624066
return true;
4067+
case GGML_OP_HADAMARD:
4068+
return (op->ne[0] == 64 || op->ne[0] == 128 || op->ne[0] == 256) && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
40634069
case GGML_OP_DUP:
40644070
case GGML_OP_REPEAT:
40654071
case GGML_OP_CONCAT:

ggml/src/ggml-cuda/hadamard.cu

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "hadamard.cuh"
2+
3+
template <int nh>
4+
static __global__ void hadamard_f32(const char * src, char * dst, int ne0,
5+
size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3) {
6+
7+
constexpr float ksqrt2 = 0.707106781f;
8+
9+
int nc = ne0/nh;
10+
int ii1 = blockIdx.x;
11+
int i1 = ii1 / nc;
12+
int ic = ii1 % nc;
13+
int i2 = blockIdx.y;
14+
int i3 = blockIdx.z;
15+
16+
int tid = threadIdx.x;
17+
18+
const float * x = (const float *)((const char *)src + i1*nb01 + i2*nb02 + i3*nb03) + ic*nh;
19+
float * y = ( float *)((const char *)dst + i1*nb1 + i2*nb2 + i3*nb3) + ic*nh;
20+
21+
__shared__ float ys[nh];
22+
23+
ys[2*tid+0] = x[2*tid+0] + x[2*tid+1];
24+
ys[2*tid+1] = x[2*tid+0] - x[2*tid+1];
25+
26+
float scale = ksqrt2;
27+
28+
#pragma unroll
29+
for (int h = 2; h < nh; h <<= 2) {
30+
__syncthreads();
31+
int ii = tid/h, jj = tid%h;
32+
int j = 2*h*ii+jj;
33+
float u = ys[j], v = ys[j+h];
34+
ys[j+0] = u + v;
35+
ys[j+h] = u - v;
36+
scale *= ksqrt2;
37+
}
38+
39+
__syncthreads();
40+
y[2*tid+0] = ys[2*tid+0] * scale;
41+
y[2*tid+1] = ys[2*tid+1] * scale;
42+
}
43+
44+
static void hadamard_f32_cuda(int nh, const char * x, char * y, int ne0, int ne1, int ne2, int ne3,
45+
size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3, cudaStream_t stream) {
46+
int nc = ne0/nh;
47+
int nrows = nc*ne1;
48+
dim3 num_blocks = dim3(nrows, ne2, ne3);
49+
switch (nh) {
50+
case 64: hadamard_f32< 64><<<num_blocks, 32, 0, stream>>>(x, y, ne0, nb01, nb02, nb03, nb1, nb2, nb3); break;
51+
case 128: hadamard_f32<128><<<num_blocks, 64, 0, stream>>>(x, y, ne0, nb01, nb02, nb03, nb1, nb2, nb3); break;
52+
case 256: hadamard_f32<256><<<num_blocks, 128, 0, stream>>>(x, y, ne0, nb01, nb02, nb03, nb1, nb2, nb3); break;
53+
default: GGML_ABORT("Unsupported Hadamard block size");
54+
}
55+
}
56+
57+
#if defined(_MSC_VER)
58+
#pragma warning(disable: 4244 4267) // possible loss of data
59+
#include <intrin.h>
60+
#include <ammintrin.h>
61+
#include <nmmintrin.h>
62+
#include <immintrin.h>
63+
#include <stdlib.h>
64+
static inline int popcount(uint32_t x) { return __popcnt(x); }
65+
#else
66+
static inline int popcount(uint32_t x) { return __builtin_popcount(x); }
67+
#endif
68+
69+
70+
void ggml_cuda_op_hadamard(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
71+
const ggml_tensor * src = dst->src[0];
72+
GGML_ASSERT(src->type == GGML_TYPE_F32);
73+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
74+
GGML_ASSERT(ggml_are_same_shape(src, dst));
75+
76+
int nh = dst->op_params[0];
77+
GGML_ASSERT(dst->ne[0]%nh == 0);
78+
GGML_ASSERT(nh > 1 && popcount(nh) == 1);
79+
80+
hadamard_f32_cuda(nh, (const char *)src->data, (char *)dst->data, src->ne[0], src->ne[1], src->ne[2], src->ne[3],
81+
src->nb[1], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3], ctx.stream());
82+
}

ggml/src/ggml-cuda/hadamard.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_hadamard(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)