diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 2e2b39720f..e161d4dc43 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -3,7 +3,6 @@ #include "solve_tri.cuh" #define MAX_N_FAST 64 -#define MAX_K_FAST 32 // ====================== // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction @@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; - __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; const int offset = threadIdx.x + threadIdx.y * blockDim.x; #pragma unroll for (int i = 0; i < n * n; i += k * WARP_SIZE) { - int i0 = i + offset; + const int i0 = i + offset; if (i0 < n * n) { sA[i0] = A_batch[i0]; } } - const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; + __syncthreads(); -#pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; - } - } + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; - __syncthreads(); + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; #pragma unroll - for (int row = 0; row < n; ++row) { + for (int row = 0; row < nrows_low; ++row) { float sum = 0.0f; - - { - int j = lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } + if (lane < row) { + sum += sA[row * n + lane] * x_low; } - if (row >= WARP_SIZE) { - int j = WARP_SIZE + lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } + sum = warp_reduce_sum(sum); + + if (lane == row) { + x_low = (x_low - sum) / sA[row * n + row]; } + } +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; + if (j < row) { + sum += sA[row * n + j] * x_high; + } sum = warp_reduce_sum(sum); - if (lane == 0) { - const float b_val = sXt[col_idx * n + row]; - const float a_diag = sA[row * n + row]; - // no safeguards for division by zero because that indicates corrupt - // data anyway - sXt[col_idx * n + row] = (b_val - sum) / a_diag; + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; } } - __syncthreads(); - #pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + for (int rr = 0; rr < 2; ++rr) { + const int row = rr * WARP_SIZE + lane; + if (row < n) { + const float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; } } }