1616from hashlib import sha256
1717from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Literal , Sequence , TypeVar , cast
1818from itertools import chain
19+ from transformers import AutoConfig
1920
2021import math
2122import numpy as np
@@ -66,8 +67,6 @@ class ModelBase:
6667 part_names : list [str ]
6768 is_safetensors : bool
6869 hparams : dict [str , Any ]
69- block_count : int
70- tensor_map : gguf .TensorNameMap
7170 tensor_names : set [str ] | None
7271 gguf_writer : gguf .GGUFWriter
7372 model_name : str | None
@@ -78,6 +77,10 @@ class ModelBase:
7877 # subclasses should define this!
7978 model_arch : gguf .MODEL_ARCH
8079
80+ # subclasses should initialize this!
81+ block_count : int
82+ tensor_map : gguf .TensorNameMap
83+
8184 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , * , is_big_endian : bool = False ,
8285 use_temp_file : bool = False , eager : bool = False ,
8386 metadata_override : Path | None = None , model_name : str | None = None ,
@@ -113,8 +116,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
113116 if not self .is_safetensors :
114117 self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
115118 self .hparams = ModelBase .load_hparams (self .dir_model ) if hparams is None else hparams
116- self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
117- self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
118119 self .tensor_names = None
119120 self .metadata_override = metadata_override
120121 self .model_name = model_name
@@ -417,15 +418,13 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
417418
418419 @staticmethod
419420 def load_hparams (dir_model : Path ):
420- with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
421- hparams = json .load (f )
422- architectures = hparams .get ("architectures" )
423- if "text_config" in hparams :
424- hparams = {** hparams , ** hparams ["text_config" ]}
425- if architectures is not None :
426- # preserve "architectures" from root level config
427- hparams ["architectures" ] = architectures
428- return hparams
421+ try :
422+ return AutoConfig .from_pretrained (dir_model ).to_dict ()
423+ except Exception as e :
424+ logger .warning (f"Failed to load model config from { dir_model } : { e } " )
425+ logger .warning ("Trying to load config.json instead" )
426+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
427+ return json .load (f )
429428
430429 @classmethod
431430 def register (cls , * names : str ) -> Callable [[AnyModel ], AnyModel ]:
@@ -454,6 +453,23 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454453
455454
456455class TextModel (ModelBase ):
456+ def __init__ (self , * args , ** kwargs ):
457+ super ().__init__ (* args , ** kwargs )
458+
459+ if "text_config" in self .hparams :
460+ # move the text_config to the root level
461+ self .hparams = {** self .hparams , ** self .hparams ["text_config" ]}
462+
463+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
464+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
465+
466+ @classmethod
467+ def __init_subclass__ (cls ):
468+ # can't use an abstract property, because overriding it without type errors
469+ # would require using decorated functions instead of simply defining the property
470+ if "model_arch" not in cls .__dict__ :
471+ raise TypeError (f"Missing property 'model_arch' for { cls .__name__ !r} " )
472+
457473 def set_vocab (self ):
458474 self ._set_vocab_gpt2 ()
459475
@@ -1070,9 +1086,9 @@ def __init__(self, *args, **kwargs):
10701086 if self .model_arch != gguf .MODEL_ARCH .CLIP_VISION :
10711087 raise TypeError ("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION" )
10721088
1073- # small hack to correct the number of layers
1074- self . tensor_map = gguf . get_tensor_name_map ( gguf . MODEL_ARCH . CLIP_VISION , 128 )
1075- self .n_embd_text = self . find_hparam ([ "hidden_size" , "n_embd" ] )
1089+ # get n_embd of the text model
1090+ text_config = { ** self . hparams , ** self . hparams [ "text_config" ]}
1091+ self .n_embd_text = text_config . get ( "hidden_size" , text_config . get ( "n_embd" , 0 ) )
10761092 assert self .n_embd_text > 0 , "n_embd not found in hparams"
10771093
10781094 if "vision_config" not in self .hparams :
@@ -1081,6 +1097,9 @@ def __init__(self, *args, **kwargs):
10811097 self .global_config = self .hparams
10821098 self .hparams = self .hparams ["vision_config" ]
10831099
1100+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" , "depth" ])
1101+ self .tensor_map = gguf .get_tensor_name_map (gguf .MODEL_ARCH .CLIP_VISION , self .block_count )
1102+
10841103 # load preprocessor config
10851104 with open (self .dir_model / "preprocessor_config.json" , "r" , encoding = "utf-8" ) as f :
10861105 self .preprocessor_config = json .load (f )
@@ -1098,7 +1117,7 @@ def set_gguf_parameters(self):
10981117 self .gguf_writer .add_vision_patch_size (self .find_hparam (["patch_size" ]))
10991118 self .gguf_writer .add_vision_embedding_length (self .find_hparam (["hidden_size" ]))
11001119 self .gguf_writer .add_vision_feed_forward_length (self .find_hparam (["intermediate_size" ]))
1101- self .gguf_writer .add_vision_block_count (self .find_hparam ([ "num_hidden_layers" ]) )
1120+ self .gguf_writer .add_vision_block_count (self .block_count )
11021121 self .gguf_writer .add_vision_head_count (self .find_hparam (["num_attention_heads" ]))
11031122
11041123 # preprocessor config
@@ -1719,23 +1738,12 @@ def prepare_tensors(self):
17191738 "LlamaForCausalLM" ,
17201739 "MistralForCausalLM" ,
17211740 "MixtralForCausalLM" ,
1722- "Idefics3ForConditionalGeneration" ,
1723- "SmolVLMForConditionalGeneration" ,
1741+ "VLlama3ForCausalLM" ,
17241742 "LlavaForConditionalGeneration" )
17251743class LlamaModel (TextModel ):
17261744 model_arch = gguf .MODEL_ARCH .LLAMA
17271745 undo_permute = True
17281746
1729- def __init__ (self , * args , ** kwargs ):
1730- super ().__init__ (* args , ** kwargs )
1731- # fix for SmolVLM2, missing `num_attention_heads` in config.json
1732- if self .hparams ["architectures" ][0 ] == "SmolVLMForConditionalGeneration" :
1733- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1734- # fix for Pixtral, missing `num_attention_heads` in config.json
1735- if self .hparams ["architectures" ][0 ] == "LlavaForConditionalGeneration" \
1736- and self .hparams .get ("model_type" ) == "mistral" :
1737- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1738-
17391747 def set_vocab (self ):
17401748 try :
17411749 self ._set_vocab_sentencepiece ()
@@ -1891,31 +1899,50 @@ def prepare_tensors(self):
18911899 raise ValueError (f"Unprocessed experts: { experts } " )
18921900
18931901
1894- @ModelBase .register ("LlavaForConditionalGeneration" )
1902+ @ModelBase .register (
1903+ "LlavaForConditionalGeneration" , # pixtral
1904+ "Mistral3ForConditionalGeneration" , # mistral small 3.1
1905+ )
18951906class LlavaVisionModel (VisionModel ):
18961907 img_break_tok_id = - 1
18971908
18981909 def __init__ (self , * args , ** kwargs ):
18991910 super ().__init__ (* args , ** kwargs )
19001911 if self .hparams ["model_type" ] == "pixtral" :
1901- # fix missing config.json values
1902- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 16 )
1903- self .hparams ["num_hidden_layers" ] = self .hparams .get ("num_hidden_layers" , 24 )
1904- self .hparams ["intermediate_size" ] = self .hparams .get ("intermediate_size" , 4096 )
1905- self .hparams ["hidden_size" ] = self .hparams .get ("hidden_size" , 1024 )
1912+ # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
19061913 self .hparams ["layer_norm_eps" ] = self .hparams .get ("layer_norm_eps" , 1e-5 )
1907- self .img_break_tok_id = 12 # see tokenizer_config.json
1914+ self .img_break_tok_id = self .get_token_id ("[IMG_BREAK]" )
1915+ logger .info (f"Image break token id: { self .img_break_tok_id } " )
19081916 else :
19091917 raise ValueError (f"Unsupported model type: { self .hparams ['model_type' ]} " )
19101918
1919+ def get_token_id (self , token : str ) -> int :
1920+ tokenizer_config_file = self .dir_model / 'tokenizer_config.json'
1921+ with open (tokenizer_config_file , "r" , encoding = "utf-8" ) as f :
1922+ added_tokens_decoder = json .load (f )['added_tokens_decoder' ]
1923+ for id_ , token_data in added_tokens_decoder .items ():
1924+ if token_data ["content" ] == token :
1925+ return int (id_ )
1926+ raise ValueError (f"Token '{ token } ' not found in tokenizer config." )
1927+
19111928 def set_gguf_parameters (self ):
19121929 super ().set_gguf_parameters ()
19131930 hparams = self .hparams
19141931 if hparams ["model_type" ] == "pixtral" :
19151932 self .gguf_writer .add_vision_projector_type (gguf .VisionProjectorType .PIXTRAL )
1916- # default values below are taken from HF tranformers code
19171933 self .gguf_writer .add_vision_attention_layernorm_eps (hparams ["layer_norm_eps" ])
1918- self .gguf_writer .add_vision_use_silu (True )
1934+
1935+ # hidden_act
1936+ if hparams ["hidden_act" ] == "silu" :
1937+ self .gguf_writer .add_vision_use_silu (True )
1938+ elif hparams ["hidden_act" ] == "gelu" :
1939+ self .gguf_writer .add_vision_use_gelu (True )
1940+ else :
1941+ raise ValueError (f"Unsupported hidden_act: { hparams ['hidden_act' ]} " )
1942+
1943+ # spatial_merge_size
1944+ if "spatial_merge_size" in self .global_config :
1945+ self .gguf_writer .add_vision_spatial_merge_size (self .global_config ["spatial_merge_size" ])
19191946
19201947 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
19211948 del bid # unused
@@ -1944,13 +1971,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
19441971class SmolVLMModel (VisionModel ):
19451972 def __init__ (self , * args , ** kwargs ):
19461973 super ().__init__ (* args , ** kwargs )
1947- # fix for SmolVLM2, missing some keys in config.json
1948- # default values are taken from transformers code
19491974 if self .hparams ["model_type" ] == "smolvlm_vision" :
1975+ # fix for SmolVLM2, missing some keys in config.json
1976+ # default values are taken from transformers code
19501977 self .hparams ["hidden_size" ] = self .hparams .get ("hidden_size" , 1152 )
19511978 self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 16 )
19521979 self .hparams ["intermediate_size" ] = self .hparams .get ("intermediate_size" , 3072 )
1953- self .hparams ["num_hidden_layers" ] = self .hparams .get ("num_hidden_layers" , 12 )
19541980
19551981 def set_gguf_parameters (self ):
19561982 super ().set_gguf_parameters ()
@@ -3505,6 +3531,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35053531
35063532@ModelBase .register ("NomicBertModel" )
35073533class NomicBertModel (BertModel ):
3534+ model_arch = gguf .MODEL_ARCH .BERT
3535+
35083536 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , ** kwargs : Any ):
35093537 hparams = kwargs .pop ("hparams" , None )
35103538 if hparams is None :
@@ -5934,6 +5962,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
59345962 return n
59355963
59365964
5965+ def get_model_architecture (dir_model : Path , model_type : ModelType , hparams : Any = None ) -> str :
5966+ hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5967+ text_config = hparams .get ("text_config" , {})
5968+ vision_config = hparams .get ("vision_config" , {})
5969+ arch = hparams ["architectures" ][0 ]
5970+ # if "architectures" is found in the sub-config, use that instead
5971+ if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
5972+ arch = text_config ["architectures" ][0 ]
5973+ elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
5974+ arch = vision_config ["architectures" ][0 ]
5975+ return arch
5976+
5977+
59375978def main () -> None :
59385979 args = parse_args ()
59395980
@@ -5986,16 +6027,15 @@ def main() -> None:
59866027
59876028 logger .info (f"Loading model: { dir_model .name } " )
59886029
5989- hparams = ModelBase .load_hparams (dir_model )
5990-
59916030 if args .mmproj :
59926031 if "mmproj" not in fname_out .name :
59936032 fname_out = ModelBase .add_prefix_to_filename (fname_out , "mmproj-" )
59946033
59956034 with torch .inference_mode ():
59966035 output_type = ftype_map [args .outtype ]
5997- model_architecture = hparams ["architectures" ][0 ]
59986036 model_type = ModelType .VISION if args .mmproj else ModelType .TEXT
6037+ model_architecture = get_model_architecture (dir_model , model_type )
6038+ logger .info (f"Model architecture: { model_architecture } " )
59996039 try :
60006040 model_class = ModelBase .from_model_architecture (model_architecture , model_type = model_type )
60016041 except NotImplementedError :
0 commit comments