@@ -590,11 +590,9 @@ static struct ggml_tensor * llm_build_kqv(
590590 cb (v, " v" , il);
591591
592592 struct ggml_tensor * padded_v = v;
593- int64_t n_embd_head_v_out = n_embd_head_v;
594593 if (n_embd_head_v < n_embd_head_k) {
595594 padded_v = ggml_pad (ctx, v, 0 , k->ne [0 ] - v->ne [1 ], 0 , 0 );
596595 cb (padded_v, " padded_v" , il);
597- n_embd_head_v_out = n_embd_head_k;
598596 padded_v = ggml_cont (ctx, padded_v);
599597 }
600598
@@ -604,11 +602,7 @@ static struct ggml_tensor * llm_build_kqv(
604602 ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
605603
606604 if (n_embd_head_v < n_embd_head_k) {
607- cur = ggml_reshape_3d (ctx, cur, n_embd_head_v_out, n_head, n_tokens);
608- cur = ggml_cont (ctx, ggml_view_3d (ctx, cur, n_embd_head_v, n_head, n_tokens,
609- ggml_row_size (cur->type , n_embd_head_v_out),
610- ggml_row_size (cur->type , n_embd_head_v_out * n_head),
611- ggml_element_size (cur) * (n_embd_head_k - n_embd_head_v)));
605+ cur = ggml_view_1d (ctx, ggml_cont (ctx, cur), n_embd_head_k*n_head, n_tokens);
612606 }
613607
614608 cur = ggml_reshape_2d (ctx, cur, n_embd_head_v*n_head, n_tokens);
0 commit comments