Skip to content

Commit 3fcfc3a

Browse files
authored
Merge pull request #3 from bluebread/sf/deepseek-ocr
Fixed get_rel_pos & add_rel_pos_inplace operator
2 parents 86f111f + effe669 commit 3fcfc3a

File tree

3 files changed

+114
-94
lines changed

3 files changed

+114
-94
lines changed

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
11061106

11071107
if (!weight_before_ffn) {
11081108
experts = ggml_mul(ctx0, experts, weights);
1109-
cb(cur, "ffn_moe_weighted", il);
1109+
cb(experts, "ffn_moe_weighted", il);
11101110
}
11111111

11121112
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };

src/models/deepseek2.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
7474
cur = build_attn(inp_attn,
7575
model.layers[il].wo, NULL,
7676
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
77+
cb(cur, "attn_out", il);
7778
}
7879
else {
7980
ggml_tensor * q = NULL;

tools/mtmd/clip.cpp

Lines changed: 112 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,9 @@ struct clip_graph {
667667
constexpr int _depth = 12;
668668
constexpr int enc_n_heads = 12;
669669
constexpr int enc_d_heads = enc_n_embd / enc_n_heads;
670-
constexpr int _prompt_n_embd = 256;
670+
// constexpr int _prompt_n_embd = 256;
671671
constexpr int enc_patch_size = 16;
672-
constexpr int _window_size = 14;
672+
// constexpr int _window_size = 14;
673673

674674
const int enc_n_patches = enc_image_size / enc_patch_size; // 64
675675

@@ -739,13 +739,14 @@ struct clip_graph {
739739

740740
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads);
741741

742-
struct ggml_tensor * rel_w = ggml_cont(
743-
ctx0,
744-
ggml_permute(ctx0, ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0,
745-
2, 1, 3));
742+
struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0,
743+
ggml_mul_mat(ctx0,
744+
rw,
745+
ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))),
746+
0, 2, 1, 3));
746747
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
747748

748-
struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W);
749+
struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);
749750

750751
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
751752

@@ -835,7 +836,7 @@ struct clip_graph {
835836

836837
ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny));
837838

838-
ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1);
839+
ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1);
839840

840841
// torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
841842
global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3));
@@ -1533,7 +1534,7 @@ struct clip_graph {
15331534
return gf;
15341535
}
15351536

