Skip to content

Conversation

@ngxson
Copy link
Collaborator

@ngxson ngxson commented Dec 3, 2025

WIP, the code is quite ugly for now, but just want to get it to work.

Remember to convert with the --mistral-format argument, as the weight is not yet transformers-compatible

Output F16 weight is 1.35 Terabytes Q8_0 weight is 716GB and I don't have enough hw to test it

Edit: thanks @bartowski1182 for testing it!

Disclaimer: unlike Ministral release, this PR is not affiliated with Mistral Team


NOTE: this PR only covers the conversion to GGUF. the C++ code still missing llama 4 scaling to work, but it will be another PR

@ngxson ngxson changed the title convert: support Mistral 3 Large MoE convert: support Mistral 3 Large MoE (need help for testing) Dec 3, 2025
@bartowski1182
Copy link
Contributor

So far so good with this, in a couple hours will be able to test generation

@github-actions github-actions bot added the python python script changes label Dec 3, 2025
@bartowski1182
Copy link
Contributor

seems to work and produce coherent results!

@ngxson ngxson marked this pull request as ready for review December 3, 2025 16:54
@ngxson ngxson requested a review from CISC as a code owner December 3, 2025 16:54
@ngxson ngxson marked this pull request as draft December 3, 2025 16:55
@ngxson
Copy link
Collaborator Author

ngxson commented Dec 3, 2025

This PR still needs to be clean up before it is ready for review 😅

@ngxson ngxson marked this pull request as ready for review December 3, 2025 18:10
@ngxson ngxson changed the title convert: support Mistral 3 Large MoE (need help for testing) convert: support Mistral 3 Large MoE Dec 3, 2025
Comment on lines +9941 to +9942
# remap hparams from Mistral MoE format to DeepseekV2 format
# we do this way to be able to reuse DeepseekV2Model set_gguf_parameters logic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat ugly but an acceptable trade-off.

@taronaeo taronaeo linked an issue Dec 5, 2025 that may be closed by this pull request
@csabakecskemeti
Copy link
Contributor

@ngxson Thank you so much for this.
I've also tested the conversion from your branch the convert script succeeded (with --mistral-format) but at inference time (Q8_0) I've received:
llama_model_load: error loading model: missing tensor 'blk.0.attn_k_b.weight'
Tried F16 too it also failed on the same (it should have been failed on not enough memory)

I've tried your Q4_K_M - seems working just fine.
(now downloading @bartowski1182's Q8_0 version to test on that too)

Is there any other setting or change needed for the conversion?
Note I've used Mistral's own BF16 version as the source, which has now disappeared.

@bartowski1182
Copy link
Contributor

It disappeared?? 👀 I can re-upload if necessary I guess ..

Only difference is using --mistral-format

@csabakecskemeti
Copy link
Contributor

Yeah I've used the mistral format. Than I guess I have a corrupted bf16 version (I cannot think of anything else)
Yeah I can't see the BF16 version on HF.
If you ca upload that would be nice.
I made a dequantizer I've used with the Ministral 3 instruct models.
If anyone need it

https://github.com/csabakecskemeti/ministral-3_dequantizer_fp8-bf16

@bartowski1182
Copy link
Contributor

I can see it here:

https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512-BF16

@csabakecskemeti
Copy link
Contributor

You're right, they just removed it from the collection (it it was ever there :p) there's where I looked for. My bad

@CISC
Copy link
Collaborator

CISC commented Dec 5, 2025

@ngxson Thank you so much for this. I've also tested the conversion from your branch the convert script succeeded (with --mistral-format) but at inference time (Q8_0) I've received: llama_model_load: error loading model: missing tensor 'blk.0.attn_k_b.weight' Tried F16 too it also failed on the same (it should have been failed on not enough memory)

It looks like @ngxson forgot wkv_b remapping in the cleanup.

Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csabakecskemeti This should work.

name = name.replace(".qscale_act", ".input_scale")
if name.endswith(".qscale_weight"):
name = name.replace(".qscale_weight", ".weight_scale")
if ".experts." in name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ".experts." in name:
if ".wkv_b." in name:
name = name.replace(".wkv_b.", ".kv_b_proj.")
if ".experts." in name:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change gave me:

ValueError: Can not map tensor 'layers.32.attention.k_b_proj.weight'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you need the changes below as well (cannot be applied directly because GitHub's "new experience" is useless).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working so far with the other change included :)

Comment on lines 944 to 955
MODEL_TENSOR.ATTN_KV_B: (
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
"layers.{bid}.attention.wkv_b", # mistral-large
),

MODEL_TENSOR.ATTN_K_B: (
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_V_B: (
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MODEL_TENSOR.ATTN_KV_B: (
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),
MODEL_TENSOR.ATTN_K_B: (
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
"layers.{bid}.attention.k_b_proj", # mistral-large
),
MODEL_TENSOR.ATTN_V_B: (
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
"layers.{bid}.attention.v_b_proj", # mistral-large
),

GitHub will mess up the diff here, but you get the gist.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I guess I needed this one too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah I didn't notice that the changes was overwritten by git merge. thanks!

(feel free to ping me when these changes are OK to be added)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's working so far @ngxson but I can wait until I have a quant I can run and do that first to confirm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Can't convert mistral 3 large

4 participants