Skip to content

Commit a8ad535

Browse files
Working till 01-12-2025 with around 80% accuracy
1 parent 22c8c3c commit a8ad535

File tree

13 files changed

+1496
-150
lines changed

13 files changed

+1496
-150
lines changed

convert_hf_to_gguf.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ def print_registered_models(cls):
643643
@classmethod
644644
def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type[ModelBase]:
645645
try:
646+
print(cls._model_classes)
646647
return cls._model_classes[model_type][arch]
647648
except KeyError:
648649
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
@@ -4457,6 +4458,48 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
44574458
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
44584459
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
44594460

4461+
@ModelBase.register("Phi3VForCausalLM")
4462+
class Phi3VisionModel(Phi3MiniModel):
4463+
"""
4464+
GGUF converter for Phi-3 Vision (Text Part Only).
4465+
4466+
This strips out the vision encoder weights and metadata, creating a
4467+
standard Phi-3 GGUF file that can be paired with an external mmproj file.
4468+
"""
4469+
4470+
# CRITICAL: Use PHI3, not PHI3_VISION.
4471+
# This tells llama.cpp to treat this as a standard text model.
4472+
model_arch = gguf.MODEL_ARCH.PHI3
4473+
4474+
def set_vocab(self):
4475+
return super().set_vocab()
4476+
4477+
def set_gguf_parameters(self):
4478+
# Only write standard text model parameters (context length, embedding size, etc.)
4479+
super().set_gguf_parameters()
4480+
4481+
def generate_extra_tensors(self):
4482+
# This handles the 'su' RoPE scaling factors (long/short) defined in Phi3MiniModel
4483+
yield from super().generate_extra_tensors()
4484+
4485+
def modify_tensors(
4486+
self,
4487+
data_torch: Tensor,
4488+
name: str,
4489+
bid: int | None,
4490+
) -> Iterable[tuple[str, Tensor]]:
4491+
4492+
# The prefix for all vision-related weights in Phi-3-Vision
4493+
VISION_PREFIX = "model.vision_embed_tokens."
4494+
4495+
# 1. If it is a vision tensor, SKIP IT completely.
4496+
# We do not want these weights in the text model file.
4497+
if name.startswith(VISION_PREFIX):
4498+
return
4499+
4500+
# 2. If it is a text tensor, delegate to the standard Phi-3 logic.
4501+
# This handles token_embd, layers, output, norms, etc.
4502+
yield from super().modify_tensors(data_torch, name, bid)
44604503

44614504
@ModelBase.register("PhiMoEForCausalLM")
44624505
class PhiMoeModel(Phi3MiniModel):
@@ -7936,7 +7979,7 @@ def set_gguf_parameters(self):
79367979
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
79377980
self.gguf_writer.add_embedding_length(n_embed)
79387981
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
7939-
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
7982+
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams.get("num_hidden_layers", 0)))
79407983
self.gguf_writer.add_head_count(n_head)
79417984
self.gguf_writer.add_head_count_kv(n_head_kv)
79427985
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
@@ -10143,6 +10186,7 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
1014310186
# maybe we should fallback to text model's arch in that case, since not many models have both
1014410187
text_config = hparams.get("text_config", {})
1014510188
vision_config = hparams.get("vision_config", {})
10189+
print(hparams.get("architectures"))
1014610190
arch = None
1014710191
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
1014810192
arch = arches[0]

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ typedef pthread_t ggml_thread_t;
193193
#include <TargetConditionals.h>
194194
#endif
195195

