Skip to content

Commit 407b408

Browse files
committed
fix test failure
1 parent 6ab4e50 commit 407b408

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

ggml/src/ggml-hexagon/htp/rope-ops.c

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
151151
}
152152

153153
static void hvx_calc_rope_neox_f32(const float * restrict src0,
154-
float * restrict dst,
155-
const int num_elems,
156-
const float * restrict theta_cache) {
154+
float * restrict dst,
155+
const int num_elems,
156+
const float * restrict theta_cache) {
157157
// for (int i = 0; i < num_elems; i += 2) {
158158
//const float cos_theta = theta_cache[i + 0];
159159
//const float sin_theta = theta_cache[i + 1];
@@ -192,7 +192,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
192192
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
193193
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
194194

195-
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
195+
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
196196
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
197197

198198
src0_curr += VLEN;
@@ -259,16 +259,16 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
259259
const uint32_t ir1,
260260
int nth,
261261
int ith,
262-
int opt_path) {
262+
const int opt_path) {
263263
struct htp_ops_context * octx = rope_ctx->octx;
264264

265265
const struct htp_tensor * src0 = &octx->src0;
266266
const struct htp_tensor * src1 = &octx->src1;
267267
const struct htp_tensor * src2 = &octx->src2;
268268
struct htp_tensor * dst = &octx->dst;
269269

270-
const int32_t mode = rope_ctx->mode;
271-
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
270+
const int32_t mode = rope_ctx->mode;
271+
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
272272

273273
htp_rope_preamble;
274274

@@ -317,10 +317,10 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
317317

318318
if (is_neox) {
319319
const float x0 = src_loc[0];
320-
const float x1 = src_loc[rope_ctx->n_dims/2];
320+
const float x1 = src_loc[rope_ctx->n_dims / 2];
321321

322-
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
323-
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
322+
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
323+
dst_data_loc[rope_ctx->n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
324324

325325
src_loc += 1;
326326
dst_data_loc += 1;
@@ -337,6 +337,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
337337
}
338338
}
339339

340+
src_loc += is_neox ? (rope_ctx->n_dims / 2) : 0;
341+
dst_data_loc += is_neox ? (rope_ctx->n_dims / 2) : 0;
340342
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
341343
dst_data_loc[0] = src_loc[0];
342344
dst_data_loc[1] = src_loc[1];

0 commit comments

Comments
 (0)