Skip to content

Commit 366dbb7

Browse files
committed
Extended TRI
1 parent 96fe9ba commit 366dbb7

File tree

4 files changed

+172
-56
lines changed

4 files changed

+172
-56
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4621,7 +4621,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
46214621
case GGML_OP_TRI:
46224622
return true;
46234623
case GGML_OP_SOLVE_TRI:
4624-
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
4624+
return true;
46254625
default:
46264626
return false;
46274627
}

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 168 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,112 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
114114
# pragma clang diagnostic pop
115115
#endif // __clang__
116116

117+
// ======================
118+
// General Kernel for larger matrices
119+
// Uses a simpler approach with fixed tile size
120+
// ======================
121+
#define GENERAL_TILE_SIZE 32
122+
123+
template <int n_template, int k_template>
124+
static __global__ void solve_tri_f32_general(const float * __restrict__ A,
125+
const float * __restrict__ B,
126+
float * __restrict__ X,
127+
const uint3 ne02,
128+
const size_t nb02,
129+
const size_t nb03,
130+
const size_t nb12,
131+
const size_t nb13,
132+
const size_t nb2,
133+
const size_t nb3,
134+
const int n_arg,
135+
const int k_arg) {
136+
const int n = n_template == 0 ? n_arg : n_template;
137+
const int k = k_template == 0 ? k_arg : k_template;
138+
139+
const int batch_idx = blockIdx.x;
140+
const int col_idx = blockIdx.y;
141+
const int tid = threadIdx.x;
142+
143+
if (col_idx >= k) {
144+
return;
145+
}
146+
147+
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
148+
const int64_t i02 = i02_i03.y;
149+
const int64_t i03 = i02_i03.x;
150+
151+
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
152+
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
153+
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
154+
155+
// Shared memory for current tile
156+
__shared__ float sA[GENERAL_TILE_SIZE * GENERAL_TILE_SIZE];
157+
__shared__ float sB[GENERAL_TILE_SIZE];
158+
__shared__ float sX[GENERAL_TILE_SIZE];
159+
160+
// Process in tiles
161+
for (int tile_start = 0; tile_start < n; tile_start += GENERAL_TILE_SIZE) {
162+
int tile_end = min(tile_start + GENERAL_TILE_SIZE, n);
163+
int tile_n = tile_end - tile_start;
164+
165+
// Load tile of A matrix
166+
for (int i = tid; i < tile_n * tile_n; i += blockDim.x) {
167+
int local_row = i / tile_n;
168+
int local_col = i % tile_n;
169+
int global_row = tile_start + local_row;
170+
int global_col = tile_start + local_col;
171+
172+
if (global_col <= global_row) {
173+
sA[local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col];
174+
} else {
175+
sA[local_row * GENERAL_TILE_SIZE + local_col] = 0.0f;
176+
}
177+
}
178+
179+
__syncthreads();
180+
181+
// Load corresponding part of B and initialize X
182+
if (tid < tile_n) {
183+
sB[tid] = B_batch[(tile_start + tid) * k + col_idx];
184+
sX[tid] = sB[tid];
185+
}
186+
187+
__syncthreads();
188+
189+
// Forward substitution for this tile
190+
for (int row = 0; row < tile_n; ++row) {
191+
if (tid == row) {
192+
float sum = 0.0f;
193+
194+
// Sum contributions from previous rows in this tile
195+
for (int j = 0; j < row; ++j) {
196+
sum += sA[row * GENERAL_TILE_SIZE + j] * sX[j];
197+
}
198+
199+
// Sum contributions from previous tiles
200+
if (tile_start > 0) {
201+
int global_row = tile_start + row;
202+
for (int j = 0; j < tile_start; ++j) {
203+
sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx];
204+
}
205+
}
206+
207+
const float a_diag = sA[row * GENERAL_TILE_SIZE + row];
208+
sX[row] = (sB[row] - sum) / a_diag;
209+
}
210+
__syncthreads();
211+
}
212+
213+
// Store results back to global memory
214+
if (tid < tile_n) {
215+
int global_row = tile_start + tid;
216+
X_batch[global_row * k + col_idx] = sX[tid];
217+
}
218+
219+
__syncthreads();
220+
}
221+
}
222+
117223
static void solve_tri_f32_cuda(const float * A,
118224
const float * B,
119225
float * X,
@@ -129,56 +235,68 @@ static void solve_tri_f32_cuda(const float * A,
129235
size_t nb3,
130236
cudaStream_t stream) {
131237
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
132-
dim3 threads(WARP_SIZE, k);
133-
dim3 grid(ne02 * ne03);
134-
if (n == 64) {
135-
switch (k) {
136-
case 32:
137-
solve_tri_f32_fast<64, 32>
138-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
139-
break;
140-
case 16:
141-
solve_tri_f32_fast<64, 16>
142-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
143-
break;
144-
case 14:
145-
solve_tri_f32_fast<64, 14>
146-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
147-
break;
148-
case 12:
149-
solve_tri_f32_fast<64, 12>
150-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
151-
break;
152-
case 10:
153-
solve_tri_f32_fast<64, 10>
154-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
155-
break;
156-
case 8:
157-
solve_tri_f32_fast<64, 8>
158-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
159-
break;
160-
case 6:
161-
solve_tri_f32_fast<64, 6>
162-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
163-
break;
164-
case 4:
165-
solve_tri_f32_fast<64, 4>
166-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
167-
break;
168-
case 2:
169-
solve_tri_f32_fast<64, 2>
170-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
171-
break;
172-
case 1:
173-
solve_tri_f32_fast<64, 1>
174-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
175-
break;
176-
default:
177-
solve_tri_f32_fast<0, 0>
178-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
238+
239+
// Choose kernel based on matrix size
240+
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
241+
// Use fast kernel for small matrices
242+
dim3 threads(WARP_SIZE, k);
243+
dim3 grid(ne02 * ne03);
244+
if (n == 64) {
245+
switch (k) {
246+
case 32:
247+
solve_tri_f32_fast<64, 32>
248+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
249+
break;
250+
case 16:
251+
solve_tri_f32_fast<64, 16>
252+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
253+
break;
254+
case 14:
255+
solve_tri_f32_fast<64, 14>
256+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
257+
break;
258+
case 12:
259+
solve_tri_f32_fast<64, 12>
260+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
261+
break;
262+
case 10:
263+
solve_tri_f32_fast<64, 10>
264+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
265+
break;
266+
case 8:
267+
solve_tri_f32_fast<64, 8>
268+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
269+
break;
270+
case 6:
271+
solve_tri_f32_fast<64, 6>
272+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
273+
break;
274+
case 4:
275+
solve_tri_f32_fast<64, 4>
276+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
277+
break;
278+
case 2:
279+
solve_tri_f32_fast<64, 2>
280+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
281+
break;
282+
case 1:
283+
solve_tri_f32_fast<64, 1>
284+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
285+
break;
286+
default:
287+
solve_tri_f32_fast<0, 0>
288+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
289+
}
290+
} else { // run general case
291+
solve_tri_f32_fast<0, 0>
292+
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
179293
}
180-
} else { // run general case
181-
solve_tri_f32_fast<0, 0>
294+
} else {
295+
// Use general kernel for larger matrices
296+
dim3 threads(256, 1); // 256 threads per block
297+
dim3 grid(ne02 * ne03, k); // One block per column
298+
299+
solve_tri_f32_general<0, 0>
182300
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
183301
}
184302
}
@@ -193,11 +311,8 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
193311
const int64_t n = src0->ne[0];
194312
const int64_t k = src1->ne[0];
195313

196-
GGML_ASSERT(n <= 64);
197-
GGML_ASSERT(k <= 32);
198-
199314
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
200315
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
201316
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
202317
dst->nb[3] / sizeof(float), ctx.stream());
203-
}
318+
}

src/llama-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1388,7 +1388,7 @@ void llama_context::output_reorder() {
13881388

13891389
uint32_t llama_context::graph_max_nodes() const {
13901390
if (model.arch == LLM_ARCH_QWEN3NEXT) {
1391-
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
1391+
return std::max<uint32_t>(32768, 64u*model.n_tensors());
13921392
}
13931393
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
13941394
}

tests/test-backend-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7756,6 +7756,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
77567756
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
77577757
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
77587758
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
7759+
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 64, 128, 4, 4 }));
77597760

77607761
for (bool v : {false, true}) {
77617762
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
@@ -7953,7 +7954,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
79537954
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
79547955

79557956
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
7956-
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
7957+
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 32, 128, 4, 1 }));
79577958

79587959
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
79597960
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));

0 commit comments

Comments
 (0)