Skip to content

Commit ca709e4

Browse files
authored
CANN: add support for partial RoPE and Vision mode (ggml-org#17543)
* cann: add support for partial RoPE and Vision mode Add support for two important RoPE variants: partial rotation (rope_dims < ne0) and Vision mode rotation. 1. Support for partial RoPE (rope_dims < ne0): - Split tensor into head (first rope_dims dimensions) and tail portions - Apply rotation only to head portion using RotaryPositionEmbedding operator - Copy unrotated tail portion directly from source to destination - Handle both contiguous and non-contiguous tensor layouts 2. Support for Vision mode (GGML_ROPE_TYPE_VISION): - Set rope_dims = ne0 for Vision mode to rotate entire tensor - Vision mode pairs dimension i with dimension i+n_dims (where n_dims = ne0/2) - No tail handling needed since entire tensor is rotated Implementation details: - Use has_tail flag to determine execution path: head/tail splitting when rope_dims < ne0, or full tensor rotation when rope_dims == ne0 - Support both F32 and F16 data types with intermediate F32 conversion - Copy non-contiguous tensors to contiguous buffers before calling RotaryPositionEmbedding operator for compatibility - Improve cache invalidation logic to include rope_dims and indep_sects parameters These enhancements enable CANN backend to handle various RoPE configurations used in modern vision-language models and models with partial rotation. * cann: fix review comment
1 parent 0cdce38 commit ca709e4

File tree

3 files changed

+161
-71
lines changed

3 files changed

+161
-71
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 153 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,12 +2251,12 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
22512251
int sections[4],
22522252
bool mrope_used,
22532253
bool is_imrope,
2254-
bool indep_sects) {
2255-
ggml_tensor * src0 = dst->src[0]; // input
2254+
bool indep_sects,
2255+
int64_t rope_dims) {
22562256
ggml_tensor * src1 = dst->src[1]; // position
22572257
ggml_tensor * src2 = dst->src[2]; // freq_factors
22582258

2259-
int64_t theta_scale_length = src0->ne[0] / 2;
2259+
int64_t theta_scale_length = rope_dims / 2;
22602260
int64_t position_length = dst->ne[2];
22612261

22622262
// TODO: check theta_scale_length and position_length.
@@ -2331,18 +2331,17 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
23312331
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
23322332
ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
23332333
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
2334-
2335-
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
2336-
theta_scale_ne, theta_scale_nb, 1);
23372334
}
2335+
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
2336+
theta_scale_ne, theta_scale_nb, 1);
23382337

23392338
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
2339+
// TODO: acl_yarn_ramp_tensor use rope cache.
23402340
bool yarn_ramp_tensor_updated = false;
23412341
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
23422342
acl_tensor_ptr acl_yarn_ramp_tensor;
2343-
if (ext_factor != 0 &&
2344-
// TODO: check more parameter.
2345-
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
2343+
if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
2344+
ctx.rope_cache.freq_scale != freq_scale)) {
23462345
yarn_ramp_tensor_updated = true;
23472346

23482347
// -rope_yarn_ramp
@@ -2590,7 +2589,7 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
25902589
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
25912590
}
25922591

