Skip to content

Commit cd3c118

Browse files
ngxsonjuliendenizeCISC
authored
model: support Ministral3 (ggml-org#17644)
* conversion script * support ministral 3 * maybe this is better? * add TODO for rope_yarn_log_mul * better ppl (tested on 14B-Instruct) * Add Ministral3 support to Mistral format * improve arch handling * add sizes * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * nits --------- Co-authored-by: Julien Denize <julien.denize@mistral.ai> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 649495c commit cd3c118

File tree

11 files changed

+342
-10
lines changed

11 files changed

+342
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,10 +1581,27 @@ def __init__(self, *args, **kwargs):
15811581

15821582
# load preprocessor config
15831583
self.preprocessor_config = {}
1584-
if not self.is_mistral_format:
1585-
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
1584+
1585+
# prefer preprocessor_config.json if possible
1586+
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
1587+
if preprocessor_config_path.is_file():
1588+
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
15861589
self.preprocessor_config = json.load(f)
15871590

1591+
# prefer processor_config.json if possible
1592+
processor_config_path = self.dir_model / "processor_config.json"
1593+
if processor_config_path.is_file():
1594+
with open(processor_config_path, "r", encoding="utf-8") as f:
1595+
cfg = json.load(f)
1596+
# move image_processor to root level for compat
1597+
if "image_processor" in cfg:
1598+
cfg = {
1599+
**cfg,
1600+
**cfg["image_processor"],
1601+
}
1602+
# merge configs
1603+
self.preprocessor_config = {**self.preprocessor_config, **cfg}
1604+
15881605
def get_vision_config(self) -> dict[str, Any] | None:
15891606
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
15901607
return self.global_config.get(config_name)
@@ -2797,7 +2814,32 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27972814

27982815
@ModelBase.register("Mistral3ForConditionalGeneration")
27992816
class Mistral3Model(LlamaModel):
2800-
model_arch = gguf.MODEL_ARCH.LLAMA
2817+
model_arch = gguf.MODEL_ARCH.MISTRAL3
2818+
2819+
def __init__(self, *args, **kwargs):
2820+
super().__init__(*args, **kwargs)
2821+
# for compatibility, we use LLAMA arch for older models
2822+
# TODO: remove this once everyone has migrated to newer version of llama.cpp
2823+
if self.hparams.get("model_type") != "ministral3":
2824+
self.model_arch = gguf.MODEL_ARCH.LLAMA
2825+
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
2826+
self.gguf_writer.add_architecture()
2827+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
2828+
2829+
def set_gguf_parameters(self):
2830+
super().set_gguf_parameters()
2831+
rope_params = self.hparams.get("rope_parameters")
2832+
if self.hparams.get("model_type") == "ministral3":
2833+
assert rope_params is not None, "ministral3 must have 'rope_parameters' config"
2834+
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
2835+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2836+
self.gguf_writer.add_rope_scaling_factor(rope_params["factor"])
2837+
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"])
2838+
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"])
2839+
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
2840+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"])
2841+
self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"])
2842+
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
28012843

28022844
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
28032845
name = name.replace("language_model.", "")
@@ -9809,12 +9851,22 @@ def modify_tensors(self, data_torch, name, bid):
98099851

98109852

98119853
class MistralModel(LlamaModel):
9812-
model_arch = gguf.MODEL_ARCH.LLAMA
9854+
model_arch = gguf.MODEL_ARCH.MISTRAL3
98139855
model_name = "Mistral"
98149856
hf_arch = ""
98159857
is_mistral_format = True
98169858
undo_permute = False
98179859

9860+
def __init__(self, *args, **kwargs):
9861+
super().__init__(*args, **kwargs)
9862+
# for compatibility, we use LLAMA arch for older models
9863+
# TODO: remove this once everyone migrates to newer version of llama.cpp
9864+
if "llama_4_scaling" not in self.hparams:
9865+
self.model_arch = gguf.MODEL_ARCH.LLAMA
9866+
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
9867+
self.gguf_writer.add_architecture()
9868+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
9869+
98189870
@staticmethod
98199871
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
98209872
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
@@ -9854,6 +9906,20 @@ def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mis
98549906

98559907
return template
98569908

9909+
def set_gguf_parameters(self):
9910+
super().set_gguf_parameters()
9911+
if "yarn" in self.hparams:
9912+
yarn_params = self.hparams["yarn"]
9913+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
9914+
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
9915+
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
9916+
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
9917+
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
9918+
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
9919+
9920+
if "llama_4_scaling" in self.hparams:
9921+
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
9922+
98579923

98589924
class PixtralModel(LlavaVisionModel):
98599925
model_name = "Pixtral"

gguf-py/gguf/constants.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class Attention:
175175
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
176176
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
177177
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
178+
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
178179

179180
class Rope:
180181
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -444,6 +445,7 @@ class MODEL_ARCH(IntEnum):
444445
MINIMAXM2 = auto()
445446
RND1 = auto()
446447
PANGU_EMBED = auto()
448+
MISTRAL3 = auto()
447449

448450

449451
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -817,6 +819,7 @@ class MODEL_TENSOR(IntEnum):
817819
MODEL_ARCH.COGVLM: "cogvlm",
818820
MODEL_ARCH.RND1: "rnd1",
819821
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
822+
MODEL_ARCH.MISTRAL3: "mistral3",
820823
}
821824

