From 9e398502dfd6d8921fba88c9b464a9efc73060b9 Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Tue, 2 Dec 2025 17:34:06 +0100 Subject: [PATCH 1/9] ggml-cuda: optimize solve_tri_f32_fast and fix stride handling - Switch from using shared memory for the RHS/solution matrix to a register-based approach (x_low, x_high), reducing shared memory pressure and bank conflicts. - Implement explicit `fmaf` instructions for the reduction loop. - Update kernel arguments to pass strides in bytes rather than elements to align with standard ggml tensor arithmetic (casting to `char *` before addition). - Remove unused `MAX_K_FAST` definition. --- ggml/src/ggml-cuda/solve_tri.cu | 83 ++++++++++++++++----------------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 2e2b39720fb..99d529e85ee 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -3,7 +3,6 @@ #include "solve_tri.cuh" #define MAX_N_FAST 64 -#define MAX_K_FAST 32 // ====================== // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction @@ -43,12 +42,11 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, const int64_t i02 = i02_i03.y; const int64_t i03 = i02_i03.x; - const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); - const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); - float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + const float * const A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3); __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; - __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; const int offset = threadIdx.x + threadIdx.y * blockDim.x; @@ -60,53 +58,55 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, } } - const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; + __syncthreads(); -#pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; - } - } + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; - __syncthreads(); + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; + // Process lower rows #pragma unroll - for (int row = 0; row < n; ++row) { + for (int row = 0; row < nrows_low; ++row) { float sum = 0.0f; - - { - int j = lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } + if (lane < row) { + sum = fmaf(sA[row * n + lane], x_low, sum); } - if (row >= WARP_SIZE) { - int j = WARP_SIZE + lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } + sum = warp_reduce_sum(sum); + + if (lane == row) { + float diag = sA[row * n + row]; + float idiv = 1.0f / diag; + x_low = fmaf(sum, -idiv, x_low * idiv); } + } + // Process upper rows +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = fmaf(sA[row * n + lane], x_low, 0.0f); + int j = half + lane; + if (j < row) { + sum = fmaf(sA[row * n + j], x_high, sum); + } sum = warp_reduce_sum(sum); - if (lane == 0) { - const float b_val = sXt[col_idx * n + row]; - const float a_diag = sA[row * n + row]; - // no safeguards for division by zero because that indicates corrupt - // data anyway - sXt[col_idx * n + row] = (b_val - sum) / a_diag; + int updater = row - half; + if (lane == updater) { + float diag = sA[row * n + row]; + float idiv = 1.0f / diag; + x_high = fmaf(sum, -idiv, x_high * idiv); } } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + // Warp-wise store +#pragma unroll 2 + for (int rr = 0; rr < 2; ++rr) { + int row = rr * WARP_SIZE + lane; + if (row < n) { + float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; } } } @@ -197,7 +197,6 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) GGML_ASSERT(k <= 32); solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2], - src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), - src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), - dst->nb[3] / sizeof(float), ctx.stream()); + src0->ne[3], src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], dst->nb[2], dst->nb[3], + ctx.stream()); } From b14882fc54ebb214efc580ff708f63661f9e19f7 Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Tue, 2 Dec 2025 19:08:05 +0100 Subject: [PATCH 2/9] Small cleanup --- ggml/src/ggml-cuda/solve_tri.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 99d529e85ee..c85066981ea 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -42,9 +42,9 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, const int64_t i02 = i02_i03.y; const int64_t i03 = i02_i03.x; - const float * const A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03); - const float * const B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13); - float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3); + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; @@ -197,6 +197,7 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) GGML_ASSERT(k <= 32); solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2], - src0->ne[3], src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], dst->nb[2], dst->nb[3], - ctx.stream()); + src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); } From 68881efb0bd10bd3cd233691a561ebde6acfd7ff Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Tue, 2 Dec 2025 19:09:40 +0100 Subject: [PATCH 3/9] Remove comments in solve_tri.cu --- ggml/src/ggml-cuda/solve_tri.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index c85066981ea..5a66e1595bc 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -66,7 +66,6 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, const int half = WARP_SIZE; const int nrows_low = (n < half) ? n : half; - // Process lower rows #pragma unroll for (int row = 0; row < nrows_low; ++row) { float sum = 0.0f; @@ -82,7 +81,6 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, } } - // Process upper rows #pragma unroll for (int row = half; row < n; ++row) { float sum = fmaf(sA[row * n + lane], x_low, 0.0f); @@ -100,7 +98,6 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, } } - // Warp-wise store #pragma unroll 2 for (int rr = 0; rr < 2; ++rr) { int row = rr * WARP_SIZE + lane; From c55b5bf994adca5c4fd75205c8693d62c41fb596 Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Thu, 4 Dec 2025 20:10:43 +0100 Subject: [PATCH 4/9] Update ggml/src/ggml-cuda/solve_tri.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/solve_tri.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 5a66e1595bc..9be9f1380de 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -98,7 +98,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, } } -#pragma unroll 2 +#pragma unroll for (int rr = 0; rr < 2; ++rr) { int row = rr * WARP_SIZE + lane; if (row < n) { From 2fd926483529e1b17e61d0805121ee33954bfdd9 Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Thu, 4 Dec 2025 20:11:05 +0100 Subject: [PATCH 5/9] Update ggml/src/ggml-cuda/solve_tri.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/solve_tri.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 9be9f1380de..62438b7ea0d 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -70,7 +70,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, for (int row = 0; row < nrows_low; ++row) { float sum = 0.0f; if (lane < row) { - sum = fmaf(sA[row * n + lane], x_low, sum); + sum += sA[row * n + lane] * x_low; } sum = warp_reduce_sum(sum); From b27ce89a407f6ed41eed853d775e48e021e16d4e Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Thu, 4 Dec 2025 20:50:27 +0100 Subject: [PATCH 6/9] Update ggml/src/ggml-cuda/solve_tri.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/solve_tri.cu | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 62438b7ea0d..a999ba30d47 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -90,11 +90,8 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, } sum = warp_reduce_sum(sum); - int updater = row - half; - if (lane == updater) { - float diag = sA[row * n + row]; - float idiv = 1.0f / diag; - x_high = fmaf(sum, -idiv, x_high * idiv); + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; } } From ec9b6f97d142082bd96dc5b0ff9222405bd5ef29 Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Thu, 4 Dec 2025 22:32:16 +0100 Subject: [PATCH 7/9] Use const for variables in solve_tri.cu --- ggml/src/ggml-cuda/solve_tri.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index a999ba30d47..66ebef77a7d 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -52,7 +52,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, #pragma unroll for (int i = 0; i < n * n; i += k * WARP_SIZE) { - int i0 = i + offset; + const int i0 = i + offset; if (i0 < n * n) { sA[i0] = A_batch[i0]; } @@ -75,8 +75,8 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, sum = warp_reduce_sum(sum); if (lane == row) { - float diag = sA[row * n + row]; - float idiv = 1.0f / diag; + const float diag = sA[row * n + row]; + const float idiv = 1.0f / diag; x_low = fmaf(sum, -idiv, x_low * idiv); } } @@ -84,7 +84,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, #pragma unroll for (int row = half; row < n; ++row) { float sum = fmaf(sA[row * n + lane], x_low, 0.0f); - int j = half + lane; + const int j = half + lane; if (j < row) { sum = fmaf(sA[row * n + j], x_high, sum); } @@ -97,9 +97,9 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, #pragma unroll for (int rr = 0; rr < 2; ++rr) { - int row = rr * WARP_SIZE + lane; + const int row = rr * WARP_SIZE + lane; if (row < n) { - float val = (row < half) ? x_low : x_high; + const float val = (row < half) ? x_low : x_high; X_batch[row * k + col_idx] = val; } } From a34a45a34330774380c74ebcba4ef644735fd2c4 Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Thu, 4 Dec 2025 22:44:47 +0100 Subject: [PATCH 8/9] Replace fmaf with more readable code --- ggml/src/ggml-cuda/solve_tri.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 66ebef77a7d..9337958972d 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -83,10 +83,10 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, #pragma unroll for (int row = half; row < n; ++row) { - float sum = fmaf(sA[row * n + lane], x_low, 0.0f); + float sum = sA[row * n + lane] * x_low; const int j = half + lane; if (j < row) { - sum = fmaf(sA[row * n + j], x_high, sum); + sum += sA[row * n + j] * x_high; } sum = warp_reduce_sum(sum); From 4a637096bf08daf684d33b0cf2c8206bda49238c Mon Sep 17 00:00:00 2001 From: wsbagnsv1 Date: Fri, 5 Dec 2025 00:14:16 +0100 Subject: [PATCH 9/9] remove last fmaf --- ggml/src/ggml-cuda/solve_tri.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 9337958972d..e161d4dc436 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -75,9 +75,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, sum = warp_reduce_sum(sum); if (lane == row) { - const float diag = sA[row * n + row]; - const float idiv = 1.0f / diag; - x_low = fmaf(sum, -idiv, x_low * idiv); + x_low = (x_low - sum) / sA[row * n + row]; } }