@@ -9912,17 +9912,124 @@ def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mis
99129912
99139913 def set_gguf_parameters (self ):
99149914 super ().set_gguf_parameters ()
9915- if "yarn" in self .hparams :
9916- yarn_params = self .hparams ["yarn" ]
9917- self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
9918- self .gguf_writer .add_rope_scaling_factor (yarn_params ["factor" ])
9919- self .gguf_writer .add_rope_scaling_yarn_beta_fast (yarn_params ["beta" ])
9920- self .gguf_writer .add_rope_scaling_yarn_beta_slow (yarn_params ["alpha" ])
9921- self .gguf_writer .add_rope_scaling_yarn_log_mul (1.0 ) # mscale_all_dim
9922- self .gguf_writer .add_rope_scaling_orig_ctx_len (yarn_params ["original_max_position_embeddings" ])
9915+ MistralModel .set_mistral_config (self .gguf_writer , self .hparams )
9916+
9917+ @staticmethod
9918+ def set_mistral_config (gguf_writer : gguf .GGUFWriter , hparams : dict ):
9919+ if "yarn" in hparams :
9920+ yarn_params = hparams ["yarn" ]
9921+ gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
9922+ gguf_writer .add_rope_scaling_factor (yarn_params ["factor" ])
9923+ gguf_writer .add_rope_scaling_yarn_beta_fast (yarn_params ["beta" ])
9924+ gguf_writer .add_rope_scaling_yarn_beta_slow (yarn_params ["alpha" ])
9925+ gguf_writer .add_rope_scaling_yarn_log_mul (1.0 ) # mscale_all_dim
9926+ gguf_writer .add_rope_scaling_orig_ctx_len (yarn_params ["original_max_position_embeddings" ])
9927+
9928+ if "llama_4_scaling" in hparams :
9929+ gguf_writer .add_attn_temperature_scale (hparams ["llama_4_scaling" ]["beta" ])
9930+
9931+
9932+ class MistralMoeModel (DeepseekV2Model ):
9933+ model_arch = gguf .MODEL_ARCH .DEEPSEEK2
9934+ model_name = "Mistral"
9935+ hf_arch = ""
9936+ is_mistral_format = True
9937+ undo_permute = False
99239938
9924- if "llama_4_scaling" in self .hparams :
9925- self .gguf_writer .add_attn_temperature_scale (self .hparams ["llama_4_scaling" ]["beta" ])
9939+ def __init__ (self , * args , ** kwargs ):
9940+ super ().__init__ (* args , ** kwargs )
9941+ logger .info ("Using MistralMoeModel" )
9942+ # ref: https://github.com/vllm-project/vllm/blob/b294e28db2c5dee61bc25157664edcada8b90b31/vllm/transformers_utils/configs/mistral.py
9943+ config = self .hparams
9944+ # Mistral key -> HF key
9945+ config_mapping = {
9946+ "dim" : "hidden_size" ,
9947+ "norm_eps" : "rms_norm_eps" ,
9948+ "n_kv_heads" : "num_key_value_heads" ,
9949+ "n_layers" : "num_hidden_layers" ,
9950+ "n_heads" : "num_attention_heads" ,
9951+ "hidden_dim" : "intermediate_size" ,
9952+ }
9953+ # HF key -> (Mistral key, default value)
9954+ top_level_mapping_with_default = {
9955+ "model_type" : ("model_type" , "transformer" ),
9956+ "hidden_act" : ("activation" , "silu" ),
9957+ "tie_word_embeddings" : ("tied_embeddings" , False ),
9958+ "max_seq_len" : ("max_seq_len" , config .get ("max_position_embeddings" , 128_000 )),
9959+ "max_position_embeddings" : ("max_position_embeddings" , 128_000 ),
9960+ }
9961+ for key , new_key in config_mapping .items ():
9962+ if key in config :
9963+ config [new_key ] = config [key ]
9964+ for new_key , (key , default_value ) in top_level_mapping_with_default .items ():
9965+ config [new_key ] = config .get (key , default_value )
9966+ moe_config_map = {
9967+ "route_every_n" : "moe_layer_freq" ,
9968+ "first_k_dense_replace" : "first_k_dense_replace" ,
9969+ "num_experts_per_tok" : "num_experts_per_tok" ,
9970+ "num_experts" : "n_routed_experts" ,
9971+ "expert_hidden_dim" : "moe_intermediate_size" ,
9972+ "routed_scale" : "routed_scaling_factor" ,
9973+ "num_shared_experts" : "n_shared_experts" ,
9974+ "num_expert_groups" : "n_group" ,
9975+ "num_expert_groups_per_tok" : "topk_group" ,
9976+ }
9977+ moe = config ["moe" ]
9978+ for key , new_key in moe_config_map .items ():
9979+ if key in moe :
9980+ config [new_key ] = moe [key ]
9981+
9982+ def set_gguf_parameters (self ):
9983+ super ().set_gguf_parameters ()
9984+ MistralModel .set_mistral_config (self .gguf_writer , self .hparams )
9985+
9986+ # TODO @ngxson : this should be in tensor_mapping, but I don't have time for now
9987+ # copied from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/mistral_large_3.py
9988+ remapping = {
9989+ r"layers\.(\d+)\.attention_norm\.weight" : r"model.layers.\1.input_layernorm.weight" , # noqa: E501
9990+ r"layers\.(\d+)\.attention\.wq_a\.(\w+)" : r"model.layers.\1.self_attn.q_a_proj.\2" , # noqa: E501
9991+ r"layers\.(\d+)\.attention\.q_a_norm\.weight" : r"model.layers.\1.self_attn.q_a_layernorm.weight" , # noqa: E501
9992+ r"layers\.(\d+)\.attention\.wq_b\.(\w+)" : r"model.layers.\1.self_attn.q_b_proj.\2" , # noqa: E501
9993+ r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)" : r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2" , # noqa: E501
9994+ r"layers\.(\d+)\.attention\.kv_a_norm\.weight" : r"model.layers.\1.self_attn.kv_a_layernorm.weight" , # noqa: E501
9995+ r"layers\.(\d+)\.attention\.wkv_b\.(\w+)" : r"model.layers.\1.self_attn.kv_b_proj.\2" , # noqa: E501
9996+ r"layers\.(\d+)\.attention\.wo\.(\w+)" : r"model.layers.\1.self_attn.o_proj.\2" , # noqa: E501
9997+ r"layers\.(\d+)\.ffn_norm\.weight" : r"model.layers.\1.post_attention_layernorm.weight" , # noqa: E501
9998+ r"layers\.(\d+)\.feed_forward\.w1\.(\w+)" : r"model.layers.\1.mlp.gate_proj.\2" , # noqa: E501
9999+ r"layers\.(\d+)\.feed_forward\.w2\.(\w+)" : r"model.layers.\1.mlp.down_proj.\2" , # noqa: E501
10000+ r"layers\.(\d+)\.feed_forward\.w3\.(\w+)" : r"model.layers.\1.mlp.up_proj.\2" , # noqa: E501
10001+ r"layers\.(\d+)\.gate\.weight" : r"model.layers.\1.mlp.gate.weight" , # noqa: E501
10002+ r"layers\.(\d+)\.shared_experts\.w1\.(\w+)" : r"model.layers.\1.mlp.shared_experts.gate_proj.\2" , # noqa: E501
10003+ r"layers\.(\d+)\.shared_experts\.w2\.(\w+)" : r"model.layers.\1.mlp.shared_experts.down_proj.\2" , # noqa: E501
10004+ r"layers\.(\d+)\.shared_experts\.w3\.(\w+)" : r"model.layers.\1.mlp.shared_experts.up_proj.\2" , # noqa: E501
10005+ r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)" : r"model.layers.\1.mlp.experts.\2.gate_proj.\3" , # noqa: E501
10006+ r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)" : r"model.layers.\1.mlp.experts.\2.down_proj.\3" , # noqa: E501
10007+ r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)" : r"model.layers.\1.mlp.experts.\2.up_proj.\3" , # noqa: E501
10008+ r"norm\.weight" : "model.norm.weight" , # noqa: E501
10009+ r"tok_embeddings\.weight" : "model.embed_tokens.weight" , # noqa: E501
10010+ r"output\.weight" : "lm_head.weight" , # noqa: E501
10011+ }
10012+
10013+ def _remap_mistral_to_ds (self , name : str ) -> str :
10014+ for k , v in self .remapping .items ():
10015+ match = re .fullmatch (k , name )
10016+ if match :
10017+ name = re .sub (k , v , name )
10018+ break
10019+ else :
10020+ raise ValueError (f"Cannot remap { name } " )
10021+
10022+ # Remapping scale names. We could do this in the regex above but it
10023+ # would triple the number of lines for most layers.
10024+ if name .endswith (".qscale_act" ):
10025+ name = re .sub (r"\.qscale_act$" , ".input_scale" , name )
10026+ elif name .endswith (".qscale_weight" ):
10027+ name = re .sub (r"\.qscale_weight$" , ".weight_scale" , name )
10028+ return name
10029+
10030+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ):
10031+ name = self ._remap_mistral_to_ds (name )
10032+ return super ().modify_tensors (data_torch , name , bid )
992610033
992710034
992810035class PixtralModel (LlavaVisionModel ):
@@ -10478,6 +10585,8 @@ def main() -> None:
1047810585 elif args .mmproj :
1047910586 assert hparams .get ("vision_encoder" ) is not None , "This model does not support multimodal"
1048010587 model_class = PixtralModel
10588+ elif "moe" in hparams :
10589+ model_class = MistralMoeModel
1048110590 else :
1048210591 model_class = MistralModel
1048310592
0 commit comments