@@ -595,11 +595,12 @@ struct clip_graph {
595595 cur = ggml_mul (ctx0, cur, model.mm_input_norm_w );
596596 cur = ggml_add (ctx0, cur, model.mm_input_norm_b );
597597
598- cur = ggml_mul_mat (ctx0, model.mm_1_w , cur);
599- cur = ggml_add (ctx0, cur, model.mm_1_b );
600- cur = ggml_gelu (ctx0, cur);
601- cur = ggml_mul_mat (ctx0, model.mm_2_w , cur);
602- cur = ggml_add (ctx0, cur, model.mm_2_b );
598+ cur = build_ffn (cur,
599+ model.mm_1_w , model.mm_1_b ,
600+ nullptr , nullptr ,
601+ model.mm_2_w , model.mm_2_b ,
602+ FFN_GELU,
603+ -1 );
603604
604605 } else if (ctx->proj_type () == PROJECTOR_TYPE_JANUS_PRO) {
605606 cur = build_ffn (cur,
@@ -667,16 +668,12 @@ struct clip_graph {
667668
668669 // LlavaMultiModalProjector (always using GELU activation)
669670 {
670- cur = ggml_mul_mat (ctx0, model.mm_1_w , cur);
671- if (model.mm_1_b ) {
672- cur = ggml_add (ctx0, cur, model.mm_1_b );
673- }
674-
675- cur = ggml_gelu (ctx0, cur);
676- cur = ggml_mul_mat (ctx0, model.mm_2_w , cur);
677- if (model.mm_2_b ) {
678- cur = ggml_add (ctx0, cur, model.mm_2_b );
679- }
671+ cur = build_ffn (cur,
672+ model.mm_1_w , model.mm_1_b ,
673+ nullptr , nullptr ,
674+ model.mm_2_w , model.mm_2_b ,
675+ FFN_GELU,
676+ -1 );
680677 }
681678
682679 // arrangement of the [IMG_BREAK] token
@@ -866,16 +863,12 @@ struct clip_graph {
866863 // multimodal projection
867864 ggml_tensor * embeddings = inpL;
868865 embeddings = ggml_reshape_3d (ctx0, embeddings, n_embd * 4 , n_pos / 4 , batch_size);
869-
870- embeddings = ggml_mul_mat (ctx0, model.mm_0_w , embeddings);
871- embeddings = ggml_add (ctx0, embeddings, model.mm_0_b );
872-
873- // GELU activation
874- embeddings = ggml_gelu (ctx0, embeddings);
875-
876- // Second linear layer
877- embeddings = ggml_mul_mat (ctx0, model.mm_1_w , embeddings);
878- embeddings = ggml_add (ctx0, embeddings, model.mm_1_b );
866+ embeddings = build_ffn (embeddings,
867+ model.mm_0_w , model.mm_0_b ,
868+ nullptr , nullptr ,
869+ model.mm_1_w , model.mm_1_b ,
870+ FFN_GELU,
871+ -1 );
879872
880873 if (use_window_attn) {
881874 window_idx = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_pos / 4 );
@@ -1253,11 +1246,12 @@ struct clip_graph {
12531246 // projector LayerNorm uses pytorch's default eps = 1e-5
12541247 // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
12551248 cur = build_norm (cur, model.mm_0_w , model.mm_0_b , NORM_TYPE_NORMAL, 1e-5 , -1 );
1256- cur = ggml_mul_mat (ctx0, model.mm_1_w , cur);
1257- cur = ggml_add (ctx0, cur, model.mm_1_b );
1258- cur = ggml_gelu (ctx0, cur);
1259- cur = ggml_mul_mat (ctx0, model.mm_3_w , cur);
1260- cur = ggml_add (ctx0, cur, model.mm_3_b );
1249+ cur = build_ffn (cur,
1250+ model.mm_1_w , model.mm_1_b ,
1251+ nullptr , nullptr ,
1252+ model.mm_3_w , model.mm_3_b ,
1253+ FFN_GELU,
1254+ -1 );
12611255 }
12621256
12631257 // build the graph
@@ -1408,11 +1402,12 @@ struct clip_graph {
14081402 cb (cur, " proj_inp_normed" , -1 );
14091403
14101404 // projection mlp
1411- cur = ggml_mul_mat (ctx0, model.mm_1_w , cur);
1412- cur = ggml_add (ctx0, cur, model.mm_1_b );
1413- cur = ggml_gelu (ctx0, cur);
1414- cur = ggml_mul_mat (ctx0, model.mm_2_w , cur);
1415- cur = ggml_add (ctx0, cur, model.mm_2_b );
1405+ cur = build_ffn (cur,
1406+ model.mm_1_w , model.mm_1_b ,
1407+ nullptr , nullptr ,
1408+ model.mm_2_w , model.mm_2_b ,
1409+ FFN_GELU,
1410+ -1 );
14161411 cb (cur, " proj_out" , -1 );
14171412 }
14181413
@@ -1883,9 +1878,12 @@ struct clip_graph {
18831878
18841879 } else if (ctx->proj_type () == PROJECTOR_TYPE_VOXTRAL) {
18851880 // projector
1886- cur = ggml_mul_mat (ctx0, model.mm_1_w , cur);
1887- cur = ggml_gelu_erf (ctx0, cur);
1888- cur = ggml_mul_mat (ctx0, model.mm_2_w , cur);
1881+ cur = build_ffn (cur,
1882+ model.mm_1_w , model.mm_1_b ,
1883+ nullptr , nullptr ,
1884+ model.mm_2_w , model.mm_2_b ,
1885+ FFN_GELU_ERF,
1886+ -1 );
18891887
18901888 } else {
18911889 GGML_ABORT (" %s: unknown projector type" , __func__);
@@ -2070,34 +2068,66 @@ struct clip_graph {
20702068
20712069 // self-attention
20722070 {
2073- ggml_tensor * Qcur = ggml_mul_mat (ctx0, layer.q_w , cur);
2074- if (layer.q_b ) {
2075- Qcur = ggml_add (ctx0, Qcur, layer.q_b );
2076- }
2071+ ggml_tensor * Qcur = nullptr ;
2072+ ggml_tensor * Kcur = nullptr ;
2073+ ggml_tensor * Vcur = nullptr ;
2074+ if (layer.qkv_w != nullptr ) {
2075+ // fused qkv
2076+ cur = ggml_mul_mat (ctx0, layer.qkv_w , cur);
2077+ if (layer.qkv_b != nullptr ) {
2078+ cur = ggml_add (ctx0, cur, layer.qkv_b );
2079+ }
20772080
2078- ggml_tensor * Kcur = ggml_mul_mat (ctx0, layer. k_w , cur);
2079- if (layer. k_b ) {
2080- Kcur = ggml_add (ctx0, Kcur, layer. k_b );
2081- }
2081+ Qcur = ggml_view_3d (ctx0, cur, d_head, n_head, n_pos,
2082+ /* nb1 */ ggml_row_size (cur-> type , d_head),
2083+ /* nb2 */ cur-> nb [ 1 ],
2084+ /* offset */ 0 );
20822085
2083- ggml_tensor * Vcur = ggml_mul_mat (ctx0, layer. v_w , cur);
2084- if (layer. v_b ) {
2085- Vcur = ggml_add (ctx0, Vcur, layer. v_b );
2086- }
2086+ Kcur = ggml_view_3d (ctx0, cur, d_head, n_head, n_pos,
2087+ /* nb1 */ ggml_row_size (cur-> type , d_head),
2088+ /* nb2 */ cur-> nb [ 1 ],
2089+ /* offset */ ggml_row_size (cur-> type , n_embd));
20872090
2088- if (layer. q_norm ) {
2089- Qcur = build_norm (Qcur, layer. q_norm , NULL , norm_t , eps, il);
2090- cb (Qcur, " Qcur_norm " , il);
2091- }
2091+ Vcur = ggml_view_3d (ctx0, cur, d_head, n_head, n_pos,
2092+ /* nb1 */ ggml_row_size (cur-> type , d_head),
2093+ /* nb2 */ cur-> nb [ 1 ],
2094+ /* offset */ ggml_row_size (cur-> type , 2 * n_embd));
20922095
2093- if (layer.k_norm ) {
2094- Kcur = build_norm (Kcur, layer.k_norm , NULL , norm_t , eps, il);
2095- cb (Kcur, " Kcur_norm" , il);
2096- }
2096+ // TODO: q/k norm requires row size == n_embd, while here it's d_head
2097+ // we can add support in the future if needed
2098+ GGML_ASSERT (layer.q_norm == nullptr && layer.k_norm == nullptr );
20972099
2098- Qcur = ggml_reshape_3d (ctx0, Qcur, d_head, n_head, n_pos);
2099- Kcur = ggml_reshape_3d (ctx0, Kcur, d_head, n_head, n_pos);
2100- Vcur = ggml_reshape_3d (ctx0, Vcur, d_head, n_head, n_pos);
2100+ } else {
2101+ // separate q, k, v
2102+ Qcur = ggml_mul_mat (ctx0, layer.q_w , cur);
2103+ if (layer.q_b ) {
2104+ Qcur = ggml_add (ctx0, Qcur, layer.q_b );
2105+ }
2106+
2107+ Kcur = ggml_mul_mat (ctx0, layer.k_w , cur);
2108+ if (layer.k_b ) {
2109+ Kcur = ggml_add (ctx0, Kcur, layer.k_b );
2110+ }
2111+
2112+ Vcur = ggml_mul_mat (ctx0, layer.v_w , cur);
2113+ if (layer.v_b ) {
2114+ Vcur = ggml_add (ctx0, Vcur, layer.v_b );
2115+ }
2116+
2117+ if (layer.q_norm ) {
2118+ Qcur = build_norm (Qcur, layer.q_norm , NULL , norm_t , eps, il);
2119+ cb (Qcur, " Qcur_norm" , il);
2120+ }
2121+
2122+ if (layer.k_norm ) {
2123+ Kcur = build_norm (Kcur, layer.k_norm , NULL , norm_t , eps, il);
2124+ cb (Kcur, " Kcur_norm" , il);
2125+ }
2126+
2127+ Qcur = ggml_reshape_3d (ctx0, Qcur, d_head, n_head, n_pos);
2128+ Kcur = ggml_reshape_3d (ctx0, Kcur, d_head, n_head, n_pos);
2129+ Vcur = ggml_reshape_3d (ctx0, Vcur, d_head, n_head, n_pos);
2130+ }
21012131
21022132 cb (Qcur, " Qcur" , il);
21032133 cb (Kcur, " Kcur" , il);
0 commit comments