@@ -2152,19 +2152,44 @@ struct clip_graph {
21522152
21532153 // self-attention
21542154 {
2155- ggml_tensor * Qcur = ggml_mul_mat (ctx0, layer.q_w , cur);
2156- if (layer.q_b ) {
2157- Qcur = ggml_add (ctx0, Qcur, layer.q_b );
2158- }
2159-
2160- ggml_tensor * Kcur = ggml_mul_mat (ctx0, layer.k_w , cur);
2161- if (layer.k_b ) {
2162- Kcur = ggml_add (ctx0, Kcur, layer.k_b );
2163- }
2155+ ggml_tensor * Qcur;
2156+ ggml_tensor * Kcur;
2157+ ggml_tensor * Vcur;
2158+
2159+ if (layer.qkv_w ) {
2160+ ggml_tensor * QKV;
21642161
2165- ggml_tensor * Vcur = ggml_mul_mat (ctx0, layer.v_w , cur);
2166- if (layer.v_b ) {
2167- Vcur = ggml_add (ctx0, Vcur, layer.v_b );
2162+ QKV = ggml_mul_mat (ctx0, layer.qkv_w , cur);
2163+ if (layer.qkv_b ) {
2164+ QKV = ggml_add (ctx0, QKV, layer.qkv_b );
2165+ }
2166+ QKV = ggml_reshape_4d (ctx0, QKV, cur->ne [0 ], 3 , cur->ne [1 ]*cur->ne [2 ], cur->ne [3 ]);
2167+
2168+ const int ne0 = QKV->ne [0 ];
2169+ const int ne2 = QKV->ne [2 ];
2170+ const int ne3 = QKV->ne [3 ];
2171+ const int nb1 = QKV->nb [1 ];
2172+ const int nb2 = QKV->nb [2 ];
2173+ const int nb3 = QKV->nb [3 ];
2174+
2175+ Qcur = ggml_cont (ctx0, ggml_view_3d (ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 0 *nb1));
2176+ Kcur = ggml_cont (ctx0, ggml_view_3d (ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 1 *nb1));
2177+ Vcur = ggml_cont (ctx0, ggml_view_3d (ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 2 *nb1));
2178+ } else {
2179+ Qcur = ggml_mul_mat (ctx0, layer.q_w , cur);
2180+ if (layer.q_b ) {
2181+ Qcur = ggml_add (ctx0, Qcur, layer.q_b );
2182+ }
2183+
2184+ Kcur = ggml_mul_mat (ctx0, layer.k_w , cur);
2185+ if (layer.k_b ) {
2186+ Kcur = ggml_add (ctx0, Kcur, layer.k_b );
2187+ }
2188+
2189+ Vcur = ggml_mul_mat (ctx0, layer.v_w , cur);
2190+ if (layer.v_b ) {
2191+ Vcur = ggml_add (ctx0, Vcur, layer.v_b );
2192+ }
21682193 }
21692194
21702195 if (layer.q_norm ) {
0 commit comments