Skip to content

Commit 6687b4e

Browse files
authored
Merge pull request #9 from sfallah/sf/deepseek-ocr-attn
using common build_attn in sam
2 parents d0c08e3 + f5bd310 commit 6687b4e

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

tools/mtmd/clip.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2590,10 +2590,7 @@ struct clip_graph {
25902590
} else {
25912591
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
25922592
v = ggml_cont(ctx0, v);
2593-
2594-
const auto n_tokens = q->ne[1];
2595-
const auto n_head = q->ne[2];
2596-
2593+
25972594
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
25982595
// F32 may not needed for vision encoders?
25992596
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -2602,7 +2599,8 @@ struct clip_graph {
26022599

26032600
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
26042601
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
2605-
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
2602+
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);
2603+
26062604
}
26072605

26082606
cb(cur, "kqv_out", il);
@@ -2789,15 +2787,12 @@ struct clip_graph {
27892787

27902788
Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]);
27912789
Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B);
2792-
Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
27932790

27942791
K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]);
27952792
K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B);
2796-
K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
27972793

27982794
V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]);
27992795
V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B);
2800-
V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
28012796

28022797
ggml_tensor * mask;
28032798
ggml_tensor * rw;
@@ -2806,7 +2801,8 @@ struct clip_graph {
28062801

28072802
rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C]
28082803
rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C]
2809-
qr = ggml_reshape_4d(ctx0, Q, d_heads, W, H, B*n_heads);
2804+
qr = ggml_permute(ctx0, Q, 0, 2, 1, 3);
2805+
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads);
28102806

28112807
const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H;
28122808

@@ -2822,11 +2818,10 @@ struct clip_graph {
28222818
mask = ggml_cast (ctx0, mask, GGML_TYPE_F16);
28232819

28242820
float scale = 1.0f / sqrtf((float)d_heads);
2825-
cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, n_heads, d_heads]
28262821

2822+
cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale,
2823+
il); // [B, H*W, n_embd]
28272824
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B);
2828-
cur = ggml_mul_mat(ctx0, layer.o_w, cur);
2829-
cur = ggml_add_inplace(ctx0, cur, layer.o_b);
28302825
}
28312826

28322827
if (hparams.is_global_attn(il) == false) {

tools/mtmd/mtmd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ struct mtmd_context {
175175

176176
clip_context_params ctx_clip_params {
177177
/* use_gpu */ ctx_params.use_gpu,
178-
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
178+
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
179179
/* image_min_tokens */ ctx_params.image_min_tokens,
180180
/* image_max_tokens */ ctx_params.image_max_tokens,
181181
/* warmup */ ctx_params.warmup,

0 commit comments

Comments
 (0)