2593-
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
2592+
int64_t sin_reshape_ne[4] = { rope_dims, 1, dst->ne[2], 1 };
25942593
size_t sin_reshape_nb[GGML_MAX_DIMS];
25952594
sin_reshape_nb[0] = sizeof(float);
25962595
for (int i = 1; i < GGML_MAX_DIMS; i++) {
@@ -2645,7 +2644,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
26452644

26462645
// param
26472646
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2648-
int sections[4];
2647+
int sections[4];
26492648
// const int n_past = ((int32_t *) dst->op_params)[0];
26502649
const int n_dims = ((int32_t *) dst->op_params)[1];
26512650
const int mode = ((int32_t *) dst->op_params)[2];
@@ -2654,44 +2653,60 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
26542653

26552654
GGML_TENSOR_UNARY_OP_LOCALS
26562655

2657-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2658-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2659-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2660-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2661-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2662-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2663-
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
2656+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2657+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2658+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2659+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2660+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2661+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2662+
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
26642663

2665-
// TODO: n_dims <= ne0
2666-
GGML_ASSERT(n_dims == ne0);
26672664
GGML_ASSERT(n_dims % 2 == 0);
2665+
GGML_ASSERT(n_dims <= ne00);
26682666

26692667
const float theta_scale = powf(freq_base, -2.0f / n_dims);
26702668

26712669
float corr_dims[2];
26722670
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
26732671

2674-
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
2675-
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
2676-
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
2677-
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
2672+
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
2673+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
2674+
// mrope_used means the GGML_ROPE_TYPE_MROPE bit is set.
2675+
// Note: this bit is also set for imrope and some vision modes,
2676+
// so mrope_used does NOT exclusively indicate pure mrope.
2677+
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
2678+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
26782679

26792680
if (mrope_used) {
26802681
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
26812682
}
26822683

26832684
if (is_vision) {
2684-
GGML_ASSERT(n_dims == ne0/2);
2685+
GGML_ASSERT(n_dims == ne0 / 2);
26852686
}
26862687

26872688
if (is_imrope || mrope_used) {
26882689
is_neox = true;
26892690
}
26902691

2692+
int64_t rope_dims = n_dims;
2693+
2694+
//Our current RotaryPositionEmbedding does not support the VISION mode,
2695+
//but essentially it only modifies theta_base in mrope,
2696+
//then repeats it at the end in the same way as is_neox.
2697+
//In fact, RoPE is still applied across all dimensions.
2698+
if (is_vision) {
2699+
rope_dims = src0->ne[0];
2700+
}
2701+
int64_t tail_dims = ne00 - rope_dims;
2702+
bool has_tail = tail_dims > 0;
2703+
26912704
// init ctx.rope_cos/rope_sin cache
2692-
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
2705+
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
2706+
mrope_used, is_imrope, is_vision, rope_dims);
26932707

2694-
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
2708+
// Cache is generated with ne00 dimensions, so we use ne00 for reshape
2709+
int64_t sin_reshape_ne[4] = { rope_dims, 1, ne02, 1 };
26952710
size_t sin_reshape_nb[GGML_MAX_DIMS];
26962711
sin_reshape_nb[0] = sizeof(float);
26972712
for (int i = 1; i < GGML_MAX_DIMS; i++) {
@@ -2704,7 +2719,6 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
27042719

27052720
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
27062721
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
2707-
27082722
#ifdef ASCEND_310P
27092723
// Special ROPE operation for 310P
27102724

@@ -2844,46 +2858,124 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
28442858
}
28452859
return;
28462860
#endif
2847-
28482861
int64_t acl_mode = is_neox ? 0 : 1;
28492862

