diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d1bed23d030..af03a8fe2e2 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2590,10 +2590,7 @@ struct clip_graph { } else { ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); v = ggml_cont(ctx0, v); - - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); // F32 may not needed for vision encoders? // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -2602,7 +2599,8 @@ struct clip_graph { ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]); + } cb(cur, "kqv_out", il); @@ -2789,15 +2787,12 @@ struct clip_graph { Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]); Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B); - Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]); K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B); - K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]); V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B); - V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] ggml_tensor * mask; ggml_tensor * rw; @@ -2806,7 +2801,8 @@ struct clip_graph { rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C] rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C] - qr = ggml_reshape_4d(ctx0, Q, d_heads, W, H, B*n_heads); + qr = ggml_permute(ctx0, Q, 0, 2, 1, 3); + qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads); const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H; @@ -2822,11 +2818,10 @@ struct clip_graph { mask = ggml_cast (ctx0, mask, GGML_TYPE_F16); float scale = 1.0f / sqrtf((float)d_heads); - cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, n_heads, d_heads] + cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale, + il); // [B, H*W, n_embd] cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B); - cur = ggml_mul_mat(ctx0, layer.o_w, cur); - cur = ggml_add_inplace(ctx0, cur, layer.o_b); } if (hparams.is_global_attn(il) == false) { diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 2c20af099b9..791ac771668 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -175,7 +175,7 @@ struct mtmd_context { clip_context_params ctx_clip_params { /* use_gpu */ ctx_params.use_gpu, - /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, + /* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type), /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, /* warmup */ ctx_params.warmup,