@@ -92,19 +92,19 @@ static void rope_cache_init(const float theta_base,
9292
9393 // Get n-d rotational scaling corrected for extrapolation
9494 float theta_interp = freq_scale * theta_extrap ;
95- float theta2 = theta_interp ;
96- float mscale2 = mscale ;
95+ float theta_final = theta_interp ;
96+ float mscale_final = mscale ;
9797
9898 if (ext_factor != 0.0f ) {
9999 float ramp_mix = rope_yarn_ramp (corr_dims [0 ], corr_dims [1 ], i0 ) * ext_factor ;
100- theta2 = theta_interp * (1 - ramp_mix ) + theta_extrap * ramp_mix ;
100+ theta_final = theta_interp * (1 - ramp_mix ) + theta_extrap * ramp_mix ;
101101
102102 // Get n-d magnitude scaling corrected for interpolation
103- mscale2 *= 1.0f + 0.1f * logf (1.0f / freq_scale );
103+ mscale_final *= 1.0f + 0.1f * logf (1.0f / freq_scale );
104104 }
105105
106- cache [i0 + 0 ] = cosf (theta2 ) * mscale2 ;
107- cache [i0 + 1 ] = sinf (theta2 ) * mscale2 ;
106+ cache [i0 + 0 ] = cosf (theta_final ) * mscale_final ;
107+ cache [i0 + 1 ] = sinf (theta_final ) * mscale_final ;
108108
109109 theta *= theta_scale ;
110110 }
@@ -282,7 +282,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
282282 freq_factors = (const float * ) src2 -> data ;
283283 }
284284
285- const uint32_t i0_end = MIN (ir1 , ne1 );
285+ const uint32_t i1_end = MIN (ir1 , ne1 );
286286 const int32_t half_dims = rope_ctx -> n_dims / 2 ;
287287 for (uint32_t i3 = 0 ; i3 < ne3 ; i3 ++ ) { // batch
288288 for (uint32_t i2 = 0 ; i2 < ne2 ; i2 ++ ) { // seq-len
@@ -291,7 +291,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
291291 rope_cache_init (p , rope_ctx -> freq_scale , freq_factors , rope_ctx -> corr_dims , ne0 , rope_ctx -> ext_factor ,
292292 rope_ctx -> attn_factor , wp0 , rope_ctx -> theta_scale );
293293
294- for (uint32_t i1 = ir0 ; i1 < i0_end ; i1 ++ ) { // attn-heads
294+ for (uint32_t i1 = ir0 ; i1 < i1_end ; i1 ++ ) { // attn-heads
295295 const float * src = (float * ) ((char * ) src0 -> data + i3 * nb03 + i2 * nb02 + i1 * nb01 );
296296 float * dst_data = (float * ) ((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 );
297297
0 commit comments