2850-
switch (src0->type) {
2851-
case GGML_TYPE_F32:
2852-
{
2853-
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
2854-
acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get());
2855-
break;
2856-
}
2857-
case GGML_TYPE_F16:
2858-
{
2859-
ggml_cann_pool_alloc src_trans_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(float));
2860-
void * src_trans_buffer = src_trans_allocator.get();
2861-
ggml_cann_pool_alloc dst_trans_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(float));
2862-
void * dst_trans_buffer = dst_trans_allocator.get();
2863+
// Pre-define head and tail dimensions for reuse
2864+
int64_t head_ne[GGML_MAX_DIMS] = { rope_dims, ne01, ne02, ne03 };
2865+
int64_t tail_ne[GGML_MAX_DIMS] = { tail_dims, ne01, ne02, ne03 };
2866+
2867+
// Step 1: Prepare trans tensors for F16 type conversion to F32 if needed
2868+
bool src_dst_need_trans = false;
2869+
ggml_cann_pool_alloc src_trans_allocator(ctx.pool());
2870+
ggml_cann_pool_alloc dst_trans_allocator(ctx.pool());
2871+
acl_tensor_ptr acl_src_trans_tensor;
2872+
acl_tensor_ptr acl_dst_trans_tensor;
2873+
void * src_trans_buffer = nullptr;
2874+
void * dst_trans_buffer = nullptr;
2875+
size_t src_dst_trans_nb[GGML_MAX_DIMS];
2876+
if (src0->type == GGML_TYPE_F16) {
2877+
src_dst_need_trans = true;
2878+
src_trans_buffer = src_trans_allocator.alloc(ggml_nelements(src0) * sizeof(float));
2879+
dst_trans_buffer = dst_trans_allocator.alloc(ggml_nelements(dst) * sizeof(float));
2880+
2881+
src_dst_trans_nb[0] = sizeof(float);
2882+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2883+
src_dst_trans_nb[i] = src_dst_trans_nb[i - 1] * src0->ne[i - 1];
2884+
}
2885+
acl_src_trans_tensor = ggml_cann_create_tensor(src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne,
2886+
src_dst_trans_nb, GGML_MAX_DIMS);
2887+
acl_dst_trans_tensor = ggml_cann_create_tensor(dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne,
2888+
src_dst_trans_nb, GGML_MAX_DIMS);
2889+
aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT);
2890+
}
2891+
2892+
// Step 2: Prepare head tensors for tail splitting if needed
2893+
acl_tensor_ptr acl_src_head;
2894+
acl_tensor_ptr acl_dst_head;
2895+
if (has_tail) {
2896+
// Create head views for RotaryPositionEmbedding (only first rope_dims dimensions)
2897+
// RotaryPositionEmbedding requires contiguous dst tensor, so we use a temporary buffer
2898+
if (src_dst_need_trans) {
2899+
// Use F32 trans tensor strides
2900+
acl_src_head = ggml_cann_create_tensor((char *) src_trans_buffer, ACL_FLOAT, sizeof(float), head_ne,
2901+
src_dst_trans_nb, GGML_MAX_DIMS);
2902+
} else {
2903+
// Use original F32 tensor strides
2904+
acl_src_head = ggml_cann_create_tensor((char *) src0->data, ACL_FLOAT, sizeof(float), head_ne, src0->nb,
2905+
GGML_MAX_DIMS);
2906+
}
28632907

2864-
size_t src_trans_nb[GGML_MAX_DIMS];
2865-
src_trans_nb[0] = sizeof(float);
2866-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2867-
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
2868-
}
2908+
int64_t head_elements = rope_dims * ne01 * ne02 * ne03;
2909+
ggml_cann_pool_alloc dst_head_contiguous_allocator(ctx.pool(), head_elements * sizeof(float));
2910+
void * dst_head_contiguous_buffer = dst_head_contiguous_allocator.get();
28692911

2870-
acl_tensor_ptr acl_src_trans_tensor = ggml_cann_create_tensor(
2871-
src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, GGML_MAX_DIMS);
2872-
acl_tensor_ptr acl_dst_trans_tensor = ggml_cann_create_tensor(
2873-
dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, GGML_MAX_DIMS);
2912+
size_t head_contiguous_nb[GGML_MAX_DIMS];
2913+
head_contiguous_nb[0] = sizeof(float);
2914+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2915+
head_contiguous_nb[i] = head_contiguous_nb[i - 1] * head_ne[i - 1];
2916+
}
2917+
acl_dst_head = ggml_cann_create_tensor(dst_head_contiguous_buffer, ACL_FLOAT, sizeof(float), head_ne,
2918+
head_contiguous_nb, GGML_MAX_DIMS);
2919+
}
28742920

2875-
aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT);
2921+
// Step 3: Execute RotaryPositionEmbedding
2922+
if (has_tail) {
2923+
// Rotate only the head portion (first rope_dims dimensions)
2924+
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_head.get(), acl_cos_reshape_tensor.get(),
2925+
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_head.get());
28762926

