@@ -52,7 +52,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
5252
5353#pragma unroll
5454 for (int i = 0 ; i < n * n; i += k * WARP_SIZE) {
55- int i0 = i + offset;
55+ const int i0 = i + offset;
5656 if (i0 < n * n) {
5757 sA [i0] = A_batch[i0];
5858 }
@@ -75,16 +75,16 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
7575 sum = warp_reduce_sum (sum);
7676
7777 if (lane == row) {
78- float diag = sA [row * n + row];
79- float idiv = 1 .0f / diag;
78+ const float diag = sA [row * n + row];
79+ const float idiv = 1 .0f / diag;
8080 x_low = fmaf (sum, -idiv, x_low * idiv);
8181 }
8282 }
8383
8484#pragma unroll
8585 for (int row = half; row < n; ++row) {
8686 float sum = fmaf (sA [row * n + lane], x_low, 0 .0f );
87- int j = half + lane;
87+ const int j = half + lane;
8888 if (j < row) {
8989 sum = fmaf (sA [row * n + j], x_high, sum);
9090 }
@@ -97,9 +97,9 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
9797
9898#pragma unroll
9999 for (int rr = 0 ; rr < 2 ; ++rr) {
100- int row = rr * WARP_SIZE + lane;
100+ const int row = rr * WARP_SIZE + lane;
101101 if (row < n) {
102- float val = (row < half) ? x_low : x_high;
102+ const float val = (row < half) ? x_low : x_high;
103103 X_batch[row * k + col_idx] = val;
104104 }
105105 }
0 commit comments