1536-
ggml_tensor * build_dp_ocr_clip(ggml_tensor * inpL, ggml_tensor * patch_embeds) {
1537+
ggml_tensor * build_dp_ocr_clip(ggml_tensor * patch_embeds) {
15371538
GGML_ASSERT(model.class_embedding != nullptr);
15381539
GGML_ASSERT(model.position_embeddings != nullptr);
15391540

@@ -2466,103 +2467,119 @@ struct clip_graph {
24662467
return inpL;
24672468
}
24682469

2469-
// attn: [k_h*k_w, q_h*q_w]
2470-
// rel_h: [q_h, q_w, k_h]
2471-
// rel_w: [q_h, q_w, k_w]
2472-
2473-
static ggml_tensor * add_rel_pos_inplace(
2474-
ggml_context * ctx,
2475-
ggml_tensor * attn,
2476-
ggml_tensor * rel_w,
2477-
ggml_tensor * rel_h,
2478-
int q_size
2479-
) {
2480-
2481-
ggml_tensor *attn_4d =
2482-
ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]);
2483-
2484-
ggml_tensor *rel_h_4d =
2485-
ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]);
2486-
2487-
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
2488-
2489-
ggml_tensor *rel_w_4d =
2490-
ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]);
2491-
2492-
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
2493-
2494-
ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep));
2495-
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
2496-
2497-
2498-
return result;
2499-
}
2500-
2501-
2502-
static ggml_tensor * get_rel_pos(
2503-
ggml_context * ctx,
2504-
ggml_tensor * rel_pos, // [L, C]
2505-
int q_size,
2506-
int k_size
2507-
) {
2508-
2509-
const auto dtype = rel_pos->type;
2510-
2511-
const int64_t L = rel_pos->ne[0]; // length
2512-
const int64_t C = rel_pos->ne[1]; // channels
2513-
2514-
// -------------------------------------------------
2515-
// 1) q_idx ← arange(0..q_size-1) [q_size]
2516-
// 2) k_idx ← arange(0..k_size-1) [k_size]
2517-
// -------------------------------------------------
2470+
// attn: [q_h*q_w, k_h*k_w]
2471+
// rel_h: [q_h, q_w, k_h]
2472+
// rel_w: [q_h, q_w, k_w]
25182473

2474+
static ggml_tensor * add_rel_pos_inplace(
2475+
ggml_context * ctx,
2476+
ggml_tensor * attn,
2477+
ggml_tensor * rel_w,
2478+
ggml_tensor * rel_h
2479+
) {
2480+
const int k_w = rel_w->ne[0];
2481+
const int k_h = rel_h->ne[0];
2482+
const int q_w = rel_h->ne[1];
2483+
const int q_h = rel_h->ne[2];
25192484

2520-
ggml_tensor * q_coord = ggml_cast(ctx,
2521-
ggml_arange(ctx, 0.0f, static_cast<float>(q_size), 1.0f),
2522-
GGML_TYPE_F32); // [q_size]
2523-
ggml_tensor * k_coord = ggml_cast(ctx,
2524-
ggml_arange(ctx, 0.0f, static_cast<float>(k_size), 1.0f),
2525-
GGML_TYPE_F32); // [k_size]
2485+
GGML_ASSERT(q_w == rel_w->ne[1]);
2486+
GGML_ASSERT(q_h == rel_w->ne[2]);
2487+
GGML_ASSERT(attn->ne[0] == k_h*k_w);
2488+
GGML_ASSERT(attn->ne[1] == q_h*q_w);
25262489

2527-
ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size);
2528-
q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size]
2490+
ggml_tensor *attn_4d = ggml_reshape_4d(ctx, attn, k_w, k_h, attn->ne[1], attn->ne[2]);
25292491

2530-
// broadcast reshape:
2531-
k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size]
2532-
k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size]
2492+
ggml_tensor *rel_h_4d = ggml_reshape_4d(ctx, rel_h, 1, k_h, attn->ne[1], attn->ne[2]);
25332493

2534-
// -------------------------------------------------
2535-
// relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling
2536-
// -------------------------------------------------
2537-
rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size]
2494+
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
25382495

2539-
rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast<float>(k_size) - 1.0f); // [q_size, k_size]
2496+
ggml_tensor *rel_w_4d = ggml_reshape_4d(ctx, rel_w, k_w, 1, attn->ne[1], attn->ne[2]);
25402497

2541-
// -------------------------------------------------
2542-
// clamp to [0, L-1] and cast to int32 (for ggml_get_rows)
2543-
// -------------------------------------------------
2498+
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
25442499

2545-
ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast<float>(L - 1));
2500+
ggml_tensor * result = ggml_add_inplace(ctx, attn_4d, ggml_add_inplace(ctx, rel_h_rep, rel_w_rep));
2501+
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
25462502

2547-
ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size]
25482503

2549-
// flatten to 1D for ggml_get_rows
2550-
const int64_t qk = static_cast<int64_t>(q_size) * static_cast<int64_t>(k_size);
2551-
ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk]
2504+
return result;
2505+
}
25522506

2553-
// -------------------------------------------------
2554-
// Gather from rel_pos → [qk, C]
2555-
// -------------------------------------------------
2556-
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C]
25572507