2877-
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(),
2878-
acl_cos_reshape_tensor.get(), acl_sin_reshape_tensor.get(), acl_mode,
2879-
acl_dst_trans_tensor.get());
2927+
// Copy head result from contiguous buffer back to destination tensor
2928+
if (src_dst_need_trans) {
2929+
acl_tensor_ptr acl_dst_head_target = ggml_cann_create_tensor(
2930+
(char *) dst_trans_buffer, ACL_FLOAT, sizeof(float), head_ne, src_dst_trans_nb, GGML_MAX_DIMS);
2931+
cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get());
2932+
} else {
2933+
acl_tensor_ptr acl_dst_head_target =
2934+
ggml_cann_create_tensor((char *) dst->data, ACL_FLOAT, sizeof(float), head_ne, dst->nb, GGML_MAX_DIMS);
2935+
cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get());
2936+
}
2937+
} else if (src_dst_need_trans) {
2938+
// Rotate full tensor (no tail), using trans tensors
2939+
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
2940+
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
2941+
} else {
2942+
// Rotate full tensor (no tail), using original tensors
2943+
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
2944+
acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get());
2945+
}
2946+
2947+
// Step 4: Copy unrotated tail portion from source to destination
2948+
if (has_tail) {
2949+
size_t src_tail_offset;
2950+
size_t dst_tail_offset;
2951+
2952+
auto copy_tail_device = [&](void * src_ptr, void * dst_ptr, aclDataType dtype, size_t elem_size,
2953+
size_t * nb_src_arr, size_t * nb_dst_arr) {
2954+
acl_tensor_ptr acl_src_tail =
2955+
ggml_cann_create_tensor(src_ptr, dtype, elem_size, tail_ne, nb_src_arr, GGML_MAX_DIMS);
2956+
acl_tensor_ptr acl_dst_tail =
2957+
ggml_cann_create_tensor(dst_ptr, dtype, elem_size, tail_ne, nb_dst_arr, GGML_MAX_DIMS);
2958+
cann_copy(ctx, acl_src_tail.get(), acl_dst_tail.get());
2959+
};
2960+
2961+
if (src_dst_need_trans) {
2962+
// Use F32 trans tensor strides and offsets
2963+
src_tail_offset = rope_dims * src_dst_trans_nb[0];
2964+
dst_tail_offset = rope_dims * src_dst_trans_nb[0];
2965+
copy_tail_device((char *) src_trans_buffer + src_tail_offset, (char *) dst_trans_buffer + dst_tail_offset,
2966+
ACL_FLOAT, sizeof(float), src_dst_trans_nb, src_dst_trans_nb);
2967+
} else {
2968+
// Use original tensor strides and offsets
2969+
src_tail_offset = rope_dims * nb00;
2970+
dst_tail_offset = rope_dims * nb0;
2971+
copy_tail_device((char *) src0->data + src_tail_offset, (char *) dst->data + dst_tail_offset,
2972+
ggml_cann_type_mapping(dst->type), ggml_element_size(dst), src0->nb, dst->nb);
2973+
}
2974+
}
28802975

2881-
aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16);
2882-
break;
2883-
}
2884-
default:
2885-
GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE");
2886-
break;
2976+
// Step 5: Cast back to F16 if needed
2977+
if (src_dst_need_trans) {
2978+
aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16);
28872979
}
28882980
}
28892981

ggml/src/ggml-cann/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ struct ggml_cann_rope_cache {
315315
if (theta_scale_exp_host) {
316316
free(theta_scale_exp_host);
317317
}
318-
if(position_select_index_host) {
318+
if (position_select_index_host) {
319319
free(position_select_index_host);
320320
}
321321
}
@@ -340,7 +340,7 @@ struct ggml_cann_rope_cache {
340340

341341
void set(int64_t theta_scale_length,
342342
int64_t position_length,
343-
float ext_factor,
343+
float ext_factor,
344344
float theta_scale,
345345
float freq_scale,
346346
float attn_factor,

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,7 +2308,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
23082308

23092309
bool cann_graph_update_required = false;
23102310
#ifdef USE_ACL_GRAPH
2311-
bool use_cann_graph = true;
2311+
bool use_cann_graph = true;
23122312

23132313
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
23142314
if (!prefill_use_graph) {
@@ -2338,7 +2338,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
23382338
}
23392339
}
23402340
#else
2341-
bool use_cann_graph = false;
2341+
bool use_cann_graph = false;
23422342
#endif // USE_ACL_GRAPH
23432343
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
23442344

@@ -2474,16 +2474,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
24742474
}
24752475
case GGML_OP_ROPE:
24762476
{
2477-
// TODO: with ops-test v == 1
2478-
// TODO: n_dims <= ne0
2479-
if (op->src[0]->ne[0] != op->op_params[1]) {
2480-
return false;
2481-
}
2482-
24832477
if (op->src[0]->ne[0] > 896) {
24842478
return false;
24852479
}
24862480
#ifdef ASCEND_310P
2481+
// TODO: Support rope_dim < ne00(dim)
2482+
if (op->src[0]->ne[0] != op->op_params[1]) {
2483+
return false;
2484+
}
24872485
if (!ggml_is_contiguous(op->src[0])) {
24882486
return false;
24892487
}

0 commit comments

Comments
 (0)