Skip to content

Commit 3e623b9

Browse files
committed
convert: support Mistral 3 Large MoE
1 parent 7feb0a1 commit 3e623b9

File tree

1 file changed

+119
-10
lines changed

1 file changed

+119
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9912,17 +9912,124 @@ def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mis
99129912

99139913
def set_gguf_parameters(self):
99149914
super().set_gguf_parameters()
9915-
if "yarn" in self.hparams:
9916-
yarn_params = self.hparams["yarn"]
9917-
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
9918-
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
9919-
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
9920-
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
9921-
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
9922-
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
9915+
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
9916+
9917+
@staticmethod
9918+
def set_mistral_config(gguf_writer: gguf.GGUFWriter, hparams: dict):
9919+
if "yarn" in hparams:
9920+
yarn_params = hparams["yarn"]
9921+
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
9922+
gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
9923+
gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
9924+
gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
9925+
gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
9926+
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
9927+
9928+
if "llama_4_scaling" in hparams:
9929+
gguf_writer.add_attn_temperature_scale(hparams["llama_4_scaling"]["beta"])
9930+
9931+
9932+
class MistralMoeModel(DeepseekV2Model):
9933+
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
9934+
model_name = "Mistral"
9935+
hf_arch = ""
9936+
is_mistral_format = True
9937+
undo_permute = False
99239938

9924-
if "llama_4_scaling" in self.hparams:
9925-
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
9939+
def __init__(self, *args, **kwargs):
9940+
super().__init__(*args, **kwargs)
9941+
logger.info("Using MistralMoeModel")
9942+
# ref: https://github.com/vllm-project/vllm/blob/b294e28db2c5dee61bc25157664edcada8b90b31/vllm/transformers_utils/configs/mistral.py
9943+
config = self.hparams
9944+
# Mistral key -> HF key
9945+
config_mapping = {
9946+
"dim": "hidden_size",
9947+
"norm_eps": "rms_norm_eps",
9948+
"n_kv_heads": "num_key_value_heads",
9949+
"n_layers": "num_hidden_layers",
9950+
"n_heads": "num_attention_heads",
9951+
"hidden_dim": "intermediate_size",
9952+
}
9953+
# HF key -> (Mistral key, default value)
9954+
top_level_mapping_with_default = {
9955+
"model_type": ("model_type", "transformer"),
9956+
"hidden_act": ("activation", "silu"),
9957+
"tie_word_embeddings": ("tied_embeddings", False),
9958+
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
9959+
"max_position_embeddings": ("max_position_embeddings", 128_000),
9960+
}
9961+
for key, new_key in config_mapping.items():
9962+
if key in config:
9963+
config[new_key] = config[key]
9964+
for new_key, (key, default_value) in top_level_mapping_with_default.items():
9965+
config[new_key] = config.get(key, default_value)
9966+
moe_config_map = {
9967+
"route_every_n": "moe_layer_freq",
9968+
"first_k_dense_replace": "first_k_dense_replace",
9969+
"num_experts_per_tok": "num_experts_per_tok",
9970+
"num_experts": "n_routed_experts",
9971+
"expert_hidden_dim": "moe_intermediate_size",
9972+
"routed_scale": "routed_scaling_factor",
9973+
"num_shared_experts": "n_shared_experts",
9974+
"num_expert_groups": "n_group",
9975+
"num_expert_groups_per_tok": "topk_group",
9976+
}
9977+
moe = config["moe"]
9978+
for key, new_key in moe_config_map.items():
9979+
if key in moe:
9980+
config[new_key] = moe[key]
9981+
9982+
def set_gguf_parameters(self):
9983+
super().set_gguf_parameters()
9984+
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
9985+
9986+
# TODO @ngxson : this should be in tensor_mapping, but I don't have time for now
9987+
# copied from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/mistral_large_3.py
9988+
remapping = {
9989+
r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501
9990+
r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501
9991+
r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501
9992+
r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501
9993+
r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501
9994+
r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501
9995+
r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501
9996+
r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501
9997+
r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501
9998+
r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501
9999+
r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501
10000+
r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501
10001+
r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501
10002+
r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501
10003+
r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501
10004+
r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501
10005+
r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501
10006+
r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501
10007+
r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501
10008+
r"norm\.weight": "model.norm.weight", # noqa: E501
10009+
r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501
10010+
r"output\.weight": "lm_head.weight", # noqa: E501
10011+
}
10012+
10013+
def _remap_mistral_to_ds(self, name: str) -> str:
10014+
for k, v in self.remapping.items():
10015+
match = re.fullmatch(k, name)
10016+
if match:
10017+
name = re.sub(k, v, name)
10018+
break
10019+
else:
10020+
raise ValueError(f"Cannot remap {name}")
10021+
10022+
# Remapping scale names. We could do this in the regex above but it
10023+
# would triple the number of lines for most layers.
10024+
if name.endswith(".qscale_act"):
10025+
name = re.sub(r"\.qscale_act$", ".input_scale", name)
10026+
elif name.endswith(".qscale_weight"):
10027+
name = re.sub(r"\.qscale_weight$", ".weight_scale", name)
10028+
return name
10029+
10030+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
10031+
name = self._remap_mistral_to_ds(name)
10032+
return super().modify_tensors(data_torch, name, bid)
992610033

992710034

992810035
class PixtralModel(LlavaVisionModel):
@@ -10478,6 +10585,8 @@ def main() -> None:
1047810585
elif args.mmproj:
1047910586
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
1048010587
model_class = PixtralModel
10588+
elif "moe" in hparams:
10589+
model_class = MistralMoeModel
1048110590
else:
1048210591
model_class = MistralModel
1048310592

0 commit comments

Comments
 (0)