Skip to content

Commit 5814b4d

Browse files
cuda: optimize SOLVE_TRI using registers and FMAF (ggml-org#17703)
* ggml-cuda: optimize solve_tri_f32_fast and fix stride handling - Switch from using shared memory for the RHS/solution matrix to a register-based approach (x_low, x_high), reducing shared memory pressure and bank conflicts. - Implement explicit `fmaf` instructions for the reduction loop. - Update kernel arguments to pass strides in bytes rather than elements to align with standard ggml tensor arithmetic (casting to `char *` before addition). - Remove unused `MAX_K_FAST` definition. * Small cleanup * Remove comments in solve_tri.cu * Update ggml/src/ggml-cuda/solve_tri.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/solve_tri.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/solve_tri.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Use const for variables in solve_tri.cu * Replace fmaf with more readable code * remove last fmaf --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent 79d6189 commit 5814b4d

File tree

1 file changed

+28
-36
lines changed

1 file changed

+28
-36
lines changed

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "solve_tri.cuh"
44

55
#define MAX_N_FAST 64
6-
#define MAX_K_FAST 32
76

87
// ======================
98
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
@@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
4847
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
4948

5049
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
51-
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
5250

5351
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
5452

5553
#pragma unroll
5654
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
57-
int i0 = i + offset;
55+
const int i0 = i + offset;
5856
if (i0 < n * n) {
5957
sA[i0] = A_batch[i0];
6058
}
6159
}
6260

63-
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
61+
__syncthreads();
6462

65-
#pragma unroll
66-
for (int i = 0; i < rows_per_warp; i++) {
67-
const int i0 = lane + i * WARP_SIZE;
68-
if (i0 < n) {
69-
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
70-
}
71-
}
63+
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
64+
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
7265

73-
__syncthreads();
66+
const int half = WARP_SIZE;
67+
const int nrows_low = (n < half) ? n : half;
7468

7569
#pragma unroll
76-
for (int row = 0; row < n; ++row) {
70+
for (int row = 0; row < nrows_low; ++row) {
7771
float sum = 0.0f;
78-
79-
{
80-
int j = lane;
81-
if (j < row) {
82-
sum += sA[row * n + j] * sXt[col_idx * n + j];
83-
}
72+
if (lane < row) {
73+
sum += sA[row * n + lane] * x_low;
8474
}
85-
if (row >= WARP_SIZE) {
86-
int j = WARP_SIZE + lane;
87-
if (j < row) {
88-
sum += sA[row * n + j] * sXt[col_idx * n + j];
89-
}
75+
sum = warp_reduce_sum(sum);
76+
77+
if (lane == row) {
78+
x_low = (x_low - sum) / sA[row * n + row];
9079
}
80+
}
9181

82+
#pragma unroll
83+
for (int row = half; row < n; ++row) {
84+
float sum = sA[row * n + lane] * x_low;
85+
const int j = half + lane;
86+
if (j < row) {
87+
sum += sA[row * n + j] * x_high;
88+
}
9289
sum = warp_reduce_sum(sum);
9390

94-
if (lane == 0) {
95-
const float b_val = sXt[col_idx * n + row];
96-
const float a_diag = sA[row * n + row];
97-
// no safeguards for division by zero because that indicates corrupt
98-
// data anyway
99-
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
91+
if (lane == row - half) {
92+
x_high = (x_high - sum) / sA[row * n + row];
10093
}
10194
}
10295

103-
__syncthreads();
104-
10596
#pragma unroll
106-
for (int i = 0; i < rows_per_warp; i++) {
107-
const int i0 = lane + i * WARP_SIZE;
108-
if (i0 < n) {
109-
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
97+
for (int rr = 0; rr < 2; ++rr) {
98+
const int row = rr * WARP_SIZE + lane;
99+
if (row < n) {
100+
const float val = (row < half) ? x_low : x_high;
101+
X_batch[row * k + col_idx] = val;
110102
}
111103
}
112104
}

0 commit comments

Comments
 (0)