196+
#include <stdatomic.h>
197+
198+
static _Atomic uint64_t ggml_op_us[GGML_OP_COUNT];
199+
static _Atomic uint64_t ggml_op_calls[GGML_OP_COUNT];
200+
196201
static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
197202
[GGML_TYPE_F32] = {
198203
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
@@ -2864,6 +2869,44 @@ struct ggml_cplan ggml_graph_plan(
28642869
return cplan;
28652870
}
28662871

2872+
// static thread_ret_t ggml_graph_compute_thread(void * data) {
2873+
// struct ggml_compute_state * state = (struct ggml_compute_state *) data;
2874+
// struct ggml_threadpool * tp = state->threadpool;
2875+
//
2876+
// const struct ggml_cgraph * cgraph = tp->cgraph;
2877+
// const struct ggml_cplan * cplan = tp->cplan;
2878+
//
2879+
// set_numa_thread_affinity(state->ith);
2880+
//
2881+
// struct ggml_compute_params params = {
2882+
// /*.ith =*/ state->ith,
2883+
// /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
2884+
// /*.wsize =*/ cplan->work_size,
2885+
// /*.wdata =*/ cplan->work_data,
2886+
// /*.threadpool=*/ tp,
2887+
// };
2888+
//
2889+
// for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
2890+
// struct ggml_tensor * node = cgraph->nodes[node_n];
2891+
//
2892+
// ggml_compute_forward(&params, node);
2893+
//
2894+
// if (state->ith == 0 && cplan->abort_callback &&
2895+
// cplan->abort_callback(cplan->abort_callback_data)) {
2896+
// atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
2897+
// tp->ec = GGML_STATUS_ABORTED;
2898+
// }
2899+
//
2900+
// if (node_n + 1 < cgraph->n_nodes) {
2901+
// ggml_barrier(state->threadpool);
2902+
// }
2903+
// }
2904+
//
2905+
// ggml_barrier(state->threadpool);
2906+
//
2907+
// return 0;
2908+
// }
2909+
28672910
static thread_ret_t ggml_graph_compute_thread(void * data) {
28682911
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
28692912
struct ggml_threadpool * tp = state->threadpool;
@@ -2884,21 +2927,25 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
28842927
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
28852928
struct ggml_tensor * node = cgraph->nodes[node_n];
28862929

2930+
uint64_t t0 = ggml_time_us();
28872931
ggml_compute_forward(&params, node);
2932+
uint64_t dt = ggml_time_us() - t0;
2933+
2934+
atomic_fetch_add_explicit(&ggml_op_us[node->op], dt, memory_order_relaxed);
2935+
atomic_fetch_add_explicit(&ggml_op_calls[node->op], 1, memory_order_relaxed);
28882936

28892937
if (state->ith == 0 && cplan->abort_callback &&
28902938
cplan->abort_callback(cplan->abort_callback_data)) {
28912939
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
28922940
tp->ec = GGML_STATUS_ABORTED;
2893-
}
2941+
}
28942942

28952943
if (node_n + 1 < cgraph->n_nodes) {
28962944
ggml_barrier(state->threadpool);
28972945
}
28982946
}
28992947

29002948
ggml_barrier(state->threadpool);
2901-
29022949
return 0;
29032950
}
29042951

@@ -3201,6 +3248,33 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
32013248
ggml_threadpool_free(threadpool);
32023249
}
32033250

3251+
// printf("\n========= GGML OP PERF =========\n");
3252+
// for (int i = 0; i < GGML_OP_COUNT; i++) {
3253+
// uint64_t us = atomic_load(&ggml_op_us[i]);
3254+
// uint64_t calls = atomic_load(&ggml_op_calls[i]);
3255+
// if (calls == 0) continue;
3256+
//
3257+
// printf("%-16s : %8llu us %6llu calls avg %6llu us\n",
3258+
// ggml_op_name(i),
3259+
// (unsigned long long)us,
3260+
// (unsigned long long)calls,
3261+
// (unsigned long long)(us / calls));
3262+
// }
3263+
// printf("================================\n\n");
3264+
3265+
// printf("\n");
3266+
// for (int i = 0; i < GGML_OP_COUNT; i++) {
3267+
// uint64_t us = atomic_load(&ggml_op_us[i]);
3268+
// uint64_t calls = atomic_load(&ggml_op_calls[i]);
3269+
// if (calls == 0) continue;
3270+
//
3271+
// printf("%-16s,%8llu us,%6llu,%6llu us,",
3272+
// ggml_op_name(i),
3273+
// (unsigned long long)us,
3274+
// (unsigned long long)calls,
3275+
// (unsigned long long)(us / calls));
3276+
// }
3277+
32043278
return ret;
32053279
}
32063280

