Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3811,6 +3811,199 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [] # skip other tensors


@ModelBase.register("Eagle2_VLForConditionalGeneration","Eagle2_5_VLForConditionalGeneration")
class Eagle2VLVisionModel(MmprojModel):
"""
Dedicated Eagle2-VL mmproj converter.

Responsibilities:
- Emit a distinct projector_type (eagle2vl) and vision metadata (image/patch size, mean/std, eps, block_count, merge size).
- Perform Eagle2-specific layout normalization during conversion in modify_tensors (in a later follow-up).
The C++ runtime will assume canonical layout and must not include Eagle2-specific transposes or hacks.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams_vision is not None:
# Prefer overrides from vision_config when present
vc = self.get_vision_config() or {}
if "image_size" in vc and "image_size" not in self.hparams_vision:
self.hparams_vision["image_size"] = vc["image_size"]
if "patch_size" in vc and "patch_size" not in self.hparams_vision:
self.hparams_vision["patch_size"] = vc["patch_size"]
if "image_mean" in vc and "image_mean" not in self.hparams_vision:
self.hparams_vision["image_mean"] = vc["image_mean"]
if "image_std" in vc and "image_std" not in self.hparams_vision:
self.hparams_vision["image_std"] = vc["image_std"]

# Normalize common aliases
if "num_heads" in self.hparams_vision and "num_attention_heads" not in self.hparams_vision:
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
if "depth" in self.hparams_vision and "num_hidden_layers" not in self.hparams_vision:
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")

def set_gguf_parameters(self):
# Base writes general vision fields (embedding/feed-forward/heads when available)
super().set_gguf_parameters()

# Projector type: Eagle2-VL
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.EAGLE2VL)

# Vision attention layernorm eps: use config if available, else 1e-6
eps = 1e-6
if isinstance(self.global_config, dict):
eps = self.global_config.get("rms_norm_eps", eps)
self.gguf_writer.add_vision_attention_layernorm_eps(eps)

# 2x2 spatial merge
self.gguf_writer.add_vision_spatial_merge_size(2)

# Mirror Qwen2VL-style image metadata if provided; do not guess
hpv = self.hparams_vision or {}

img_sz = hpv.get("image_size")
if isinstance(img_sz, (list, tuple)) and len(img_sz) > 0:
img_sz = img_sz[0]
if isinstance(img_sz, int):
self.gguf_writer.add_vision_image_size(img_sz)

patch_sz = hpv.get("patch_size")
if isinstance(patch_sz, (list, tuple)) and len(patch_sz) > 0:
patch_sz = patch_sz[0]
if isinstance(patch_sz, int):
self.gguf_writer.add_vision_patch_size(patch_sz)

blk_cnt = hpv.get("num_hidden_layers")
if isinstance(blk_cnt, int):
self.gguf_writer.add_vision_block_count(blk_cnt)

img_mean = hpv.get("image_mean")
if isinstance(img_mean, (list, tuple)) and len(img_mean) > 0:
self.gguf_writer.add_vision_image_mean(list(img_mean))

img_std = hpv.get("image_std")
if isinstance(img_std, (list, tuple)) and len(img_std) > 0:
self.gguf_writer.add_vision_image_std(list(img_std))

# Note:
# Eagle2-specific tensor layout normalization (mlp1 → mm.*, QKV split)
# will live here in Eagle2VLVisionModel.modify_tensors() in a follow-up.
# The C++ builder will assume canonical weights and perform zero ad-hoc transposes.

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# 1) Name prefix normalization
if name.startswith("vision_model.vision_model."):
name = name.replace("vision_model.vision_model.", "model.vision_model.", 1)
# Eagle2 HF projector tensors often live under "multi_modal_projector.*"
# Normalize that prefix away so we can match on "mlp1.*" below
if name.startswith("multi_modal_projector."):
name = name.replace("multi_modal_projector.", "", 1)

# 2) Skip clearly non-vision towers
if name.startswith("audio") or name.startswith("talker") or name.startswith("token2wav"):
return []

# 2.1) Skip classification / pooling head tensors under vision model head
# Example tensors to skip:
# - vision_model.vision_model.head.attention.in_proj_weight
# - vision_model.vision_model.head.mlp.fc1.weight
# - vision_model.vision_model.head.probe
if ".head." in name:
return []

# 3) Projector MLP remap: map Eagle2-VL mlp1.* -> mm_input_norm/mm.0/mm.2
mlp_pos = name.find("mlp1.")
if mlp_pos != -1:
mlp_suffix = name[mlp_pos + len("mlp1."):]
# Map Eagle2-VL projector LayerNorm:
# mlp1.0.{weight,bias} correspond to the input LayerNorm of the projector
# (structure: LayerNorm → Linear → GELU → Linear).
# The C++ runtime applies ggml_norm + scale/shift, so we store γ/β as mm_input_norm_w/b.
if mlp_suffix.startswith("0."):
if mlp_suffix.endswith("weight"):
return [("mm_input_norm_w", data_torch)]
if mlp_suffix.endswith("bias"):
return [("mm_input_norm_b", data_torch)]
# any other subfield under mlp1.0.* (rare) -> skip
return []

# Map first Linear (mlp1.1.*) -> mm.0.*
if mlp_suffix.startswith("1."):
new_name = "mm.0." + mlp_suffix[2:]
if new_name.endswith(".weight"):
# Canonicalize to [n_in, n_out]; detect and transpose only if needed.
if data_torch.ndim == 2:
d0, d1 = int(data_torch.shape[0]), int(data_torch.shape[1])
# Expected input width after 2x2 merge = vit_hidden * 4 (if available)
vit_hidden = (self.hparams_vision or {}).get("hidden_size")
expected_in = vit_hidden * 4 if isinstance(vit_hidden, int) and vit_hidden > 0 else None
expected_out = getattr(self, "n_embd_text", None)
# Strong orientation rule: if [out, in] = [n_embd_text, 4*hidden] -> transpose
if isinstance(expected_in, int) and isinstance(expected_out, int) and expected_in > 0 and expected_out > 0:
if d0 == expected_out and d1 == expected_in:
data_torch = data_torch.transpose(-1, -2)
elif d0 == expected_in and d1 == expected_out:
pass # already [n_in, n_out]
else:
# fall through to heuristic rules below
pass
if isinstance(expected_in, int) and expected_in > 0:
if d0 == expected_in:
pass # already canonical [n_in, n_out]
elif d1 == expected_in:
data_torch = data_torch.transpose(-1, -2)
else:
# Fallback: choose orientation that puts larger dim on axis 0
if d0 < d1:
data_torch = data_torch.transpose(-1, -2)
else:
# Fallback when vit_hidden is unknown: assume PyTorch [out, in] and transpose if in > out
if d0 < d1:
data_torch = data_torch.transpose(-1, -2)
return [(new_name, data_torch)]
return [(new_name, data_torch)]
# Map second Linear (mlp1.3.*) -> mm.2.*
if mlp_suffix.startswith("3."):
new_name = "mm.2." + mlp_suffix[2:]
if new_name.endswith(".weight"):
# Canonicalize to [n_in, n_out] for the second layer.
# Here expected_in == expected_out == text emb dim; if ambiguous, leave as-is (often square).
if data_torch.ndim == 2:
d0, d1 = int(data_torch.shape[0]), int(data_torch.shape[1])
expected = getattr(self, "n_embd_text", None)
if isinstance(expected, int) and expected > 0:
if d0 == expected and d1 == expected:
pass
elif d1 == expected and d0 != expected:
data_torch = data_torch.transpose(-1, -2)
return [(new_name, data_torch)]
return [(new_name, data_torch)]
# Unknown mlp1 component -> skip
return []

# 4) Fused QKV split
if ".qkv." in name:
# Determine split size from leading dimension
c3 = data_torch.shape[0]
assert c3 % 3 == 0, f"qkv tensor leading dim must be divisible by 3, got {c3} for {name}"
c = c3 // 3
wq = data_torch[:c]
wk = data_torch[c: 2 * c]
wv = data_torch[2 * c:]
return [
(self.map_tensor_name(name.replace("qkv", "q")), wq),
(self.map_tensor_name(name.replace("qkv", "k")), wk),
(self.map_tensor_name(name.replace("qkv", "v")), wv),
]

# 5) Default mapping for remaining Eagle2 vision tensors
if name.startswith("vision_model.") or name.startswith("model.vision_model.") or name.startswith("visual."):
return [(self.map_tensor_name(name), data_torch)]

# Not an Eagle2 vision tensor -> skip
return []

@ModelBase.register("Qwen2_5OmniModel")
class Qwen25OmniModel(Qwen2VLVisionModel):
has_vision_encoder = True
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3238,6 +3238,7 @@ class VisionProjectorType:
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
EAGLE2VL = "eagle2vl"


# Items here are (block size, type size)
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ enum projector_type {
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_QWEN2VL,
PROJECTOR_TYPE_QWEN3VL,
PROJECTOR_TYPE_EAGLE2VL,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
Expand Down Expand Up @@ -168,6 +169,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"},
{ PROJECTOR_TYPE_EAGLE2VL, "eagle2vl"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
Expand Down
99 changes: 99 additions & 0 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,75 @@ struct clip_graph {
return gf;
}

// Eagle2-VL: normalized ViT with learned absolute position embeddings and 2-layer MLP projector (mm.0, GELU, mm.2)
ggml_cgraph * build_eagle2vl() {
GGML_ASSERT(model.class_embedding == nullptr);

const int n_pos = n_patches;

// Use resized learned position embeddings if dynamic resolution is used
ggml_tensor * learned_pos_embd = resize_position_embeddings();

// Build input patches via Conv2D, add patch bias if present
ggml_tensor * inp = build_inp();

// Vision encoder: use RMS norm per metadata (eps), FFN op per hparams
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_RMS,
hparams.ffn_op,
learned_pos_embd,
nullptr);
// keep runtime quiet in normal runs; shapes are correct by construction

// Apply spatial patch merge (e.g., 2x2) before projector if requested
{
const int scale_factor = hparams.n_merge;
if (scale_factor > 1) {
// This returns a 2D tensor shaped [n_embd * scale_factor^2, n_pos / scale_factor^2]
// which matches the expected input width for mm.0
cur = build_patch_merge_permute(cur, scale_factor);
// merged tokens layout now matches projector input width
}
}

// 2-layer MLP projector: LayerNorm (mlp1.0) -> mm.0 -> GELU -> mm.2
ggml_tensor * embeddings = cur;

GGML_ASSERT(model.mm_0_w != nullptr);
GGML_ASSERT(model.mm_2_w != nullptr);

// ensure projector input is a packed 2D matrix [n_in, n_tokens]
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
embeddings = ggml_cont_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);

// Apply projector input LayerNorm (mlp1.0) with default eps = 1e-5
if (model.mm_input_norm_w || model.mm_input_norm_b) {
embeddings = build_norm(
embeddings,
model.mm_input_norm_w,
model.mm_input_norm_b,
NORM_TYPE_NORMAL,
1e-5f,
/*il=*/-1);
}

// Use shared FFN helper: Linear(mm.0) -> GELU -> Linear(mm.2)
embeddings = build_ffn(
embeddings,
model.mm_0_w, model.mm_0_b,
/*gate=*/nullptr, /*gate_b=*/nullptr,
model.mm_2_w, model.mm_2_b,
FFN_GELU,
/*il=*/0
);

// build the graph
ggml_build_forward_expand(gf, embeddings);

return gf;
}

ggml_cgraph * build_minicpmv() {
GGML_ASSERT(model.class_embedding == nullptr);
const int n_pos = n_patches;
Expand Down Expand Up @@ -2493,6 +2562,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_qwen3vl();
} break;
case PROJECTOR_TYPE_EAGLE2VL:
{
res = graph.build_eagle2vl();
} break;
case PROJECTOR_TYPE_MINICPMV:
{
res = graph.build_minicpmv();
Expand Down Expand Up @@ -2767,6 +2840,15 @@ struct clip_model_loader {

// model-specific params
switch (model.proj_type) {
case PROJECTOR_TYPE_EAGLE2VL:
{
// spatial merge (default 2), allow override from metadata if present
hparams.n_merge = 2;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
// set reasonable token limits and warmup like qwen2vl
hparams.set_limit_image_tokens(8, 4096);
hparams.set_warmup_n_tokens(46*46);
} break;
case PROJECTOR_TYPE_MINICPMV:
{
if (hparams.minicpmv_version == 0) {
Expand Down Expand Up @@ -3048,6 +3130,17 @@ struct clip_model_loader {
}
model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
} break;
case PROJECTOR_TYPE_EAGLE2VL:
{
// projector input LayerNorm (mlp1.0.{weight,bias})
model.mm_input_norm_w = get_tensor("mm_input_norm_w", false);
model.mm_input_norm_b = get_tensor("mm_input_norm_b", false);
// 2-layer MLP projector using mm.0 and mm.2
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
} break;
case PROJECTOR_TYPE_LDP:
{
// MobileVLM projection
Expand Down Expand Up @@ -4287,6 +4380,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str

case PROJECTOR_TYPE_GLM_EDGE:
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_EAGLE2VL:
case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution
{
clip_image_u8 resized_image;
Expand Down Expand Up @@ -4533,6 +4627,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_EAGLE2VL:
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_LLAMA4:
Expand Down Expand Up @@ -4911,6 +5006,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
set_input_i32("patches", patches);
} break;
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_EAGLE2VL:
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_QWEN2A:
Expand Down Expand Up @@ -5029,6 +5125,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];
case PROJECTOR_TYPE_EAGLE2VL:
// final projector output dim
return ctx->model.mm_2_w->ne[1];
default:
GGML_ABORT("Unknown projector type");
}
Expand Down
Loading