From c5ac7763f13ec33b5710dd7556e88b884c8ab5f4 Mon Sep 17 00:00:00 2001 From: YaelGitAccount Date: Tue, 18 Nov 2025 13:08:24 +0200 Subject: [PATCH 1/8] mtmd : add Eagle2-VL vision support --- convert_hf_to_gguf.py | 197 +++++++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 1 + tools/mtmd/clip-impl.h | 2 + tools/mtmd/clip.cpp | 93 ++++++++++++++++++ 4 files changed, 292 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0cc3df0975f..2b76f4faebd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3811,6 +3811,201 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("Eagle2_VLForConditionalGeneration","Eagle2_5_VLForConditionalGeneration") +class Eagle2VLVisionModel(MmprojModel): + """ + Dedicated Eagle2-VL mmproj converter. + + Responsibilities: + - Emit a distinct projector_type (eagle2vl) and vision metadata (image/patch size, mean/std, eps, block_count, merge size). + - Perform Eagle2-specific layout normalization during conversion in modify_tensors (in a later follow-up). + The C++ runtime will assume canonical layout and must not include Eagle2-specific transposes or hacks. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.hparams_vision is not None: + # Prefer overrides from vision_config when present + vc = self.get_vision_config() or {} + if "image_size" in vc and "image_size" not in self.hparams_vision: + self.hparams_vision["image_size"] = vc["image_size"] + if "patch_size" in vc and "patch_size" not in self.hparams_vision: + self.hparams_vision["patch_size"] = vc["patch_size"] + if "image_mean" in vc and "image_mean" not in self.hparams_vision: + self.hparams_vision["image_mean"] = vc["image_mean"] + if "image_std" in vc and "image_std" not in self.hparams_vision: + self.hparams_vision["image_std"] = vc["image_std"] + + # Normalize common aliases + if "num_heads" in self.hparams_vision and "num_attention_heads" not in self.hparams_vision: + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + if "depth" in self.hparams_vision and "num_hidden_layers" not in self.hparams_vision: + self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") + + def set_gguf_parameters(self): + # Base writes general vision fields (embedding/feed-forward/heads when available) + super().set_gguf_parameters() + + # Projector type: Eagle2-VL + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.EAGLE2VL) + + # Vision attention layernorm eps: use config if available, else 1e-6 + eps = 1e-6 + if isinstance(self.global_config, dict): + eps = self.global_config.get("rms_norm_eps", eps) + self.gguf_writer.add_vision_attention_layernorm_eps(eps) + + # 2x2 spatial merge + self.gguf_writer.add_vision_spatial_merge_size(2) + + # Mirror Qwen2VL-style image metadata if provided; do not guess + hpv = self.hparams_vision or {} + + img_sz = hpv.get("image_size") + if isinstance(img_sz, (list, tuple)) and len(img_sz) > 0: + img_sz = img_sz[0] + if isinstance(img_sz, int): + self.gguf_writer.add_vision_image_size(img_sz) + + patch_sz = hpv.get("patch_size") + if isinstance(patch_sz, (list, tuple)) and len(patch_sz) > 0: + patch_sz = patch_sz[0] + if isinstance(patch_sz, int): + self.gguf_writer.add_vision_patch_size(patch_sz) + + blk_cnt = hpv.get("num_hidden_layers") + if isinstance(blk_cnt, int): + self.gguf_writer.add_vision_block_count(blk_cnt) + + img_mean = hpv.get("image_mean") + if isinstance(img_mean, (list, tuple)) and len(img_mean) > 0: + self.gguf_writer.add_vision_image_mean(list(img_mean)) + + img_std = hpv.get("image_std") + if isinstance(img_std, (list, tuple)) and len(img_std) > 0: + self.gguf_writer.add_vision_image_std(list(img_std)) + + # Note: + # Eagle2-specific tensor layout normalization (mlp1 → mm.*, QKV split, Conv3D → Conv2D) + # will live here in Eagle2VLVisionModel.modify_tensors() in a follow-up. + # The C++ builder will assume canonical weights and perform zero ad-hoc transposes. + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # 1) Name prefix normalization + if name.startswith("vision_model.vision_model."): + name = name.replace("vision_model.vision_model.", "model.vision_model.", 1) + # Eagle2 HF projector tensors often live under "multi_modal_projector.*" + # Normalize that prefix away so we can match on "mlp1.*" below + if name.startswith("multi_modal_projector."): + name = name.replace("multi_modal_projector.", "", 1) + + # 2) Skip clearly non-vision towers + if name.startswith("audio") or name.startswith("talker") or name.startswith("token2wav"): + return [] + + # 2.1) Skip classification / pooling head tensors under vision model head + # Example tensors to skip: + # - vision_model.vision_model.head.attention.in_proj_weight + # - vision_model.vision_model.head.mlp.fc1.weight + # - vision_model.vision_model.head.probe + if ".head." in name: + return [] + + # 3) Projector MLP remap: mlp1 -> mm indices (match 'mlp1.' anywhere, not just at start) + mlp_pos = name.find("mlp1.") + if mlp_pos != -1: + mlp_suffix = name[mlp_pos + len("mlp1."):] + # Skip LayerNorm (mlp1.0.*) + if mlp_suffix.startswith("0."): + return [] + # Map first Linear (mlp1.1.*) -> mm.0.* + if mlp_suffix.startswith("1."): + new_name = "mm.0." + mlp_suffix[2:] + if new_name.endswith(".weight"): + # Canonicalize to [n_in, n_out]; detect and transpose only if needed. + if data_torch.ndim == 2: + d0, d1 = int(data_torch.shape[0]), int(data_torch.shape[1]) + # Expected input width after 2x2 merge = vit_hidden * 4 (if available) + vit_hidden = (self.hparams_vision or {}).get("hidden_size") + expected_in = vit_hidden * 4 if isinstance(vit_hidden, int) and vit_hidden > 0 else None + expected_out = getattr(self, "n_embd_text", None) + # Strong orientation rule: if [out, in] = [n_embd_text, 4*hidden] -> transpose + if isinstance(expected_in, int) and isinstance(expected_out, int) and expected_in > 0 and expected_out > 0: + if d0 == expected_out and d1 == expected_in: + data_torch = data_torch.transpose(-1, -2) + elif d0 == expected_in and d1 == expected_out: + pass # already [n_in, n_out] + else: + # fall through to heuristic rules below + pass + if isinstance(expected_in, int) and expected_in > 0: + if d0 == expected_in: + pass # already canonical [n_in, n_out] + elif d1 == expected_in: + data_torch = data_torch.transpose(-1, -2) + else: + # Fallback: choose orientation that puts larger dim on axis 0 + if d0 < d1: + data_torch = data_torch.transpose(-1, -2) + else: + # Fallback when vit_hidden is unknown: assume PyTorch [out, in] and transpose if in > out + if d0 < d1: + data_torch = data_torch.transpose(-1, -2) + return [(new_name, data_torch)] + return [(new_name, data_torch)] + # Map second Linear (mlp1.3.*) -> mm.2.* + if mlp_suffix.startswith("3."): + new_name = "mm.2." + mlp_suffix[2:] + if new_name.endswith(".weight"): + # Canonicalize to [n_in, n_out] for the second layer. + # Here expected_in == expected_out == text emb dim; if ambiguous, leave as-is (often square). + if data_torch.ndim == 2: + d0, d1 = int(data_torch.shape[0]), int(data_torch.shape[1]) + expected = getattr(self, "n_embd_text", None) + if isinstance(expected, int) and expected > 0: + if d0 == expected and d1 == expected: + pass + elif d1 == expected and d0 != expected: + data_torch = data_torch.transpose(-1, -2) + return [(new_name, data_torch)] + return [(new_name, data_torch)] + # Unknown mlp1 component -> skip + return [] + + # 4) Fused QKV split + if ".qkv." in name: + # Determine split size from leading dimension + c3 = data_torch.shape[0] + assert c3 % 3 == 0, f"qkv tensor leading dim must be divisible by 3, got {c3} for {name}" + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: 2 * c] + wv = data_torch[2 * c:] + return [ + (self.map_tensor_name(name.replace("qkv", "q")), wq), + (self.map_tensor_name(name.replace("qkv", "k")), wk), + (self.map_tensor_name(name.replace("qkv", "v")), wv), + ] + + # 5) Conv3D patch embed -> two Conv2D kernels + if name.endswith("patch_embed.proj.weight") and data_torch.ndim == 5: + c_out, c_in, kt, kh, kw = data_torch.shape + del c_out, c_in, kh, kw # unused + assert kt == 2, "Current implementation only supports temporal_patch_size of 2" + base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" + return [ + (base, data_torch[:, :, 0, ...]), + (base + ".1", data_torch[:, :, 1, ...]), + ] + + # 6) Default mapping for remaining Eagle2 vision tensors + if name.startswith("vision_model.") or name.startswith("model.vision_model.") or name.startswith("visual."): + return [(self.map_tensor_name(name), data_torch)] + + # Not an Eagle2 vision tensor -> skip + return [] + @ModelBase.register("Qwen2_5OmniModel") class Qwen25OmniModel(Qwen2VLVisionModel): has_vision_encoder = True @@ -9973,7 +10168,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.startswith('model.language_model.'): name = name.replace('model.language_model.', 'model.') - elif name.startswith('language_model.'): + if name.startswith("mlp1."): name = name.replace('language_model.', '') return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1cd0efad4a8..9c43e8dd8fb 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3238,6 +3238,7 @@ class VisionProjectorType: LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" + EAGLE2VL = "eagle2vl" # Items here are (block size, type size) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cd47865bf4a..0bfd497b110 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -141,6 +141,7 @@ enum projector_type { PROJECTOR_TYPE_GLM_EDGE, PROJECTOR_TYPE_QWEN2VL, PROJECTOR_TYPE_QWEN3VL, + PROJECTOR_TYPE_EAGLE2VL, PROJECTOR_TYPE_GEMMA3, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, @@ -168,6 +169,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"}, { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, { PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"}, + { PROJECTOR_TYPE_EAGLE2VL, "eagle2vl"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index abdb778f7af..01e66c59ca8 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1081,6 +1081,72 @@ struct clip_graph { return gf; } + // Eagle2-VL: normalized ViT with learned absolute position embeddings and 2-layer MLP projector (mm.0, GELU, mm.2) + ggml_cgraph * build_eagle2vl() { + GGML_ASSERT(model.class_embedding == nullptr); + + const int n_pos = n_patches; + + // Use resized learned position embeddings if dynamic resolution is used + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + + // Build input patches via Conv2D, add patch bias if present + ggml_tensor * inp = build_inp(); + + // Vision encoder: use RMS norm per metadata (eps), FFN op per hparams + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_RMS, + hparams.ffn_op, + learned_pos_embd, + nullptr); + + // keep runtime quiet in normal runs; shapes are correct by construction + + // Apply spatial patch merge (e.g., 2x2) before projector if requested + { + const int scale_factor = hparams.n_merge; + if (scale_factor > 1) { + // This returns a 2D tensor shaped [n_embd * scale_factor^2, n_pos / scale_factor^2] + // which matches the expected input width for mm.0 + cur = build_patch_merge_permute(cur, scale_factor); + // merged tokens layout now matches projector input width + } + } + + // 2-layer MLP projector: mm.0 -> GELU -> mm.2 + ggml_tensor * embeddings = cur; + + // projector matmuls assume canonical [n_in, n_out] weights; no runtime transposes + GGML_ASSERT(model.mm_0_w != nullptr); + // ensure projector input is a packed 2D matrix [n_in, n_tokens] + embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + // embeddings are packed [n_in, n_tokens] + // Weights are canonicalized at conversion time to [n_in, n_out]; multiply directly. + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + if (model.mm_0_b) { + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + } + + embeddings = ggml_gelu(ctx0, embeddings); + + GGML_ASSERT(model.mm_2_w != nullptr); + // keep [n_in, n_tokens] layout for the second matmul as well + embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + // Weights are canonicalized at conversion time to [n_in, n_out]; multiply directly. + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + if (model.mm_2_b) { + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; + } + ggml_cgraph * build_minicpmv() { GGML_ASSERT(model.class_embedding == nullptr); const int n_pos = n_patches; @@ -2493,6 +2559,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_qwen3vl(); } break; + case PROJECTOR_TYPE_EAGLE2VL: + { + res = graph.build_eagle2vl(); + } break; case PROJECTOR_TYPE_MINICPMV: { res = graph.build_minicpmv(); @@ -2767,6 +2837,15 @@ struct clip_model_loader { // model-specific params switch (model.proj_type) { + case PROJECTOR_TYPE_EAGLE2VL: + { + // spatial merge (default 2), allow override from metadata if present + hparams.n_merge = 2; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + // set reasonable token limits and warmup like qwen2vl + hparams.set_limit_image_tokens(8, 4096); + hparams.set_warmup_n_tokens(46*46); + } break; case PROJECTOR_TYPE_MINICPMV: { if (hparams.minicpmv_version == 0) { @@ -3048,6 +3127,14 @@ struct clip_model_loader { } model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); } break; + case PROJECTOR_TYPE_EAGLE2VL: + { + // 2-layer MLP projector using mm.0 and mm.2 (normalized at conversion time) + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + } break; case PROJECTOR_TYPE_LDP: { // MobileVLM projection @@ -4287,6 +4374,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_GLM_EDGE: case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_EAGLE2VL: case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution { clip_image_u8 resized_image; @@ -4533,6 +4621,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_EAGLE2VL: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_LLAMA4: @@ -4911,6 +5000,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("patches", patches); } break; case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_EAGLE2VL: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: @@ -5029,6 +5119,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; + case PROJECTOR_TYPE_EAGLE2VL: + // final projector output dim + return ctx->model.mm_2_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } From e35a94b4c9288a836a0c095d7740caba0e6ef8d2 Mon Sep 17 00:00:00 2001 From: Yael Zadok <38328157276@mby.co.il> Date: Tue, 18 Nov 2025 21:13:12 +0200 Subject: [PATCH 2/8] Update convert_hf_to_gguf.py --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2b76f4faebd..b07072d4502 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10168,7 +10168,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.startswith('model.language_model.'): name = name.replace('model.language_model.', 'model.') - if name.startswith("mlp1."): + elif name.startswith('language_model.'): name = name.replace('language_model.', '') return super().modify_tensors(data_torch, name, bid) From 6247fd2c7dc48d68e4ee35d5d1e01ff2867f51aa Mon Sep 17 00:00:00 2001 From: Yael Zadok <38328157276@mby.co.il> Date: Sun, 7 Dec 2025 11:34:49 +0200 Subject: [PATCH 3/8] style: fix indentation in build_eagle2vl --- tools/mtmd/clip.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 01e66c59ca8..2c3abf17553 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1095,13 +1095,12 @@ struct clip_graph { // Vision encoder: use RMS norm per metadata (eps), FFN op per hparams ggml_tensor * cur = build_vit( - inp, n_pos, - NORM_TYPE_RMS, - hparams.ffn_op, - learned_pos_embd, - nullptr); - - // keep runtime quiet in normal runs; shapes are correct by construction + inp, n_pos, + NORM_TYPE_RMS, + hparams.ffn_op, + learned_pos_embd, + nullptr); + // keep runtime quiet in normal runs; shapes are correct by construction // Apply spatial patch merge (e.g., 2x2) before projector if requested { @@ -1114,11 +1113,11 @@ struct clip_graph { } } - // 2-layer MLP projector: mm.0 -> GELU -> mm.2 + // 2-layer MLP projector: mm.0 -> GELU -> mm.2 ggml_tensor * embeddings = cur; - // projector matmuls assume canonical [n_in, n_out] weights; no runtime transposes - GGML_ASSERT(model.mm_0_w != nullptr); + // projector matmuls assume canonical [n_in, n_out] weights; no runtime transposes + GGML_ASSERT(model.mm_0_w != nullptr); // ensure projector input is a packed 2D matrix [n_in, n_tokens] embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); From 0dda80fabfab65d8066475ba0a29c1fdb3c4f98e Mon Sep 17 00:00:00 2001 From: Yael Zadok <192727233+YaelGitAccount@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:06:53 +0200 Subject: [PATCH 4/8] Refactor MLP projector to use build_ffn helper --- tools/mtmd/clip.cpp | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 2c3abf17553..b2926dba037 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1116,29 +1116,22 @@ struct clip_graph { // 2-layer MLP projector: mm.0 -> GELU -> mm.2 ggml_tensor * embeddings = cur; - // projector matmuls assume canonical [n_in, n_out] weights; no runtime transposes GGML_ASSERT(model.mm_0_w != nullptr); + GGML_ASSERT(model.mm_2_w != nullptr); + // ensure projector input is a packed 2D matrix [n_in, n_tokens] embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); - // embeddings are packed [n_in, n_tokens] - // Weights are canonicalized at conversion time to [n_in, n_out]; multiply directly. - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - if (model.mm_0_b) { - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - } - - embeddings = ggml_gelu(ctx0, embeddings); - GGML_ASSERT(model.mm_2_w != nullptr); - // keep [n_in, n_tokens] layout for the second matmul as well - embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); - embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); - // Weights are canonicalized at conversion time to [n_in, n_out]; multiply directly. - embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); - if (model.mm_2_b) { - embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - } + // Use shared FFN helper: Linear(mm.0) -> GELU -> Linear(mm.2) + embeddings = build_ffn( + embeddings, + model.mm_0_w, model.mm_0_b, + /*gate=*/nullptr, /*gate_b=*/nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + /*il=*/0 + ); // build the graph ggml_build_forward_expand(gf, embeddings); From eb83923bbe0b4dcd5f7e0aed3f931b695342e18b Mon Sep 17 00:00:00 2001 From: Yael Zadok <192727233+YaelGitAccount@users.noreply.github.com> Date: Sun, 7 Dec 2025 16:13:06 +0200 Subject: [PATCH 5/8] add projector LayerNorm weight mapping (converter side) --- convert_hf_to_gguf.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b07072d4502..bfda7a745b0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3912,13 +3912,22 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if ".head." in name: return [] - # 3) Projector MLP remap: mlp1 -> mm indices (match 'mlp1.' anywhere, not just at start) + # 3) Projector MLP remap: map Eagle2-VL mlp1.* -> mm_input_norm/mm.0/mm.2 mlp_pos = name.find("mlp1.") if mlp_pos != -1: mlp_suffix = name[mlp_pos + len("mlp1."):] - # Skip LayerNorm (mlp1.0.*) + # Map Eagle2-VL projector LayerNorm: + # mlp1.0.{weight,bias} correspond to the input LayerNorm of the projector + # (structure: LayerNorm → Linear → GELU → Linear). + # The C++ runtime applies ggml_norm + scale/shift, so we store γ/β as mm_input_norm_w/b. if mlp_suffix.startswith("0."): + if mlp_suffix.endswith("weight"): + return [("mm_input_norm_w", data_torch)] + if mlp_suffix.endswith("bias"): + return [("mm_input_norm_b", data_torch)] + # any other subfield under mlp1.0.* (rare) -> skip return [] + # Map first Linear (mlp1.1.*) -> mm.0.* if mlp_suffix.startswith("1."): new_name = "mm.0." + mlp_suffix[2:] From f4af853e46b3bf749f35a2c2c5148a930d84c420 Mon Sep 17 00:00:00 2001 From: Yael Zadok <192727233+YaelGitAccount@users.noreply.github.com> Date: Sun, 7 Dec 2025 16:36:02 +0200 Subject: [PATCH 6/8] Update convert_hf_to_gguf.py From b4f660f4706ccb4d3baaeddeff8f16534cafb05a Mon Sep 17 00:00:00 2001 From: Yael Zadok <192727233+YaelGitAccount@users.noreply.github.com> Date: Sun, 7 Dec 2025 17:01:14 +0200 Subject: [PATCH 7/8] apply projector LayerNorm at runtime --- tools/mtmd/clip.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index b2926dba037..4ac88c82381 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1113,7 +1113,7 @@ struct clip_graph { } } - // 2-layer MLP projector: mm.0 -> GELU -> mm.2 + // 2-layer MLP projector: LayerNorm (mlp1.0) -> mm.0 -> GELU -> mm.2 ggml_tensor * embeddings = cur; GGML_ASSERT(model.mm_0_w != nullptr); @@ -1123,6 +1123,17 @@ struct clip_graph { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + // Apply projector input LayerNorm (mlp1.0) with default eps = 1e-5 + if (model.mm_input_norm_w || model.mm_input_norm_b) { + embeddings = build_norm( + embeddings, + model.mm_input_norm_w, + model.mm_input_norm_b, + NORM_TYPE_NORMAL, + 1e-5f, + /*il=*/-1); + } + // Use shared FFN helper: Linear(mm.0) -> GELU -> Linear(mm.2) embeddings = build_ffn( embeddings, @@ -3121,7 +3132,10 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_EAGLE2VL: { - // 2-layer MLP projector using mm.0 and mm.2 (normalized at conversion time) + // projector input LayerNorm (mlp1.0.{weight,bias}) + model.mm_input_norm_w = get_tensor("mm_input_norm_w", false); + model.mm_input_norm_b = get_tensor("mm_input_norm_b", false); + // 2-layer MLP projector using mm.0 and mm.2 model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); From e51fd1fcb52ee0bd7e305f997c159268decd8caa Mon Sep 17 00:00:00 2001 From: Yael Zadok <192727233+YaelGitAccount@users.noreply.github.com> Date: Tue, 9 Dec 2025 12:41:24 +0200 Subject: [PATCH 8/8] eagle2-vl: drop Conv3D patch embed handling --- convert_hf_to_gguf.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bfda7a745b0..1f939d9e766 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3885,7 +3885,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_image_std(list(img_std)) # Note: - # Eagle2-specific tensor layout normalization (mlp1 → mm.*, QKV split, Conv3D → Conv2D) + # Eagle2-specific tensor layout normalization (mlp1 → mm.*, QKV split) # will live here in Eagle2VLVisionModel.modify_tensors() in a follow-up. # The C++ builder will assume canonical weights and perform zero ad-hoc transposes. @@ -3997,18 +3997,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter (self.map_tensor_name(name.replace("qkv", "v")), wv), ] - # 5) Conv3D patch embed -> two Conv2D kernels - if name.endswith("patch_embed.proj.weight") and data_torch.ndim == 5: - c_out, c_in, kt, kh, kw = data_torch.shape - del c_out, c_in, kh, kw # unused - assert kt == 2, "Current implementation only supports temporal_patch_size of 2" - base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" - return [ - (base, data_torch[:, :, 0, ...]), - (base + ".1", data_torch[:, :, 1, ...]), - ] - - # 6) Default mapping for remaining Eagle2 vision tensors + # 5) Default mapping for remaining Eagle2 vision tensors if name.startswith("vision_model.") or name.startswith("model.vision_model.") or name.startswith("visual."): return [(self.map_tensor_name(name), data_torch)]