gguf-py/gguf/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ class MODEL_ARCH(IntEnum):
356356
QWEN3VLMOE = auto()
357357
PHI2 = auto()
358358
PHI3 = auto()
359+
PHI3_VISION = auto()
359360
PHIMOE = auto()
360361
PLAMO = auto()
361362
PLAMO2 = auto()
@@ -723,6 +724,7 @@ class MODEL_TENSOR(IntEnum):
723724
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
724725
MODEL_ARCH.PHI2: "phi2",
725726
MODEL_ARCH.PHI3: "phi3",
727+
MODEL_ARCH.PHI3_VISION: "phi3_vision",
726728
MODEL_ARCH.PHIMOE: "phimoe",
727729
MODEL_ARCH.PLAMO: "plamo",
728730
MODEL_ARCH.PLAMO2: "plamo2",
@@ -1670,6 +1672,22 @@ class MODEL_TENSOR(IntEnum):
16701672
MODEL_TENSOR.FFN_DOWN,
16711673
MODEL_TENSOR.FFN_UP,
16721674
],
1675+
MODEL_ARCH.PHI3_VISION: [
1676+
MODEL_TENSOR.TOKEN_EMBD,
1677+
MODEL_TENSOR.OUTPUT_NORM,
1678+
MODEL_TENSOR.OUTPUT,
1679+
MODEL_TENSOR.ROPE_FACTORS_LONG,
1680+
MODEL_TENSOR.ROPE_FACTORS_SHORT,
1681+
MODEL_TENSOR.ATTN_NORM,
1682+
MODEL_TENSOR.ATTN_QKV,
1683+
MODEL_TENSOR.ATTN_Q,
1684+
MODEL_TENSOR.ATTN_K,
1685+
MODEL_TENSOR.ATTN_V,
1686+
MODEL_TENSOR.ATTN_OUT,
1687+
MODEL_TENSOR.FFN_NORM,
1688+
MODEL_TENSOR.FFN_DOWN,
1689+
MODEL_TENSOR.FFN_UP,
1690+
],
16731691
MODEL_ARCH.PHIMOE: [
16741692
MODEL_TENSOR.TOKEN_EMBD,
16751693
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
3636
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
3737
{ LLM_ARCH_PHI2, "phi2" },
3838
{ LLM_ARCH_PHI3, "phi3" },
39+
{ LLM_ARCH_PHI3_VISION, "phi3_vision" },
3940
{ LLM_ARCH_PHIMOE, "phimoe" },
4041
{ LLM_ARCH_PLAMO, "plamo" },
4142
{ LLM_ARCH_PLAMO2, "plamo2" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ enum llm_arch {
4040
LLM_ARCH_QWEN3VLMOE,
4141
LLM_ARCH_PHI2,
4242
LLM_ARCH_PHI3,
43+
LLM_ARCH_PHI3_VISION,
4344
LLM_ARCH_PHIMOE,
4445
LLM_ARCH_PLAMO,
4546
LLM_ARCH_PLAMO2,

src/llama-graph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ class llm_graph_result {
473473

474474
virtual ~llm_graph_result() = default;
475475

476-
ggml_tensor * get_tokens() const { return t_tokens; }
476+
auto get_tokens() const -> ggml_tensor * { return t_tokens; }
477477
ggml_tensor * get_logits() const { return t_logits; }
478478
ggml_tensor * get_embd() const { return t_embd; }
479479
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }

tools/mtmd/clip-impl.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
4040
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
4141
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
42+
// [NEW] Phi-3-Vision Specific Keys
43+
#define KEY_PHI3_HD_ORDER "clip.vision.hd_transform_order" // Stores "sub_glb"
44+
#define KEY_PHI3_NUM_IMG_TOKENS "clip.vision.num_img_tokens" // Stores 144
45+
#define KEY_PHI3_USE_HD "clip.vision.use_hd_transform" // Stores true
46+
#define KEY_PHI3_WITH_SEP "clip.vision.with_learnable_separator" // Stores true
4247
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
4348

4449
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
@@ -86,6 +91,21 @@
8691
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
8792
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
8893
#define TN_IMAGE_NEWLINE "model.image_newline"
94+
95+
// [NEW] Phi-3-Vision Specific Tensors
96+
// Mapping for: model.vision_embed_tokens.glb_GN
97+
#define TN_PHI3_GLB_GN "v.glb_GN"
98+
// Mapping for: model.vision_embed_tokens.sub_GN
99+
#define TN_PHI3_SUB_GN "v.sub_GN"
100+
101+
// [NEW] Projector Mapping
102+
// Your tensor map shows "model.vision_embed_tokens.img_projection.0.weight"
103+
// and "model.vision_embed_tokens.img_projection.2.weight".
104+
// This confirms it is a 2-layer MLP (Layer 0 = Linear, Layer 1 = GELU (implicit), Layer 2 = Linear).
105+
// We can reuse TN_LLAVA_PROJ ("mm.%d.%s") or define a specific one if the conversion script names them uniquely.
106+
// To be safe and specific:
107+
#define TN_PHI3_PROJ_MLP "mm.phi3_mlp.%d.%s"
108+
89109
#define TN_MM_INP_NORM "mm.input_norm.weight"
90110
#define TN_MM_INP_NORM_B "mm.input_norm.bias"
91111
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
@@ -156,6 +176,7 @@ enum projector_type {
156176
PROJECTOR_TYPE_LIGHTONOCR,
157177
PROJECTOR_TYPE_COGVLM,
158178
PROJECTOR_TYPE_JANUS_PRO,
179+
PROJECTOR_TYPE_PHI3_V,
159180
PROJECTOR_TYPE_UNKNOWN,
160181
};
161182

@@ -182,6 +203,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
182203
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
183204
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
184205
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
206+
{ PROJECTOR_TYPE_PHI3_V, "phi3_v"},
185207
};
186208

187209
static projector_type clip_projector_type_from_string(const std::string & str) {

0 commit comments

Comments
 (0)