Skip to content

Commit ec9b6f9

Browse files
authored
Use const for variables in solve_tri.cu
1 parent 12d108a commit ec9b6f9

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
5252

5353
#pragma unroll
5454
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
55-
int i0 = i + offset;
55+
const int i0 = i + offset;
5656
if (i0 < n * n) {
5757
sA[i0] = A_batch[i0];
5858
}
@@ -75,16 +75,16 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
7575
sum = warp_reduce_sum(sum);
7676

7777
if (lane == row) {
78-
float diag = sA[row * n + row];
79-
float idiv = 1.0f / diag;
78+
const float diag = sA[row * n + row];
79+
const float idiv = 1.0f / diag;
8080
x_low = fmaf(sum, -idiv, x_low * idiv);
8181
}
8282
}
8383

8484
#pragma unroll
8585
for (int row = half; row < n; ++row) {
8686
float sum = fmaf(sA[row * n + lane], x_low, 0.0f);
87-
int j = half + lane;
87+
const int j = half + lane;
8888
if (j < row) {
8989
sum = fmaf(sA[row * n + j], x_high, sum);
9090
}
@@ -97,9 +97,9 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
9797

9898
#pragma unroll
9999
for (int rr = 0; rr < 2; ++rr) {
100-
int row = rr * WARP_SIZE + lane;
100+
const int row = rr * WARP_SIZE + lane;
101101
if (row < n) {
102-
float val = (row < half) ? x_low : x_high;
102+
const float val = (row < half) ? x_low : x_high;
103103
X_batch[row * k + col_idx] = val;
104104
}
105105
}

0 commit comments

Comments
 (0)