diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 00419bcba6b..a4399704fcb 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -73,15 +73,15 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return (1 - MIN(1, MAX(0, y))); } -static void rope_cache_init(const float theta_base, - float freq_scale, - const float * freq_factors, - float * corr_dims, - uint32_t ne0, - float ext_factor, - float mscale, - float * cache, - float theta_scale) { +static void rope_cache_init(const float theta_base, + const float freq_scale, + const float * freq_factors, + float * corr_dims, + const uint32_t ne0, + const float ext_factor, + const float mscale, + float * cache, + const float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta = theta_base; @@ -92,18 +92,19 @@ static void rope_cache_init(const float theta_base, // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; - float theta2 = theta_interp; + float theta_final = theta_interp; + float mscale_final = mscale; if (ext_factor != 0.0f) { float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - cache[i0 + 0] = cosf(theta2) * mscale; - cache[i0 + 1] = sinf(theta2) * mscale; + cache[i0 + 0] = cosf(theta_final) * mscale_final; + cache[i0 + 1] = sinf(theta_final) * mscale_final; theta *= theta_scale; } @@ -151,9 +152,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context } static void hvx_calc_rope_neox_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { + float * restrict dst, + const int num_elems, + const float * restrict theta_cache) { // for (int i = 0; i < num_elems; i += 2) { //const float cos_theta = theta_cache[i + 0]; //const float sin_theta = theta_cache[i + 1]; @@ -192,7 +193,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0, HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); src0_curr += VLEN; @@ -259,7 +260,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, const uint32_t ir1, int nth, int ith, - int opt_path) { + const int opt_path) { struct htp_ops_context * octx = rope_ctx->octx; const struct htp_tensor * src0 = &octx->src0; @@ -267,8 +268,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - const int32_t mode = rope_ctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + const int32_t mode = rope_ctx->mode; + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; htp_rope_preamble; @@ -281,8 +282,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, freq_factors = (const float *) src2->data; } - int ir = 0; - + const uint32_t i1_end = MIN(ir1, ne1); + const int32_t half_dims = rope_ctx->n_dims / 2; + const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float); for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len const int32_t p = pos[i2]; @@ -290,14 +292,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); - for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads - if (ir++ < ir0) { - continue; - } - if (ir > ir1) { - break; - } - + for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); @@ -310,6 +305,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, } else { hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); } + + src_loc += rope_ctx->n_dims; + dst_data_loc += rope_ctx->n_dims; } else { for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { const float cos_theta = wp0[i0 + 0]; @@ -317,10 +315,10 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, if (is_neox) { const float x0 = src_loc[0]; - const float x1 = src_loc[rope_ctx->n_dims/2]; + const float x1 = src_loc[half_dims]; - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta; + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta; src_loc += 1; dst_data_loc += 1; @@ -335,15 +333,13 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, dst_data_loc += 2; } } - } - - for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) { - dst_data_loc[0] = src_loc[0]; - dst_data_loc[1] = src_loc[1]; - src_loc += 2; - dst_data_loc += 2; + src_loc += (is_neox ? half_dims : 0); + dst_data_loc += (is_neox ? half_dims : 0); } + + // TODO: use simd to speed up the remaining elements copy + memcpy(dst_data_loc, src_loc, remain_bytes); } } }