@@ -114,6 +114,112 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
114114# pragma clang diagnostic pop
115115#endif // __clang__
116116
117+ // ======================
118+ // General Kernel for larger matrices
119+ // Uses a simpler approach with fixed tile size
120+ // ======================
121+ #define GENERAL_TILE_SIZE 32
122+
123+ template <int n_template, int k_template>
124+ static __global__ void solve_tri_f32_general (const float * __restrict__ A,
125+ const float * __restrict__ B,
126+ float * __restrict__ X,
127+ const uint3 ne02,
128+ const size_t nb02,
129+ const size_t nb03,
130+ const size_t nb12,
131+ const size_t nb13,
132+ const size_t nb2,
133+ const size_t nb3,
134+ const int n_arg,
135+ const int k_arg) {
136+ const int n = n_template == 0 ? n_arg : n_template;
137+ const int k = k_template == 0 ? k_arg : k_template;
138+
139+ const int batch_idx = blockIdx .x ;
140+ const int col_idx = blockIdx .y ;
141+ const int tid = threadIdx .x ;
142+
143+ if (col_idx >= k) {
144+ return ;
145+ }
146+
147+ const uint2 i02_i03 = fast_div_modulo (batch_idx, ne02);
148+ const int64_t i02 = i02_i03.y ;
149+ const int64_t i03 = i02_i03.x ;
150+
151+ const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
152+ const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
153+ float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
154+
155+ // Shared memory for current tile
156+ __shared__ float sA [GENERAL_TILE_SIZE * GENERAL_TILE_SIZE];
157+ __shared__ float sB [GENERAL_TILE_SIZE];
158+ __shared__ float sX [GENERAL_TILE_SIZE];
159+
160+ // Process in tiles
161+ for (int tile_start = 0 ; tile_start < n; tile_start += GENERAL_TILE_SIZE) {
162+ int tile_end = min (tile_start + GENERAL_TILE_SIZE, n);
163+ int tile_n = tile_end - tile_start;
164+
165+ // Load tile of A matrix
166+ for (int i = tid; i < tile_n * tile_n; i += blockDim .x ) {
167+ int local_row = i / tile_n;
168+ int local_col = i % tile_n;
169+ int global_row = tile_start + local_row;
170+ int global_col = tile_start + local_col;
171+
172+ if (global_col <= global_row) {
173+ sA [local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col];
174+ } else {
175+ sA [local_row * GENERAL_TILE_SIZE + local_col] = 0 .0f ;
176+ }
177+ }
178+
179+ __syncthreads ();
180+
181+ // Load corresponding part of B and initialize X
182+ if (tid < tile_n) {
183+ sB [tid] = B_batch[(tile_start + tid) * k + col_idx];
184+ sX [tid] = sB [tid];
185+ }
186+
187+ __syncthreads ();
188+
189+ // Forward substitution for this tile
190+ for (int row = 0 ; row < tile_n; ++row) {
191+ if (tid == row) {
192+ float sum = 0 .0f ;
193+
194+ // Sum contributions from previous rows in this tile
195+ for (int j = 0 ; j < row; ++j) {
196+ sum += sA [row * GENERAL_TILE_SIZE + j] * sX [j];
197+ }
198+
199+ // Sum contributions from previous tiles
200+ if (tile_start > 0 ) {
201+ int global_row = tile_start + row;
202+ for (int j = 0 ; j < tile_start; ++j) {
203+ sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx];
204+ }
205+ }
206+
207+ const float a_diag = sA [row * GENERAL_TILE_SIZE + row];
208+ sX [row] = (sB [row] - sum) / a_diag;
209+ }
210+ __syncthreads ();
211+ }
212+
213+ // Store results back to global memory
214+ if (tid < tile_n) {
215+ int global_row = tile_start + tid;
216+ X_batch[global_row * k + col_idx] = sX [tid];
217+ }
218+
219+ __syncthreads ();
220+ }
221+ }
222+
117223static void solve_tri_f32_cuda (const float * A,
118224 const float * B,
119225 float * X,
@@ -129,56 +235,68 @@ static void solve_tri_f32_cuda(const float * A,
129235 size_t nb3,
130236 cudaStream_t stream) {
131237 const uint3 ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
132- dim3 threads (WARP_SIZE, k);
133- dim3 grid (ne02 * ne03);
134- if (n == 64 ) {
135- switch (k) {
136- case 32 :
137- solve_tri_f32_fast<64 , 32 >
138- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
139- break ;
140- case 16 :
141- solve_tri_f32_fast<64 , 16 >
142- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
143- break ;
144- case 14 :
145- solve_tri_f32_fast<64 , 14 >
146- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
147- break ;
148- case 12 :
149- solve_tri_f32_fast<64 , 12 >
150- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
151- break ;
152- case 10 :
153- solve_tri_f32_fast<64 , 10 >
154- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
155- break ;
156- case 8 :
157- solve_tri_f32_fast<64 , 8 >
158- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
159- break ;
160- case 6 :
161- solve_tri_f32_fast<64 , 6 >
162- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
163- break ;
164- case 4 :
165- solve_tri_f32_fast<64 , 4 >
166- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
167- break ;
168- case 2 :
169- solve_tri_f32_fast<64 , 2 >
170- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
171- break ;
172- case 1 :
173- solve_tri_f32_fast<64 , 1 >
174- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
175- break ;
176- default :
177- solve_tri_f32_fast<0 , 0 >
178- <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
238+
239+ // Choose kernel based on matrix size
240+ if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
241+ // Use fast kernel for small matrices
242+ dim3 threads (WARP_SIZE, k);
243+ dim3 grid (ne02 * ne03);
244+ if (n == 64 ) {
245+ switch (k) {
246+ case 32 :
247+ solve_tri_f32_fast<64 , 32 >
248+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
249+ break ;
250+ case 16 :
251+ solve_tri_f32_fast<64 , 16 >
252+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
253+ break ;
254+ case 14 :
255+ solve_tri_f32_fast<64 , 14 >
256+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
257+ break ;
258+ case 12 :
259+ solve_tri_f32_fast<64 , 12 >
260+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
261+ break ;
262+ case 10 :
263+ solve_tri_f32_fast<64 , 10 >
264+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
265+ break ;
266+ case 8 :
267+ solve_tri_f32_fast<64 , 8 >
268+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
269+ break ;
270+ case 6 :
271+ solve_tri_f32_fast<64 , 6 >
272+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
273+ break ;
274+ case 4 :
275+ solve_tri_f32_fast<64 , 4 >
276+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
277+ break ;
278+ case 2 :
279+ solve_tri_f32_fast<64 , 2 >
280+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
281+ break ;
282+ case 1 :
283+ solve_tri_f32_fast<64 , 1 >
284+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0 , 0 );
285+ break ;
286+ default :
287+ solve_tri_f32_fast<0 , 0 >
288+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
289+ }
290+ } else { // run general case
291+ solve_tri_f32_fast<0 , 0 >
292+ <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
179293 }
180- } else { // run general case
181- solve_tri_f32_fast<0 , 0 >
294+ } else {
295+ // Use general kernel for larger matrices
296+ dim3 threads (256 , 1 ); // 256 threads per block
297+ dim3 grid (ne02 * ne03, k); // One block per column
298+
299+ solve_tri_f32_general<0 , 0 >
182300 <<<grid, threads, 0 , stream>>> (A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
183301 }
184302}
@@ -193,11 +311,8 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
193311 const int64_t n = src0->ne [0 ];
194312 const int64_t k = src1->ne [0 ];
195313
196- GGML_ASSERT (n <= 64 );
197- GGML_ASSERT (k <= 32 );
198-
199314 solve_tri_f32_cuda ((const float *) src0->data , (const float *) src1->data , (float *) dst->data , n, k, src0->ne [2 ],
200315 src0->ne [3 ], src0->nb [2 ] / sizeof (float ), src0->nb [3 ] / sizeof (float ),
201316 src1->nb [2 ] / sizeof (float ), src1->nb [3 ] / sizeof (float ), dst->nb [2 ] / sizeof (float ),
202317 dst->nb [3 ] / sizeof (float ), ctx.stream ());
203- }
318+ }
0 commit comments