@@ -151,9 +151,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
151151}
152152
153153static void hvx_calc_rope_neox_f32 (const float * restrict src0 ,
154- float * restrict dst ,
155- const int num_elems ,
156- const float * restrict theta_cache ) {
154+ float * restrict dst ,
155+ const int num_elems ,
156+ const float * restrict theta_cache ) {
157157 // for (int i = 0; i < num_elems; i += 2) {
158158 //const float cos_theta = theta_cache[i + 0];
159159 //const float sin_theta = theta_cache[i + 1];
@@ -192,7 +192,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
192192 HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32 (vx0_c , vx1_s );
193193 HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32 (vx0_s , vx1_c );
194194
195- * (HVX_Vector * ) dst_curr = Q6_Vsf_equals_Vqf32 (v4 );
195+ * (HVX_Vector * ) dst_curr = Q6_Vsf_equals_Vqf32 (v4 );
196196 * (HVX_Vector * ) (dst_curr + half_size ) = Q6_Vsf_equals_Vqf32 (v5 );
197197
198198 src0_curr += VLEN ;
@@ -259,16 +259,16 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
259259 const uint32_t ir1 ,
260260 int nth ,
261261 int ith ,
262- int opt_path ) {
262+ const int opt_path ) {
263263 struct htp_ops_context * octx = rope_ctx -> octx ;
264264
265265 const struct htp_tensor * src0 = & octx -> src0 ;
266266 const struct htp_tensor * src1 = & octx -> src1 ;
267267 const struct htp_tensor * src2 = & octx -> src2 ;
268268 struct htp_tensor * dst = & octx -> dst ;
269269
270- const int32_t mode = rope_ctx -> mode ;
271- const bool is_neox = mode & HTP_ROPE_TYPE_NEOX ;
270+ const int32_t mode = rope_ctx -> mode ;
271+ const bool is_neox = mode & HTP_ROPE_TYPE_NEOX ;
272272
273273 htp_rope_preamble ;
274274
@@ -317,10 +317,10 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
317317
318318 if (is_neox ) {
319319 const float x0 = src_loc [0 ];
320- const float x1 = src_loc [rope_ctx -> n_dims / 2 ];
320+ const float x1 = src_loc [rope_ctx -> n_dims / 2 ];
321321
322- dst_data_loc [0 ] = x0 * cos_theta - x1 * sin_theta ;
323- dst_data_loc [rope_ctx -> n_dims / 2 ] = x0 * sin_theta + x1 * cos_theta ;
322+ dst_data_loc [0 ] = x0 * cos_theta - x1 * sin_theta ;
323+ dst_data_loc [rope_ctx -> n_dims / 2 ] = x0 * sin_theta + x1 * cos_theta ;
324324
325325 src_loc += 1 ;
326326 dst_data_loc += 1 ;
@@ -337,6 +337,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
337337 }
338338 }
339339
340+ src_loc += is_neox ? (rope_ctx -> n_dims / 2 ) : 0 ;
341+ dst_data_loc += is_neox ? (rope_ctx -> n_dims / 2 ) : 0 ;
340342 for (uint32_t i0 = rope_ctx -> n_dims ; i0 < ne0 ; i0 += 2 ) {
341343 dst_data_loc [0 ] = src_loc [0 ];
342344 dst_data_loc [1 ] = src_loc [1 ];
0 commit comments