@@ -2590,10 +2590,7 @@ struct clip_graph {
25902590 } else {
25912591 ggml_tensor * v = ggml_permute (ctx0, v_cur, 1 , 2 , 0 , 3 );
25922592 v = ggml_cont (ctx0, v);
2593-
2594- const auto n_tokens = q->ne [1 ];
2595- const auto n_head = q->ne [2 ];
2596-
2593+
25972594 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
25982595 // F32 may not needed for vision encoders?
25992596 // ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -2602,7 +2599,8 @@ struct clip_graph {
26022599
26032600 ggml_tensor * kqv = ggml_mul_mat (ctx0, v, kq);
26042601 cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
2605- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
2602+ cur = ggml_reshape_2d (ctx0, ggml_cont (ctx0, cur), cur->ne [0 ] * cur->ne [1 ], cur->ne [2 ] * cur->ne [3 ]);
2603+
26062604 }
26072605
26082606 cb (cur, " kqv_out" , il);
@@ -2789,15 +2787,12 @@ struct clip_graph {
27892787
27902788 Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb [2 ], cur->nb [3 ], 0 *cur->nb [1 ]);
27912789 Q = ggml_reshape_4d (ctx0, ggml_cont (ctx0, Q), d_heads, n_heads, W*H, B);
2792- Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 )); // [B, n_heads, H*W, d_heads]
27932790
27942791 K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb [2 ], cur->nb [3 ], 1 *cur->nb [1 ]);
27952792 K = ggml_reshape_4d (ctx0, ggml_cont (ctx0, K), d_heads, n_heads, W*H, B);
2796- K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 )); // [B, n_heads, H*W, d_heads]
27972793
27982794 V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb [2 ], cur->nb [3 ], 2 *cur->nb [1 ]);
27992795 V = ggml_reshape_4d (ctx0, ggml_cont (ctx0, V), d_heads, n_heads, W*H, B);
2800- V = ggml_cont (ctx0, ggml_permute (ctx0, V, 0 , 2 , 1 , 3 )); // [B, n_heads, H*W, d_heads]
28012796
28022797 ggml_tensor * mask;
28032798 ggml_tensor * rw;
@@ -2806,7 +2801,8 @@ struct clip_graph {
28062801
28072802 rw = get_rel_pos (ctx0, layer.rel_pos_w , W, W); // [W, W, C]
28082803 rh = get_rel_pos (ctx0, layer.rel_pos_h , H, H); // [H, H, C]
2809- qr = ggml_reshape_4d (ctx0, Q, d_heads, W, H, B*n_heads);
2804+ qr = ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 );
2805+ qr = ggml_reshape_4d (ctx0, ggml_cont (ctx0, qr), d_heads, W, H, B * n_heads);
28102806
28112807 const int WH_pad = GGML_PAD (W*H, GGML_KQ_MASK_PAD) - W*H;
28122808
@@ -2822,11 +2818,10 @@ struct clip_graph {
28222818 mask = ggml_cast (ctx0, mask, GGML_TYPE_F16);
28232819
28242820 float scale = 1 .0f / sqrtf ((float )d_heads);
2825- cur = ggml_flash_attn_ext (ctx0, Q, K, V, mask, scale, 0 .0f , 0 .0f ); // [B, H*W, n_heads, d_heads]
28262821
2822+ cur = build_attn (layer.o_w , layer.o_b , Q, K, V, mask, scale,
2823+ il); // [B, H*W, n_embd]
28272824 cur = ggml_reshape_4d (ctx0, ggml_cont (ctx0, cur), n_embd, W, H, B);
2828- cur = ggml_mul_mat (ctx0, layer.o_w , cur);
2829- cur = ggml_add_inplace (ctx0, cur, layer.o_b );
28302825 }
28312826
28322827 if (hparams.is_global_attn (il) == false ) {
0 commit comments