Skip to content

Commit 2d918b3

Browse files
committed
mtmd: make sam hparams configurable
1 parent 15f2ada commit 2d918b3

File tree

5 files changed

+21
-5
lines changed

5 files changed

+21
-5
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6032,6 +6032,7 @@ def set_gguf_parameters(self):
60326032
sam_hparams = hparams['sam']
60336033
self.gguf_writer.add_vision_sam_layers_count(sam_hparams['layers'])
60346034
self.gguf_writer.add_vision_sam_embedding_length(sam_hparams['width'])
6035+
self.gguf_writer.add_vision_sam_head_count(sam_hparams['heads'])
60356036

60366037
def get_vision_config(self) -> dict[str, Any]:
60376038
vision_config: dict[str, Any] | None = self.global_config.get("vision_config")

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ class Projector:
306306
class SAM:
307307
BLOCK_COUNT = "clip.vision.sam.block_count"
308308
EMBEDDING_LENGTH = "clip.vision.sam.embedding_length"
309+
HEAD_COUNT = "clip.vision.sam.head_count"
309310

310311
class ClipAudio:
311312
NUM_MEL_BINS = "clip.audio.num_mel_bins"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,9 @@ def add_vision_sam_layers_count(self, value: int) -> None:
11351135

11361136
def add_vision_sam_embedding_length(self, value: int) -> None:
11371137
self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value)
1138+
1139+
def add_vision_sam_head_count(self, value: int) -> None:
1140+
self.add_uint32(Keys.ClipVision.SAM.HEAD_COUNT, value)
11381141

11391142
# audio models
11401143

tools/mtmd/clip-impl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
5050
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
5151
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
52-
52+
#define KEY_SAM_N_HEAD "clip.vision.sam.head_count"
53+
#define KEY_SAM_N_BLOCK "clip.vision.sam.block_count"
54+
#define KEY_SAM_N_EMBD "clip.vision.sam.embedding_length"
5355
// audio-specific
5456
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
5557
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"

tools/mtmd/clip.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ struct clip_hparams {
193193
int32_t attn_window_size = 0;
194194
int32_t n_wa_pattern = 0;
195195

196+
// deepseek-ocr (sam)
197+
int32_t sam_n_layer = 0;
198+
int32_t sam_n_head = 0;
199+
int32_t sam_n_embd = 0;
200+
196201
// audio
197202
int32_t n_mel_bins = 0; // whisper preprocessor
198203
int32_t proj_stack_factor = 0; // ultravox
@@ -2676,9 +2681,9 @@ struct clip_graph {
26762681
}
26772682

26782683
ggml_tensor * build_sam(ggml_tensor * inp_raw) {
2679-
const int n_embd = 768;
2680-
const int _depth = 12;
2681-
const int n_heads = 12;
2684+
const int n_embd = hparams.sam_n_embd;
2685+
const int n_layer = hparams.sam_n_layer;
2686+
const int n_heads = hparams.sam_n_head;
26822687
const int d_heads = n_embd / n_heads;
26832688
const int window = hparams.attn_window_size;
26842689

@@ -2721,7 +2726,7 @@ struct clip_graph {
27212726
}
27222727

27232728
// loop over layers
2724-
for (int il = 0; il < _depth; il++) {
2729+
for (int il = 0; il < n_layer; il++) {
27252730
auto & layer = model.sam_layers[il];
27262731
ggml_tensor * shortcut = cur;
27272732

@@ -3286,6 +3291,10 @@ struct clip_model_loader {
32863291
hparams.patch_size = 16;
32873292
hparams.image_size = 1024;
32883293
hparams.warmup_image_size = 1024;
3294+
3295+
get_u32(KEY_SAM_N_BLOCK, hparams.sam_n_layer, true);
3296+
get_u32(KEY_SAM_N_HEAD, hparams.sam_n_head, true);
3297+
get_u32(KEY_SAM_N_EMBD, hparams.sam_n_embd, true);
32893298
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true);
32903299
} break;
32913300
default:

0 commit comments

Comments
 (0)