822825
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -3071,6 +3074,26 @@ class MODEL_TENSOR(IntEnum):
30713074
MODEL_TENSOR.FFN_DOWN,
30723075
MODEL_TENSOR.FFN_UP,
30733076
],
3077+
MODEL_ARCH.MISTRAL3: [
3078+
MODEL_TENSOR.TOKEN_EMBD,
3079+
MODEL_TENSOR.OUTPUT_NORM,
3080+
MODEL_TENSOR.OUTPUT,
3081+
MODEL_TENSOR.ROPE_FREQS,
3082+
MODEL_TENSOR.ATTN_NORM,
3083+
MODEL_TENSOR.ATTN_Q,
3084+
MODEL_TENSOR.ATTN_K,
3085+
MODEL_TENSOR.ATTN_V,
3086+
MODEL_TENSOR.ATTN_OUT,
3087+
MODEL_TENSOR.ATTN_ROT_EMBD,
3088+
MODEL_TENSOR.FFN_GATE_INP,
3089+
MODEL_TENSOR.FFN_NORM,
3090+
MODEL_TENSOR.FFN_GATE,
3091+
MODEL_TENSOR.FFN_DOWN,
3092+
MODEL_TENSOR.FFN_UP,
3093+
MODEL_TENSOR.FFN_GATE_EXP,
3094+
MODEL_TENSOR.FFN_DOWN_EXP,
3095+
MODEL_TENSOR.FFN_UP_EXP,
3096+
],
30743097
# TODO
30753098
}
30763099

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,9 @@ def add_attn_output_scale(self, value: float) -> None:
904904
def add_attn_temperature_length(self, value: int) -> None:
905905
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
906906

907+
def add_attn_temperature_scale(self, value: float) -> None:
908+
self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)
909+
907910
def add_pooling_type(self, value: PoolingType) -> None:
908911
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
909912

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ add_library(llama
132132
models/t5-enc.cpp
133133
models/wavtokenizer-dec.cpp
134134
models/xverse.cpp
135+
models/mistral3.cpp
135136
models/graph-context-mamba.cpp
136137
)
137138

src/llama-arch.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
111111
{ LLM_ARCH_COGVLM, "cogvlm" },
112112
{ LLM_ARCH_RND1, "rnd1" },
113113
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
114+
{ LLM_ARCH_MISTRAL3, "mistral3" },
114115
{ LLM_ARCH_UNKNOWN, "(unknown)" },
115116
};
116117

@@ -204,6 +205,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
204205
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
205206
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
206207
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
208+
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
207209
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
208210
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
209211

