Skip to content

Commit 4ddb8a4

Browse files
committed
fix: correct scaling calculations in rope_cache_init
1 parent 407b408 commit 4ddb8a4

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

ggml/src/ggml-hexagon/htp/rope-ops.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,18 @@ static void rope_cache_init(const float theta_base,
9393
// Get n-d rotational scaling corrected for extrapolation
9494
float theta_interp = freq_scale * theta_extrap;
9595
float theta2 = theta_interp;
96+
float mscale2 = mscale;
9697

9798
if (ext_factor != 0.0f) {
9899
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
99100
theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
100101

101102
// Get n-d magnitude scaling corrected for interpolation
102-
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
103+
mscale2 *= 1.0f + 0.1f * logf(1.0f / freq_scale);
103104
}
104105

105-
cache[i0 + 0] = cosf(theta2) * mscale;
106-
cache[i0 + 1] = sinf(theta2) * mscale;
106+
cache[i0 + 0] = cosf(theta2) * mscale2;
107+
cache[i0 + 1] = sinf(theta2) * mscale2;
107108

108109
theta *= theta_scale;
109110
}
@@ -337,8 +338,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
337338
}
338339
}
339340

340-
src_loc += is_neox ? (rope_ctx->n_dims / 2) : 0;
341-
dst_data_loc += is_neox ? (rope_ctx->n_dims / 2) : 0;
341+
src_loc += (is_neox ? (rope_ctx->n_dims / 2) : 0);
342+
dst_data_loc += (is_neox ? (rope_ctx->n_dims / 2) : 0);
342343
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
343344
dst_data_loc[0] = src_loc[0];
344345
dst_data_loc[1] = src_loc[1];

0 commit comments

Comments
 (0)