|
| 1 | +#include "hadamard.cuh" |
| 2 | + |
| 3 | +template <int nh> |
| 4 | +static __global__ void hadamard_f32(const char * src, char * dst, int ne0, |
| 5 | + size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3) { |
| 6 | + |
| 7 | + constexpr float ksqrt2 = 0.707106781f; |
| 8 | + |
| 9 | + int nc = ne0/nh; |
| 10 | + int ii1 = blockIdx.x; |
| 11 | + int i1 = ii1 / nc; |
| 12 | + int ic = ii1 % nc; |
| 13 | + int i2 = blockIdx.y; |
| 14 | + int i3 = blockIdx.z; |
| 15 | + |
| 16 | + int tid = threadIdx.x; |
| 17 | + |
| 18 | + const float * x = (const float *)((const char *)src + i1*nb01 + i2*nb02 + i3*nb03) + ic*nh; |
| 19 | + float * y = ( float *)((const char *)dst + i1*nb1 + i2*nb2 + i3*nb3) + ic*nh; |
| 20 | + |
| 21 | + __shared__ float ys[nh]; |
| 22 | + |
| 23 | + ys[2*tid+0] = x[2*tid+0] + x[2*tid+1]; |
| 24 | + ys[2*tid+1] = x[2*tid+0] - x[2*tid+1]; |
| 25 | + |
| 26 | + float scale = ksqrt2; |
| 27 | + |
| 28 | +#pragma unroll |
| 29 | + for (int h = 2; h < nh; h <<= 2) { |
| 30 | + __syncthreads(); |
| 31 | + int ii = tid/h, jj = tid%h; |
| 32 | + int j = 2*h*ii+jj; |
| 33 | + float u = ys[j], v = ys[j+h]; |
| 34 | + ys[j+0] = u + v; |
| 35 | + ys[j+h] = u - v; |
| 36 | + scale *= ksqrt2; |
| 37 | + } |
| 38 | + |
| 39 | + __syncthreads(); |
| 40 | + y[2*tid+0] = ys[2*tid+0] * scale; |
| 41 | + y[2*tid+1] = ys[2*tid+1] * scale; |
| 42 | +} |
| 43 | + |
| 44 | +static void hadamard_f32_cuda(int nh, const char * x, char * y, int ne0, int ne1, int ne2, int ne3, |
| 45 | + size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3, cudaStream_t stream) { |
| 46 | + int nc = ne0/nh; |
| 47 | + int nrows = nc*ne1; |
| 48 | + dim3 num_blocks = dim3(nrows, ne2, ne3); |
| 49 | + switch (nh) { |
| 50 | + case 64: hadamard_f32< 64><<<num_blocks, 32, 0, stream>>>(x, y, ne0, nb01, nb02, nb03, nb1, nb2, nb3); break; |
| 51 | + case 128: hadamard_f32<128><<<num_blocks, 64, 0, stream>>>(x, y, ne0, nb01, nb02, nb03, nb1, nb2, nb3); break; |
| 52 | + case 256: hadamard_f32<256><<<num_blocks, 128, 0, stream>>>(x, y, ne0, nb01, nb02, nb03, nb1, nb2, nb3); break; |
| 53 | + default: GGML_ABORT("Unsupported Hadamard block size"); |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +#if defined(_MSC_VER) |
| 58 | +#pragma warning(disable: 4244 4267) // possible loss of data |
| 59 | +#include <intrin.h> |
| 60 | +#include <ammintrin.h> |
| 61 | +#include <nmmintrin.h> |
| 62 | +#include <immintrin.h> |
| 63 | +#include <stdlib.h> |
| 64 | +static inline int popcount(uint32_t x) { return __popcnt(x); } |
| 65 | +#else |
| 66 | +static inline int popcount(uint32_t x) { return __builtin_popcount(x); } |
| 67 | +#endif |
| 68 | + |
| 69 | + |
| 70 | +void ggml_cuda_op_hadamard(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 71 | + const ggml_tensor * src = dst->src[0]; |
| 72 | + GGML_ASSERT(src->type == GGML_TYPE_F32); |
| 73 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 74 | + GGML_ASSERT(ggml_are_same_shape(src, dst)); |
| 75 | + |
| 76 | + int nh = dst->op_params[0]; |
| 77 | + GGML_ASSERT(dst->ne[0]%nh == 0); |
| 78 | + GGML_ASSERT(nh > 1 && popcount(nh) == 1); |
| 79 | + |
| 80 | + hadamard_f32_cuda(nh, (const char *)src->data, (char *)dst->data, src->ne[0], src->ne[1], src->ne[2], src->ne[3], |
| 81 | + src->nb[1], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3], ctx.stream()); |
| 82 | +} |
0 commit comments