Skip to content

Commit 0dda80f

Browse files
Refactor MLP projector to use build_ffn helper
1 parent 6247fd2 commit 0dda80f

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

tools/mtmd/clip.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,29 +1116,22 @@ struct clip_graph {
11161116
// 2-layer MLP projector: mm.0 -> GELU -> mm.2
11171117
ggml_tensor * embeddings = cur;
11181118

1119-
// projector matmuls assume canonical [n_in, n_out] weights; no runtime transposes
11201119
GGML_ASSERT(model.mm_0_w != nullptr);
1120+
GGML_ASSERT(model.mm_2_w != nullptr);
1121+
11211122
// ensure projector input is a packed 2D matrix [n_in, n_tokens]
11221123
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
11231124
embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
1124-
// embeddings are packed [n_in, n_tokens]
1125-
// Weights are canonicalized at conversion time to [n_in, n_out]; multiply directly.
1126-
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
1127-
if (model.mm_0_b) {
1128-
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
1129-
}
1130-
1131-
embeddings = ggml_gelu(ctx0, embeddings);
11321125

1133-
GGML_ASSERT(model.mm_2_w != nullptr);
1134-
// keep [n_in, n_tokens] layout for the second matmul as well
1135-
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
1136-
embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
1137-
// Weights are canonicalized at conversion time to [n_in, n_out]; multiply directly.
1138-
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
1139-
if (model.mm_2_b) {
1140-
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
1141-
}
1126+
// Use shared FFN helper: Linear(mm.0) -> GELU -> Linear(mm.2)
1127+
embeddings = build_ffn(
1128+
embeddings,
1129+
model.mm_0_w, model.mm_0_b,
1130+
/*gate=*/nullptr, /*gate_b=*/nullptr,
1131+
model.mm_2_w, model.mm_2_b,
1132+
FFN_GELU,
1133+
/*il=*/0
1134+
);
11421135

11431136
// build the graph
11441137
ggml_build_forward_expand(gf, embeddings);

0 commit comments

Comments
 (0)