Skip to content

Commit e9a02fd

Browse files
committed
wip
1 parent cfca78b commit e9a02fd

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

ggml/src/ggml-hexagon/htp/rope-ops.c

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
282282
freq_factors = (const float *) src2->data;
283283
}
284284

285-
int ir = 0;
286-
285+
int ir = 0;
286+
const int32_t half_dims = rope_ctx->n_dims / 2;
287287
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
288288
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
289289
const int32_t p = pos[i2];
@@ -311,17 +311,20 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
311311
} else {
312312
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
313313
}
314+
315+
src_loc += rope_ctx->n_dims;
316+
dst_data_loc += rope_ctx->n_dims;
314317
} else {
315318
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
316319
const float cos_theta = wp0[i0 + 0];
317320
const float sin_theta = wp0[i0 + 1];
318321

319322
if (is_neox) {
320323
const float x0 = src_loc[0];
321-
const float x1 = src_loc[rope_ctx->n_dims / 2];
324+
const float x1 = src_loc[half_dims];
322325

323-
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
324-
dst_data_loc[rope_ctx->n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
326+
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
327+
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
325328

326329
src_loc += 1;
327330
dst_data_loc += 1;
@@ -336,10 +339,11 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
336339
dst_data_loc += 2;
337340
}
338341
}
342+
343+
src_loc += (is_neox ? half_dims : 0);
344+
dst_data_loc += (is_neox ? half_dims : 0);
339345
}
340346

341-
src_loc += (is_neox ? (rope_ctx->n_dims / 2) : 0);
342-
dst_data_loc += (is_neox ? (rope_ctx->n_dims / 2) : 0);
343347
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
344348
dst_data_loc[0] = src_loc[0];
345349
dst_data_loc[1] = src_loc[1];

0 commit comments

Comments
 (0)