Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2590,10 +2590,7 @@ struct clip_graph {
} else {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);

const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];


ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// F32 may not needed for vision encoders?
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
Expand All @@ -2602,7 +2599,8 @@ struct clip_graph {

ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);

}

cb(cur, "kqv_out", il);
Expand Down Expand Up @@ -2789,15 +2787,12 @@ struct clip_graph {

Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]);
Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B);
Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]

K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]);
K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B);
K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]

V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]);
V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B);
V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]

ggml_tensor * mask;
ggml_tensor * rw;
Expand All @@ -2806,7 +2801,8 @@ struct clip_graph {

rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C]
rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C]
qr = ggml_reshape_4d(ctx0, Q, d_heads, W, H, B*n_heads);
qr = ggml_permute(ctx0, Q, 0, 2, 1, 3);
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads);

const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H;

Expand All @@ -2822,11 +2818,10 @@ struct clip_graph {
mask = ggml_cast (ctx0, mask, GGML_TYPE_F16);

float scale = 1.0f / sqrtf((float)d_heads);
cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, n_heads, d_heads]

cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale,
il); // [B, H*W, n_embd]
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B);
cur = ggml_mul_mat(ctx0, layer.o_w, cur);
cur = ggml_add_inplace(ctx0, cur, layer.o_b);
}

if (hparams.is_global_attn(il) == false) {
Expand Down
2 changes: 1 addition & 1 deletion tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ struct mtmd_context {

clip_context_params ctx_clip_params {
/* use_gpu */ ctx_params.use_gpu,
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
/* image_min_tokens */ ctx_params.image_min_tokens,
/* image_max_tokens */ ctx_params.image_max_tokens,
/* warmup */ ctx_params.warmup,
Expand Down