@@ -2251,12 +2251,12 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
22512251 int sections[4 ],
22522252 bool mrope_used,
22532253 bool is_imrope,
2254- bool indep_sects) {
2255- ggml_tensor * src0 = dst-> src [ 0 ]; // input
2254+ bool indep_sects,
2255+ int64_t rope_dims) {
22562256 ggml_tensor * src1 = dst->src [1 ]; // position
22572257 ggml_tensor * src2 = dst->src [2 ]; // freq_factors
22582258
2259- int64_t theta_scale_length = src0-> ne [ 0 ] / 2 ;
2259+ int64_t theta_scale_length = rope_dims / 2 ;
22602260 int64_t position_length = dst->ne [2 ];
22612261
22622262 // TODO: check theta_scale_length and position_length.
@@ -2331,18 +2331,17 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
23312331 ACL_CHECK (aclrtMemcpyAsync (ctx.rope_cache .theta_scale_cache , theta_scale_length * sizeof (float ),
23322332 ctx.rope_cache .theta_scale_exp_host , theta_scale_length * sizeof (float ),
23332333 ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream ()));
2334-
2335- acl_theta_scale_tensor = ggml_cann_create_tensor (ctx.rope_cache .theta_scale_cache , ACL_FLOAT, sizeof (float ),
2336- theta_scale_ne, theta_scale_nb, 1 );
23372334 }
2335+ acl_theta_scale_tensor = ggml_cann_create_tensor (ctx.rope_cache .theta_scale_cache , ACL_FLOAT, sizeof (float ),
2336+ theta_scale_ne, theta_scale_nb, 1 );
23382337
23392338 // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
2339+ // TODO: acl_yarn_ramp_tensor use rope cache.
23402340 bool yarn_ramp_tensor_updated = false ;
23412341 ggml_cann_pool_alloc yarn_ramp_allocator (ctx.pool ());
23422342 acl_tensor_ptr acl_yarn_ramp_tensor;
2343- if (ext_factor != 0 &&
2344- // TODO: check more parameter.
2345- (ctx.rope_cache .theta_scale_length != theta_scale_length || ctx.rope_cache .freq_scale != freq_scale)) {
2343+ if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache .theta_scale_length != theta_scale_length ||
2344+ ctx.rope_cache .freq_scale != freq_scale)) {
23462345 yarn_ramp_tensor_updated = true ;
23472346
23482347 // -rope_yarn_ramp
@@ -2590,7 +2589,7 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
25902589 aclnn_muls (ctx, acl_cos_tensor.get (), attn_factor, nullptr , true );
25912590 }
25922591
2593- int64_t sin_reshape_ne[4 ] = { src0-> ne [ 0 ] , 1 , dst->ne [2 ], 1 };
2592+ int64_t sin_reshape_ne[4 ] = { rope_dims , 1 , dst->ne [2 ], 1 };
25942593 size_t sin_reshape_nb[GGML_MAX_DIMS];
25952594 sin_reshape_nb[0 ] = sizeof (float );
25962595 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
@@ -2645,7 +2644,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
26452644
26462645 // param
26472646 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2648- int sections[4 ];
2647+ int sections[4 ];
26492648 // const int n_past = ((int32_t *) dst->op_params)[0];
26502649 const int n_dims = ((int32_t *) dst->op_params )[1 ];
26512650 const int mode = ((int32_t *) dst->op_params )[2 ];
@@ -2654,44 +2653,60 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
26542653
26552654 GGML_TENSOR_UNARY_OP_LOCALS
26562655
2657- memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
2658- memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
2659- memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
2660- memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
2661- memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
2662- memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
2663- memcpy (§ions, (int32_t *) dst->op_params + 11 , sizeof (int )* 4 );
2656+ memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
2657+ memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
2658+ memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
2659+ memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
2660+ memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
2661+ memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
2662+ memcpy (§ions, (int32_t *) dst->op_params + 11 , sizeof (int ) * 4 );
26642663
2665- // TODO: n_dims <= ne0
2666- GGML_ASSERT (n_dims == ne0);
26672664 GGML_ASSERT (n_dims % 2 == 0 );
2665+ GGML_ASSERT (n_dims <= ne00);
26682666
26692667 const float theta_scale = powf (freq_base, -2 .0f / n_dims);
26702668
26712669 float corr_dims[2 ];
26722670 ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
26732671
2674- bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
2675- const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
2676- const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
2677- const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
2672+ bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
2673+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
2674+ // mrope_used means the GGML_ROPE_TYPE_MROPE bit is set.
2675+ // Note: this bit is also set for imrope and some vision modes,
2676+ // so mrope_used does NOT exclusively indicate pure mrope.
2677+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
2678+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
26782679
26792680 if (mrope_used) {
26802681 GGML_ASSERT (sections[0 ] > 0 || sections[1 ] > 0 || sections[2 ] > 0 );
26812682 }
26822683
26832684 if (is_vision) {
2684- GGML_ASSERT (n_dims == ne0/ 2 );
2685+ GGML_ASSERT (n_dims == ne0 / 2 );
26852686 }
26862687
26872688 if (is_imrope || mrope_used) {
26882689 is_neox = true ;
26892690 }
26902691
2692+ int64_t rope_dims = n_dims;
2693+
2694+ // Our current RotaryPositionEmbedding does not support the VISION mode,
2695+ // but essentially it only modifies theta_base in mrope,
2696+ // then repeats it at the end in the same way as is_neox.
2697+ // In fact, RoPE is still applied across all dimensions.
2698+ if (is_vision) {
2699+ rope_dims = src0->ne [0 ];
2700+ }
2701+ int64_t tail_dims = ne00 - rope_dims;
2702+ bool has_tail = tail_dims > 0 ;
2703+
26912704 // init ctx.rope_cos/rope_sin cache
2692- aclnn_rope_cache_init (ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
2705+ aclnn_rope_cache_init (ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
2706+ mrope_used, is_imrope, is_vision, rope_dims);
26932707
2694- int64_t sin_reshape_ne[4 ] = { ne00, 1 , ne02, 1 };
2708+ // Cache is generated with ne00 dimensions, so we use ne00 for reshape
2709+ int64_t sin_reshape_ne[4 ] = { rope_dims, 1 , ne02, 1 };
26952710 size_t sin_reshape_nb[GGML_MAX_DIMS];
26962711 sin_reshape_nb[0 ] = sizeof (float );
26972712 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
@@ -2704,7 +2719,6 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
27042719
27052720 acl_tensor_ptr acl_src = ggml_cann_create_tensor (src0);
27062721 acl_tensor_ptr acl_dst = ggml_cann_create_tensor (dst);
2707-
27082722#ifdef ASCEND_310P
27092723 // Special ROPE operation for 310P
27102724
@@ -2844,46 +2858,124 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
28442858 }
28452859 return ;
28462860#endif
2847-
28482861 int64_t acl_mode = is_neox ? 0 : 1 ;
28492862
2850- switch (src0->type ) {
2851- case GGML_TYPE_F32:
2852- {
2853- GGML_CANN_CALL_ACLNN_OP (ctx, RotaryPositionEmbedding, acl_src.get (), acl_cos_reshape_tensor.get (),
2854- acl_sin_reshape_tensor.get (), acl_mode, acl_dst.get ());
2855- break ;
2856- }
2857- case GGML_TYPE_F16:
2858- {
2859- ggml_cann_pool_alloc src_trans_allocator (ctx.pool (), ggml_nelements (src0) * sizeof (float ));
2860- void * src_trans_buffer = src_trans_allocator.get ();
2861- ggml_cann_pool_alloc dst_trans_allocator (ctx.pool (), ggml_nelements (dst) * sizeof (float ));
2862- void * dst_trans_buffer = dst_trans_allocator.get ();
2863+ // Pre-define head and tail dimensions for reuse
2864+ int64_t head_ne[GGML_MAX_DIMS] = { rope_dims, ne01, ne02, ne03 };
2865+ int64_t tail_ne[GGML_MAX_DIMS] = { tail_dims, ne01, ne02, ne03 };
2866+
2867+ // Step 1: Prepare trans tensors for F16 type conversion to F32 if needed
2868+ bool src_dst_need_trans = false ;
2869+ ggml_cann_pool_alloc src_trans_allocator (ctx.pool ());
2870+ ggml_cann_pool_alloc dst_trans_allocator (ctx.pool ());
2871+ acl_tensor_ptr acl_src_trans_tensor;
2872+ acl_tensor_ptr acl_dst_trans_tensor;
2873+ void * src_trans_buffer = nullptr ;
2874+ void * dst_trans_buffer = nullptr ;
2875+ size_t src_dst_trans_nb[GGML_MAX_DIMS];
2876+ if (src0->type == GGML_TYPE_F16) {
2877+ src_dst_need_trans = true ;
2878+ src_trans_buffer = src_trans_allocator.alloc (ggml_nelements (src0) * sizeof (float ));
2879+ dst_trans_buffer = dst_trans_allocator.alloc (ggml_nelements (dst) * sizeof (float ));
2880+
2881+ src_dst_trans_nb[0 ] = sizeof (float );
2882+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2883+ src_dst_trans_nb[i] = src_dst_trans_nb[i - 1 ] * src0->ne [i - 1 ];
2884+ }
2885+ acl_src_trans_tensor = ggml_cann_create_tensor (src_trans_buffer, ACL_FLOAT, sizeof (float ), src0->ne ,
2886+ src_dst_trans_nb, GGML_MAX_DIMS);
2887+ acl_dst_trans_tensor = ggml_cann_create_tensor (dst_trans_buffer, ACL_FLOAT, sizeof (float ), dst->ne ,
2888+ src_dst_trans_nb, GGML_MAX_DIMS);
2889+ aclnn_cast (ctx, acl_src.get (), acl_src_trans_tensor.get (), ACL_FLOAT);
2890+ }
2891+
2892+ // Step 2: Prepare head tensors for tail splitting if needed
2893+ acl_tensor_ptr acl_src_head;
2894+ acl_tensor_ptr acl_dst_head;
2895+ if (has_tail) {
2896+ // Create head views for RotaryPositionEmbedding (only first rope_dims dimensions)
2897+ // RotaryPositionEmbedding requires contiguous dst tensor, so we use a temporary buffer
2898+ if (src_dst_need_trans) {
2899+ // Use F32 trans tensor strides
2900+ acl_src_head = ggml_cann_create_tensor ((char *) src_trans_buffer, ACL_FLOAT, sizeof (float ), head_ne,
2901+ src_dst_trans_nb, GGML_MAX_DIMS);
2902+ } else {
2903+ // Use original F32 tensor strides
2904+ acl_src_head = ggml_cann_create_tensor ((char *) src0->data , ACL_FLOAT, sizeof (float ), head_ne, src0->nb ,
2905+ GGML_MAX_DIMS);
2906+ }
28632907
2864- size_t src_trans_nb[GGML_MAX_DIMS];
2865- src_trans_nb[0 ] = sizeof (float );
2866- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2867- src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
2868- }
2908+ int64_t head_elements = rope_dims * ne01 * ne02 * ne03;
2909+ ggml_cann_pool_alloc dst_head_contiguous_allocator (ctx.pool (), head_elements * sizeof (float ));
2910+ void * dst_head_contiguous_buffer = dst_head_contiguous_allocator.get ();
28692911
2870- acl_tensor_ptr acl_src_trans_tensor = ggml_cann_create_tensor (
2871- src_trans_buffer, ACL_FLOAT, sizeof (float ), src0->ne , src_trans_nb, GGML_MAX_DIMS);
2872- acl_tensor_ptr acl_dst_trans_tensor = ggml_cann_create_tensor (
2873- dst_trans_buffer, ACL_FLOAT, sizeof (float ), dst->ne , src_trans_nb, GGML_MAX_DIMS);
2912+ size_t head_contiguous_nb[GGML_MAX_DIMS];
2913+ head_contiguous_nb[0 ] = sizeof (float );
2914+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2915+ head_contiguous_nb[i] = head_contiguous_nb[i - 1 ] * head_ne[i - 1 ];
2916+ }
2917+ acl_dst_head = ggml_cann_create_tensor (dst_head_contiguous_buffer, ACL_FLOAT, sizeof (float ), head_ne,
2918+ head_contiguous_nb, GGML_MAX_DIMS);
2919+ }
28742920
2875- aclnn_cast (ctx, acl_src.get (), acl_src_trans_tensor.get (), ACL_FLOAT);
2921+ // Step 3: Execute RotaryPositionEmbedding
2922+ if (has_tail) {
2923+ // Rotate only the head portion (first rope_dims dimensions)
2924+ GGML_CANN_CALL_ACLNN_OP (ctx, RotaryPositionEmbedding, acl_src_head.get (), acl_cos_reshape_tensor.get (),
2925+ acl_sin_reshape_tensor.get (), acl_mode, acl_dst_head.get ());
28762926
2877- GGML_CANN_CALL_ACLNN_OP (ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get (),
2878- acl_cos_reshape_tensor.get (), acl_sin_reshape_tensor.get (), acl_mode,
2879- acl_dst_trans_tensor.get ());
2927+ // Copy head result from contiguous buffer back to destination tensor
2928+ if (src_dst_need_trans) {
2929+ acl_tensor_ptr acl_dst_head_target = ggml_cann_create_tensor (
2930+ (char *) dst_trans_buffer, ACL_FLOAT, sizeof (float ), head_ne, src_dst_trans_nb, GGML_MAX_DIMS);
2931+ cann_copy (ctx, acl_dst_head.get (), acl_dst_head_target.get ());
2932+ } else {
2933+ acl_tensor_ptr acl_dst_head_target =
2934+ ggml_cann_create_tensor ((char *) dst->data , ACL_FLOAT, sizeof (float ), head_ne, dst->nb , GGML_MAX_DIMS);
2935+ cann_copy (ctx, acl_dst_head.get (), acl_dst_head_target.get ());
2936+ }
2937+ } else if (src_dst_need_trans) {
2938+ // Rotate full tensor (no tail), using trans tensors
2939+ GGML_CANN_CALL_ACLNN_OP (ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get (), acl_cos_reshape_tensor.get (),
2940+ acl_sin_reshape_tensor.get (), acl_mode, acl_dst_trans_tensor.get ());
2941+ } else {
2942+ // Rotate full tensor (no tail), using original tensors
2943+ GGML_CANN_CALL_ACLNN_OP (ctx, RotaryPositionEmbedding, acl_src.get (), acl_cos_reshape_tensor.get (),
2944+ acl_sin_reshape_tensor.get (), acl_mode, acl_dst.get ());
2945+ }
2946+
2947+ // Step 4: Copy unrotated tail portion from source to destination
2948+ if (has_tail) {
2949+ size_t src_tail_offset;
2950+ size_t dst_tail_offset;
2951+
2952+ auto copy_tail_device = [&](void * src_ptr, void * dst_ptr, aclDataType dtype, size_t elem_size,
2953+ size_t * nb_src_arr, size_t * nb_dst_arr) {
2954+ acl_tensor_ptr acl_src_tail =
2955+ ggml_cann_create_tensor (src_ptr, dtype, elem_size, tail_ne, nb_src_arr, GGML_MAX_DIMS);
2956+ acl_tensor_ptr acl_dst_tail =
2957+ ggml_cann_create_tensor (dst_ptr, dtype, elem_size, tail_ne, nb_dst_arr, GGML_MAX_DIMS);
2958+ cann_copy (ctx, acl_src_tail.get (), acl_dst_tail.get ());
2959+ };
2960+
2961+ if (src_dst_need_trans) {
2962+ // Use F32 trans tensor strides and offsets
2963+ src_tail_offset = rope_dims * src_dst_trans_nb[0 ];
2964+ dst_tail_offset = rope_dims * src_dst_trans_nb[0 ];
2965+ copy_tail_device ((char *) src_trans_buffer + src_tail_offset, (char *) dst_trans_buffer + dst_tail_offset,
2966+ ACL_FLOAT, sizeof (float ), src_dst_trans_nb, src_dst_trans_nb);
2967+ } else {
2968+ // Use original tensor strides and offsets
2969+ src_tail_offset = rope_dims * nb00;
2970+ dst_tail_offset = rope_dims * nb0;
2971+ copy_tail_device ((char *) src0->data + src_tail_offset, (char *) dst->data + dst_tail_offset,
2972+ ggml_cann_type_mapping (dst->type ), ggml_element_size (dst), src0->nb , dst->nb );
2973+ }
2974+ }
28802975
2881- aclnn_cast (ctx, acl_dst_trans_tensor.get (), acl_dst.get (), ACL_FLOAT16);
2882- break ;
2883- }
2884- default :
2885- GGML_ABORT (" Unsupported tensor type for GGML_OP_ROPE" );
2886- break ;
2976+ // Step 5: Cast back to F16 if needed
2977+ if (src_dst_need_trans) {
2978+ aclnn_cast (ctx, acl_dst_trans_tensor.get (), acl_dst.get (), ACL_FLOAT16);
28872979 }
28882980}
28892981
0 commit comments