-
Notifications
You must be signed in to change notification settings - Fork 13.9k
SOLVE_TRI CUDA kernel for small matrices #17457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
0e6fd86
4836963
42d6d58
084d650
002d26e
e21a0f8
b2d870e
376d4be
4e8524c
baa5813
c5cd33a
6b11712
f19cdf8
3a24c92
6bf2328
18fb138
ea4dc88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,240 @@ | ||||
| #include "common.cuh" | ||||
| #include "ggml.h" | ||||
| #include "solve_tri.cuh" | ||||
|
|
||||
| #define MAX_N_FAST 64 | ||||
| #define MAX_K_FAST 32 | ||||
|
|
||||
| // ====================== | ||||
| // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction | ||||
| // ====================== | ||||
| template <int n, int k> | ||||
| static __global__ void solve_tri_f32_fast( | ||||
| const float* __restrict__ A, | ||||
| const float* __restrict__ B, | ||||
| float* __restrict__ X, | ||||
| const uint3 ne02, | ||||
| const size_t nb02, const size_t nb03, | ||||
| const size_t nb12, const size_t nb13, | ||||
| const size_t nb2, const size_t nb3) { | ||||
| const int batch_idx = blockIdx.x; | ||||
| const int lane = threadIdx.x; | ||||
| const int col_idx = threadIdx.y; | ||||
|
|
||||
| // A block processes one batch, k warps process k columns | ||||
| if (col_idx >= k) { | ||||
| return; | ||||
| } | ||||
|
|
||||
| const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); | ||||
| const int64_t i02 = i02_i03.y; | ||||
| const int64_t i03 = i02_i03.x; | ||||
|
|
||||
| const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03); | ||||
| const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13); | ||||
| float* X_batch = (float*) ((char *)X + i02 * nb2 + i03 * nb3); | ||||
|
|
||||
|
|
||||
| __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; | ||||
| __shared__ float sX[MAX_N_FAST * MAX_K_FAST]; | ||||
|
|
||||
| const int offset = threadIdx.x + threadIdx.y * blockDim.x; | ||||
| // Load A into shared memory (coalesced) | ||||
| #pragma unroll | ||||
| for (int i = 0; i < n * n; i += k * WARP_SIZE) { | ||||
| int i0 = i + offset; | ||||
| sA[i0] = A_batch[i0]; | ||||
| } | ||||
|
|
||||
| // Load B into shared memory (coalesced) | ||||
pwilkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| #pragma unroll | ||||
| for (int i = 0; i < n * k; i += k * WARP_SIZE) { | ||||
| int i0 = i + threadIdx.x + threadIdx.y * blockDim.x; | ||||
| sX[i0] = B_batch[i0]; | ||||
| } | ||||
| __syncthreads(); | ||||
|
|
||||
| // Each warp (32 threads with same col_idx) solves one column | ||||
| for (int row = 0; row < n; ++row) { | ||||
| float sum = 0.0f; | ||||
|
|
||||
| // Parallel reduction for sum | ||||
| for (int j = lane; j < row; j += WARP_SIZE) { | ||||
| sum += sA[row * n + j] * sX[j * k + col_idx]; | ||||
| } | ||||
|
|
||||
| sum = warp_reduce_sum(sum); | ||||
|
|
||||
| // Lane 0 computes and stores the final result for the current row | ||||
| if (lane == 0) { | ||||
| const float b_val = sX[row * k + col_idx]; // Value from B | ||||
| const float a_diag = sA[row * n + row]; | ||||
| if (a_diag != 0.0f) { | ||||
| sX[row * k + col_idx] = (b_val - sum) / a_diag; | ||||
| } else { | ||||
| sX[row * k + col_idx] = 0.0f; // Avoid division by zero | ||||
| } | ||||
| } | ||||
| // Sync threads in block to make sure the result of sX is visible to all threads for the next row | ||||
| __syncthreads(); | ||||
| } | ||||
|
|
||||
| // Write results from shared memory to global memory (coalesced) | ||||
| #pragma unroll | ||||
| for (int i = 0; i < n * k; i += k * WARP_SIZE) { | ||||
| const int i0 = i + threadIdx.x + threadIdx.y*blockDim.x; | ||||
| X_batch[i0] = sX[i0]; | ||||
| } | ||||
| } | ||||
|
|
||||
| static __global__ void solve_tri_f32_fast_general( | ||||
| const float* __restrict__ A, | ||||
| const float* __restrict__ B, | ||||
| float* __restrict__ X, | ||||
| const uint3 ne02, | ||||
| const size_t nb02, const size_t nb03, | ||||
| const size_t nb12, const size_t nb13, | ||||
| const size_t nb2, const size_t nb3, | ||||
| const int n, const int k) { | ||||
| const int batch_idx = blockIdx.x; | ||||
| const int lane = threadIdx.x; | ||||
| const int col_idx = threadIdx.y; | ||||
|
|
||||
| // A block processes one batch, k warps process k columns | ||||
| if (col_idx >= k) { | ||||
| return; | ||||
| } | ||||
|
|
||||
| const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); | ||||
| const int64_t i02 = i02_i03.y; | ||||
| const int64_t i03 = i02_i03.x; | ||||
|
|
||||
| const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03); | ||||
| const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13); | ||||
| float* X_batch = (float*) ((char *)X + i02 * nb2 + i03 * nb3); | ||||
|
|
||||
| __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; | ||||
| __shared__ float sX[MAX_N_FAST * MAX_K_FAST]; | ||||
|
|
||||
| // Load A into shared memory (coalesced) | ||||
| #pragma unroll | ||||
|
||||
| #pragma unroll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this whole function can go away, it should be something like
if constexpr(n == 0) {
//take this path
} else {
#pragma unroll
//the fast loop
}
pwilkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
pwilkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
pwilkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
pwilkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
pwilkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
pwilkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| #include "common.cuh" | ||
|
|
||
| #define CUDA_SOLVE_TRI_BLOCK_SIZE 256 | ||
|
||
|
|
||
| void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | ||
Uh oh!
There was an error while loading. Please reload this page.