Skip to content
Open
69 changes: 33 additions & 36 deletions ggml/src/ggml-cuda/solve_tri.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "solve_tri.cuh"

#define MAX_N_FAST 64
#define MAX_K_FAST 32

// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
Expand Down Expand Up @@ -48,7 +47,6 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);

__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];

const int offset = threadIdx.x + threadIdx.y * blockDim.x;

Expand All @@ -60,53 +58,52 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
}
}

const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
__syncthreads();

#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
}
}
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
Comment on lines +63 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
const float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
const float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;

Please use const wherever applicable so that one can easily tell which variables are subject to change in the future.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of those are changed in the #pragma unroll loops so they cant be const (;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you're right. And confusion like this is precisely why I want a clear and consistent distinction between const and non-const vatiables.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should i add a comment to clear things up?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment explaining the purposes of x_low and x_high would be nice to have but not required. The problem here was rather that I read the kernel top to bottom, wasn't sure whether these particular values are supposed to be constant, didn't see the part further down where they are modified (but saw that you are not consistently adding const where applicable), and then left this comment.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ive added const to every variable that should use it now (;


__syncthreads();
const int half = WARP_SIZE;
const int nrows_low = (n < half) ? n : half;

#pragma unroll
for (int row = 0; row < n; ++row) {
for (int row = 0; row < nrows_low; ++row) {
float sum = 0.0f;

{
int j = lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
if (lane < row) {
sum = fmaf(sA[row * n + lane], x_low, sum);
}
if (row >= WARP_SIZE) {
int j = WARP_SIZE + lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
sum = warp_reduce_sum(sum);

if (lane == row) {
float diag = sA[row * n + row];
float idiv = 1.0f / diag;
x_low = fmaf(sum, -idiv, x_low * idiv);
}
}

#pragma unroll
for (int row = half; row < n; ++row) {
float sum = fmaf(sA[row * n + lane], x_low, 0.0f);
int j = half + lane;
if (j < row) {
sum = fmaf(sA[row * n + j], x_high, sum);
}
sum = warp_reduce_sum(sum);

if (lane == 0) {
const float b_val = sXt[col_idx * n + row];
const float a_diag = sA[row * n + row];
// no safeguards for division by zero because that indicates corrupt
// data anyway
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
int updater = row - half;
if (lane == updater) {
float diag = sA[row * n + row];
float idiv = 1.0f / diag;
x_high = fmaf(sum, -idiv, x_high * idiv);
}
}

__syncthreads();

#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
#pragma unroll 2
for (int rr = 0; rr < 2; ++rr) {
int row = rr * WARP_SIZE + lane;
if (row < n) {
float val = (row < half) ? x_low : x_high;
X_batch[row * k + col_idx] = val;
}
}
}
Expand Down
Loading