Skip to content

Commit cd0e3a7

Browse files
authored
SOLVE_TRI CUDA kernel for small matrices (#17457)
1 parent efaaccd commit cd0e3a7

File tree

4 files changed

+215
-0
lines changed

4 files changed

+215
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "ggml-cuda/set.cuh"
5454
#include "ggml-cuda/set-rows.cuh"
5555
#include "ggml-cuda/pad_reflect_1d.cuh"
56+
#include "ggml-cuda/solve_tri.cuh"
5657
#include "ggml.h"
5758

5859
#include <algorithm>
@@ -2717,6 +2718,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
27172718
case GGML_OP_OPT_STEP_SGD:
27182719
ggml_cuda_opt_step_sgd(ctx, dst);
27192720
break;
2721+
case GGML_OP_SOLVE_TRI:
2722+
ggml_cuda_op_solve_tri(ctx, dst);
2723+
break;
27202724
default:
27212725
return false;
27222726
}
@@ -4255,6 +4259,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
42554259
case GGML_OP_OPT_STEP_ADAMW:
42564260
case GGML_OP_OPT_STEP_SGD:
42574261
return true;
4262+
case GGML_OP_SOLVE_TRI:
4263+
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
42584264
default:
42594265
return false;
42604266
}

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#include "common.cuh"
2+
#include "ggml.h"
3+
#include "solve_tri.cuh"
4+
5+
#define MAX_N_FAST 64
6+
#define MAX_K_FAST 32
7+
8+
// ======================
9+
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
10+
// ======================
11+
// When ncols_template == 0 the bounds for the loops in this function are not
12+
// known and can't be unrolled. As we want to keep pragma unroll for all other
13+
// cases we supress the clang transformation warning here.
14+
#ifdef __clang__
15+
# pragma clang diagnostic push
16+
# pragma clang diagnostic ignored "-Wpass-failed"
17+
#endif // __clang__
18+
template <int n_template, int k_template>
19+
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
20+
const float * __restrict__ B,
21+
float * __restrict__ X,
22+
const uint3 ne02,
23+
const size_t nb02,
24+
const size_t nb03,
25+
const size_t nb12,
26+
const size_t nb13,
27+
const size_t nb2,
28+
const size_t nb3,
29+
const int n_arg,
30+
const int k_arg) {
31+
const int n = n_template == 0 ? n_arg : n_template;
32+
const int k = k_template == 0 ? k_arg : k_template;
33+
34+
const int batch_idx = blockIdx.x;
35+
const int lane = threadIdx.x;
36+
const int col_idx = threadIdx.y;
37+
38+
if (col_idx >= k) {
39+
return;
40+
}
41+
42+
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
43+
const int64_t i02 = i02_i03.y;
44+
const int64_t i03 = i02_i03.x;
45+
46+
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
47+
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
48+
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
49+
50+
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
51+
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
52+
53+
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
54+
55+
#pragma unroll
56+
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
57+
int i0 = i + offset;
58+
if (i0 < n * n) {
59+
sA[i0] = A_batch[i0];
60+
}
61+
}
62+
63+
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
64+
65+
#pragma unroll
66+
for (int i = 0; i < rows_per_warp; i++) {
67+
const int i0 = lane + i * WARP_SIZE;
68+
if (i0 < n) {
69+
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
70+
}
71+
}
72+
73+
__syncthreads();
74+
75+
#pragma unroll
76+
for (int row = 0; row < n; ++row) {
77+
float sum = 0.0f;
78+
79+
{
80+
int j = lane;
81+
if (j < row) {
82+
sum += sA[row * n + j] * sXt[col_idx * n + j];
83+
}
84+
}
85+
if (row >= WARP_SIZE) {
86+
int j = WARP_SIZE + lane;
87+
if (j < row) {
88+
sum += sA[row * n + j] * sXt[col_idx * n + j];
89+
}
90+
}
91+
92+
sum = warp_reduce_sum(sum);
93+
94+
if (lane == 0) {
95+
const float b_val = sXt[col_idx * n + row];
96+
const float a_diag = sA[row * n + row];
97+
// no safeguards for division by zero because that indicates corrupt
98+
// data anyway
99+
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
100+
}
101+
}
102+
103+
__syncthreads();
104+
105+
#pragma unroll
106+
for (int i = 0; i < rows_per_warp; i++) {
107+
const int i0 = lane + i * WARP_SIZE;
108+
if (i0 < n) {
109+
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
110+
}
111+
}
112+
}
113+
#ifdef __clang__
114+
# pragma clang diagnostic pop
115+
#endif // __clang__
116+
117+
static void solve_tri_f32_cuda(const float * A,
118+
const float * B,
119+
float * X,
120+
int n,
121+
int k,
122+
int64_t ne02,
123+
int64_t ne03,
124+
size_t nb02,
125+
size_t nb03,
126+
size_t nb12,
127+
size_t nb13,
128+
size_t nb2,
129+
size_t nb3,
130+
cudaStream_t stream) {
131+
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
132+
dim3 threads(WARP_SIZE, k);
133+
dim3 grid(ne02 * ne03);
134+
if (n == 64) {
135+
switch (k) {
136+
case 32:
137+
solve_tri_f32_fast<64, 32>
138+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
139+
break;
140+
case 16:
141+
solve_tri_f32_fast<64, 16>
142+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
143+
break;
144+
case 14:
145+
solve_tri_f32_fast<64, 14>
146+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
147+
break;
148+
case 12:
149+
solve_tri_f32_fast<64, 12>
150+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
151+
break;
152+
case 10:
153+
solve_tri_f32_fast<64, 10>
154+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
155+
break;
156+
case 8:
157+
solve_tri_f32_fast<64, 8>
158+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
159+
break;
160+
case 6:
161+
solve_tri_f32_fast<64, 6>
162+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
163+
break;
164+
case 4:
165+
solve_tri_f32_fast<64, 4>
166+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
167+
break;
168+
case 2:
169+
solve_tri_f32_fast<64, 2>
170+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
171+
break;
172+
case 1:
173+
solve_tri_f32_fast<64, 1>
174+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
175+
break;
176+
default:
177+
solve_tri_f32_fast<0, 0>
178+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
179+
}
180+
} else { // run general case
181+
solve_tri_f32_fast<0, 0>
182+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
183+
}
184+
}
185+
186+
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
187+
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
188+
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
189+
190+
ggml_is_contiguous(src0);
191+
ggml_is_contiguous(src1);
192+
193+
const int64_t n = src0->ne[0];
194+
const int64_t k = src1->ne[0];
195+
196+
GGML_ASSERT(n <= 64);
197+
GGML_ASSERT(k <= 32);
198+
199+
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
200+
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
201+
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
202+
dst->nb[3] / sizeof(float), ctx.stream());
203+
}

ggml/src/ggml-cuda/solve_tri.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7935,6 +7935,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
79357935
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
79367936
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
79377937

7938+
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
7939+
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
7940+
79387941
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
79397942
for (ggml_type type_a : all_types) {
79407943
for (ggml_type type_b : {GGML_TYPE_F32}) {

0 commit comments

Comments
 (0)