@@ -282,8 +282,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
282282 freq_factors = (const float * ) src2 -> data ;
283283 }
284284
285- int ir = 0 ;
286-
285+ int ir = 0 ;
286+ const int32_t half_dims = rope_ctx -> n_dims / 2 ;
287287 for (uint32_t i3 = 0 ; i3 < ne3 ; i3 ++ ) { // batch
288288 for (uint32_t i2 = 0 ; i2 < ne2 ; i2 ++ ) { // seq-len
289289 const int32_t p = pos [i2 ];
@@ -311,17 +311,20 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
311311 } else {
312312 hvx_calc_rope_f32 (src_loc , dst_data_loc , rope_ctx -> n_dims , wp0 );
313313 }
314+
315+ src_loc += rope_ctx -> n_dims ;
316+ dst_data_loc += rope_ctx -> n_dims ;
314317 } else {
315318 for (uint32_t i0 = 0 ; i0 < rope_ctx -> n_dims ; i0 += 2 ) {
316319 const float cos_theta = wp0 [i0 + 0 ];
317320 const float sin_theta = wp0 [i0 + 1 ];
318321
319322 if (is_neox ) {
320323 const float x0 = src_loc [0 ];
321- const float x1 = src_loc [rope_ctx -> n_dims / 2 ];
324+ const float x1 = src_loc [half_dims ];
322325
323- dst_data_loc [0 ] = x0 * cos_theta - x1 * sin_theta ;
324- dst_data_loc [rope_ctx -> n_dims / 2 ] = x0 * sin_theta + x1 * cos_theta ;
326+ dst_data_loc [0 ] = x0 * cos_theta - x1 * sin_theta ;
327+ dst_data_loc [half_dims ] = x0 * sin_theta + x1 * cos_theta ;
325328
326329 src_loc += 1 ;
327330 dst_data_loc += 1 ;
@@ -336,10 +339,11 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
336339 dst_data_loc += 2 ;
337340 }
338341 }
342+
343+ src_loc += (is_neox ? half_dims : 0 );
344+ dst_data_loc += (is_neox ? half_dims : 0 );
339345 }
340346
341- src_loc += (is_neox ? (rope_ctx -> n_dims / 2 ) : 0 );
342- dst_data_loc += (is_neox ? (rope_ctx -> n_dims / 2 ) : 0 );
343347 for (uint32_t i0 = rope_ctx -> n_dims ; i0 < ne0 ; i0 += 2 ) {
344348 dst_data_loc [0 ] = src_loc [0 ];
345349 dst_data_loc [1 ] = src_loc [1 ];
0 commit comments