@@ -2512,6 +2514,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
25122514
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
25132515
},
25142516
},
2517+
{
2518+
LLM_ARCH_MISTRAL3,
2519+
{
2520+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2521+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2522+
{ LLM_TENSOR_OUTPUT, "output" },
2523+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
2524+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2525+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2526+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2527+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2528+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2529+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
2530+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2531+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2532+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2533+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2534+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2535+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
2536+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
2537+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
2538+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2539+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2540+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2541+
},
2542+
},
25152543
{
25162544
LLM_ARCH_UNKNOWN,
25172545
{

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ enum llm_arch {
115115
LLM_ARCH_COGVLM,
116116
LLM_ARCH_RND1,
117117
LLM_ARCH_PANGU_EMBED,
118+
LLM_ARCH_MISTRAL3,
118119
LLM_ARCH_UNKNOWN,
119120
};
120121

@@ -208,6 +209,7 @@ enum llm_kv {
208209
LLM_KV_ATTENTION_SCALE,
209210
LLM_KV_ATTENTION_OUTPUT_SCALE,
210211
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
212+
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
211213
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
212214
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
213215

src/llama-graph.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
7171
if (ubatch->pos && attn_scale) {
7272
const int64_t n_tokens = ubatch->n_tokens;
7373

74+
GGML_ASSERT(f_attn_temp_scale != 0.0f);
75+
GGML_ASSERT(n_attn_temp_floor_scale != 0);
76+
7477
std::vector<float> attn_scale_data(n_tokens, 0.0f);
7578
for (int i = 0; i < n_tokens; ++i) {
7679
const float pos = ubatch->pos[i];

src/llama-hparams.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ struct llama_hparams {
162162
// llama4 smallthinker
163163
uint32_t n_moe_layer_step = 0;
164164
uint32_t n_no_rope_layer_step = 4;
165-
uint32_t n_attn_temp_floor_scale = 8192;
166-
float f_attn_temp_scale = 0.1;
165+
uint32_t n_attn_temp_floor_scale = 0;
166+
float f_attn_temp_scale = 0.0f;
167167

168168
// gemma3n altup
169169
uint32_t n_altup = 4; // altup_num_inputs

src/llama-model.cpp

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
626626
switch (arch) {
627627
case LLM_ARCH_LLAMA:
628628
{
629-
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
630-
631629
if (hparams.n_expert == 8) {
632630
switch (hparams.n_layer) {
633631
case 32: type = LLM_TYPE_8x7B; break;
@@ -663,8 +661,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
663661
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
664662
hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
665663
} else {
666-
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
667-
hparams.n_swa = 8192;
664+
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
665+
hparams.n_swa = 8192;
666+
hparams.n_attn_temp_floor_scale = 8192;
667+
hparams.f_attn_temp_scale = 0.1f;
668668
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
669669
}
670670

@@ -2247,6 +2247,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
22472247
default: type = LLM_TYPE_UNKNOWN;
22482248
}
22492249
} break;
2250+
case LLM_ARCH_MISTRAL3:
2251+
{
2252+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2253+
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
2254+
2255+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
2256+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
2257+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
2258+
2259+
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
2260+
if (hparams.f_attn_temp_scale != 0.0f) {
2261+
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
2262+
if (hparams.n_attn_temp_floor_scale == 0) {
2263+
throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling");
2264+
}
2265+
}
2266+
2267+
// TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
2268+
// but may need further verification with other values
2269+
if (hparams.rope_yarn_log_mul != 0.0f) {
2270+
float factor = 1.0f / hparams.rope_freq_scale_train;
2271+
float mscale = 1.0f;
2272+
float mscale_all_dims = hparams.rope_yarn_log_mul;
2273+
static auto get_mscale = [](float scale, float mscale) {
2274+
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
2275+
};
2276+
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
2277+
}
2278+
2279+
switch (hparams.n_layer) {
2280+
case 26: type = LLM_TYPE_3B; break;
2281+
case 34: type = LLM_TYPE_8B; break;
2282+
case 40: type = LLM_TYPE_14B; break;
2283+
default: type = LLM_TYPE_UNKNOWN;
2284+
}
2285+
} break;
22502286
default: throw std::runtime_error("unsupported model architecture");
22512287
}
22522288

@@ -2560,6 +2596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
25602596
case LLM_ARCH_MINICPM:
25612597
case LLM_ARCH_GRANITE:
25622598
case LLM_ARCH_GRANITE_MOE:
2599+
case LLM_ARCH_MISTRAL3:
25632600
{
25642601
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
25652602

@@ -7522,6 +7559,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
75227559
{
75237560
llm = std::make_unique<llm_build_qwen3next>(*this, params);
75247561
} break;
7562+
case LLM_ARCH_MISTRAL3:
7563+
{
7564+
llm = std::make_unique<llm_build_mistral3>(*this, params);
7565+
} break;
75257566
default:
75267567
GGML_ABORT("fatal error");
75277568
}
@@ -7690,6 +7731,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
76907731
case LLM_ARCH_ARCEE:
76917732
case LLM_ARCH_ERNIE4_5:
76927733
case LLM_ARCH_ERNIE4_5_MOE:
7734+
case LLM_ARCH_MISTRAL3:
76937735
return LLAMA_ROPE_TYPE_NORM;
76947736

76957737
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)