2558-
// reshape to final output → [q_size, k_size, C]
2559-
ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0],
2560-
q_size,
2561-
k_size);
2508+
static ggml_tensor * get_rel_pos(
2509+
ggml_context * ctx,
2510+
ggml_tensor * rel_pos, // [L, C]
2511+
int q_size,
2512+
int k_size
2513+
) {
2514+
const int64_t C = rel_pos->ne[0]; // channels
2515+
const int64_t L = rel_pos->ne[1]; // length
2516+
2517+
GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L);
2518+
2519+
// -------------------------------------------------
2520+
// 1) q_idx ← arange(0..q_size-1) [q_size]
2521+
// 2) k_idx ← arange(0..k_size-1) [k_size]
2522+
// -------------------------------------------------
2523+
2524+
// ggml_arange always returns FP32 tensor
2525+
ggml_tensor * q_coord = ggml_arange(ctx, 0.0f, static_cast<float>(q_size), 1.0f); // [q_size]
2526+
ggml_tensor * k_coord = ggml_arange(ctx, 0.0f, static_cast<float>(k_size), 1.0f); // [k_size]
2527+
ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, k_size, q_size);
2528+
2529+
// broadcast reshape:
2530+
q_coord = ggml_cont(ctx,
2531+
ggml_repeat(ctx,
2532+
ggml_reshape_2d(ctx, q_coord, 1, q_size), // [q_size, 1]
2533+
rel
2534+
)
2535+
); // [q_size, k_size]
2536+
k_coord = ggml_cont(ctx, ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size]
2537+
2538+
float q_scale = std::max((float)k_size/q_size, 1.0f);
2539+
float k_scale = std::max((float)q_size/k_size, 1.0f);
2540+
2541+
// This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with
2542+
// the original implementation.
2543+
if (q_size != k_size) {
2544+
q_coord = ggml_scale_inplace(ctx, q_coord, q_scale);
2545+
k_coord = ggml_scale_inplace(ctx, k_coord, k_scale);
2546+
}
25622547

2563-
return out; // [q_size, k_size, C]
2564-
}
2548+
// -------------------------------------------------
2549+
// relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling
2550+
// -------------------------------------------------
2551+
2552+
rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size]
2553+
rel = ggml_scale_bias(ctx, rel, 1.0f, (k_size - 1.0f)*k_scale); // [q_size, k_size]
2554+
// Clamp to [0, L-1] range for valid indexing
2555+
rel = ggml_clamp(ctx, rel, 0.0f, static_cast<float>(rel_pos->ne[1] - 1));
2556+
2557+
// -------------------------------------------------
2558+
// clamp to [0, L-1] and cast to int32 (for ggml_get_rows)
2559+
// -------------------------------------------------
2560+
2561+
ggml_tensor * idx_2d = ggml_cast(ctx, rel, GGML_TYPE_I32); // [q_size, k_size]
2562+
2563+
// Gather from rel_pos → [qk, C]
2564+
// -------------------------------------------------
2565+
2566+
// flatten to 1D for ggml_get_rows
2567+
int qk = q_size * k_size;
2568+
ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk]
2569+
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C]
2570+
2571+
// -------------------------------------------------
2572+
// Gather from rel_pos → [qk, C]
2573+
// -------------------------------------------------
2574+
2575+
ggml_tensor * out = ggml_reshape_3d(ctx, gathered, C, k_size, q_size); // [qk, C]
2576+
2577+
2578+
return out; // [q_size, k_size, C]
2579+
}
25652580

2581+
// Implementation based on approach suggested by Acly
2582+
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091
25662583
static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) {
25672584
auto [c, w, h, b] = x->ne;
25682585
// same as
@@ -2583,6 +2600,8 @@ static ggml_tensor * get_rel_pos(
25832600
return x;
25842601
}
25852602

2603+
// Implementation based on approach suggested by Acly
2604+
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091
25862605
static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) {
25872606
int64_t c = x->ne[0];
25882607
// same as
@@ -4978,7 +4997,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
49784997
const int min_num = 2;
49794998
const int max_num = 9;
49804999
const int image_size = params.image_size; // typically 640
4981-
const bool use_thumbnail = true; // mimic python's use_thumbnail
5000+
// const bool use_thumbnail = true; // mimic python's use_thumbnail
49825001

49835002
// original image size
49845003
const int orig_w = original_size.width;

0 commit comments

Comments
 (0)