Skip to content

Commit fc3f625

Browse files
committed
mtmd: support combined QKV projection in buid_vit
1 parent 2dd9924 commit fc3f625

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

tools/mtmd/clip.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2152,19 +2152,44 @@ struct clip_graph {
21522152

21532153
// self-attention
21542154
{
2155-
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
2156-
if (layer.q_b) {
2157-
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
2158-
}
2159-
2160-
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
2161-
if (layer.k_b) {
2162-
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
2163-
}
2155+
ggml_tensor * Qcur;
2156+
ggml_tensor * Kcur;
2157+
ggml_tensor * Vcur;
2158+
2159+
if (layer.qkv_w) {
2160+
ggml_tensor * QKV;
21642161

2165-
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
2166-
if (layer.v_b) {
2167-
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
2162+
QKV = ggml_mul_mat(ctx0, layer.qkv_w, cur);
2163+
if (layer.qkv_b) {
2164+
QKV = ggml_add(ctx0, QKV, layer.qkv_b);
2165+
}
2166+
QKV = ggml_reshape_4d(ctx0, QKV, cur->ne[0], 3, cur->ne[1]*cur->ne[2], cur->ne[3]);
2167+
2168+
const int ne0 = QKV->ne[0];
2169+
const int ne2 = QKV->ne[2];
2170+
const int ne3 = QKV->ne[3];
2171+
const int nb1 = QKV->nb[1];
2172+
const int nb2 = QKV->nb[2];
2173+
const int nb3 = QKV->nb[3];
2174+
2175+
Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 0*nb1));
2176+
Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 1*nb1));
2177+
Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 2*nb1));
2178+
} else {
2179+
Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
2180+
if (layer.q_b) {
2181+
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
2182+
}
2183+
2184+
Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
2185+
if (layer.k_b) {
2186+
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
2187+
}
2188+
2189+
Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
2190+
if (layer.v_b) {
2191+
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
2192+
}
21682193
}
21692194

21702195
if (layer.q_norm) {

0 commit comments

Comments
 (0)