Skip to content

Commit 832dc26

Browse files
committed
repack mxfp4 upon conversion
1 parent e59b2eb commit 832dc26

File tree

4 files changed

+117
-23
lines changed

4 files changed

+117
-23
lines changed

convert_hf_to_gguf.py

Lines changed: 110 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7713,9 +7713,112 @@ def set_vocab(self):
77137713
self.gguf_writer.add_chat_template(chat_template)
77147714

77157715

7716-
@ModelBase.register("OpenAIMoeForCausalLM")
7717-
class OpenAIMoeModel(TextModel):
7718-
model_arch = gguf.MODEL_ARCH.OPENAI_MOE
7716+
@ModelBase.register("GptOssForCausalLM")
7717+
class GptOssModel(TextModel):
7718+
model_arch = gguf.MODEL_ARCH.GPT_OSS
7719+
7720+
def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
7721+
assert blocks.dtype == torch.uint8
7722+
assert scales.dtype == torch.uint8
7723+
scales = scales.unsqueeze(-1)
7724+
assert len(blocks.shape) == 4
7725+
assert len(scales.shape) == 4
7726+
new_data = torch.cat([scales, blocks], dim=-1)
7727+
new_data = new_data.numpy()
7728+
new_shape = [scales.shape[0], scales.shape[1], scales.shape[2] * 32]
7729+
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
7730+
self.gguf_writer.add_tensor(new_name, new_data, new_shape, gguf.GGMLQuantizationType.MXFP4)
7731+
7732+
def convert_moe_packed_tensors(
7733+
self,
7734+
new_name: str,
7735+
blocks,
7736+
scales,
7737+
*,
7738+
dtype: torch.dtype = torch.float16,
7739+
rows_per_chunk: int = 32768 * 1024,
7740+
):
7741+
import math
7742+
7743+
scales = scales.to(torch.int32) - 127
7744+
7745+
assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}"
7746+
7747+
FP4_VALUES = [
7748+
+0.0,
7749+
+0.5,
7750+
+1.0,
7751+
+1.5,
7752+
+2.0,
7753+
+3.0,
7754+
+4.0,
7755+
+6.0,
7756+
-0.0,
7757+
-0.5,
7758+
-1.0,
7759+
-1.5,
7760+
-2.0,
7761+
-3.0,
7762+
-4.0,
7763+
-6.0,
7764+
]
7765+
blocks = blocks.to(device="cpu")
7766+
scales = scales.to(device="cpu")
7767+
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
7768+
7769+
*prefix_shape, G, B = blocks.shape
7770+
rows_total = math.prod(prefix_shape) * G
7771+
7772+
blocks = blocks.reshape(rows_total, B)
7773+
scales = scales.reshape(rows_total, 1)
7774+
7775+
out = torch.empty(rows_total, B * 2, dtype=dtype, device="cpu")
7776+
7777+
for r0 in range(0, rows_total, rows_per_chunk):
7778+
r1 = min(r0 + rows_per_chunk, rows_total)
7779+
7780+
blk = blocks[r0:r1]
7781+
exp = scales[r0:r1]
7782+
7783+
# nibble indices -> int64
7784+
idx_lo = (blk & 0x0F).to(torch.long)
7785+
idx_hi = (blk >> 4).to(torch.long)
7786+
7787+
sub = out[r0:r1]
7788+
sub[:, 0::2] = lut[idx_lo]
7789+
sub[:, 1::2] = lut[idx_hi]
7790+
7791+
torch.ldexp(sub, exp, out=sub)
7792+
del idx_lo, idx_hi, blk, exp
7793+
7794+
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
7795+
out = out.numpy()
7796+
logger.info(f"Unpacked {new_name} with shape {out.shape} from MXFP4 to F16")
7797+
print(out.dtype, out.device, out.shape)
7798+
self.gguf_writer.add_tensor(new_name, out)
7799+
7800+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
7801+
blocks0: Tensor = torch.zeros(1)
7802+
blocks1: Tensor = torch.zeros(1)
7803+
# we assume that tensors are loaded in the correct order
7804+
for name, data_torch in self.get_tensors():
7805+
if "mlp.experts.down_proj_blocks" in name:
7806+
blocks0 = data_torch
7807+
elif "mlp.experts.down_proj_scales" in name:
7808+
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
7809+
#self.repack_mxfp4(new_name, blocks0, data_torch)
7810+
self.convert_moe_packed_tensors(new_name, blocks0, data_torch)
7811+
elif "mlp.experts.gate_up_proj_blocks" in name:
7812+
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
7813+
elif "mlp.experts.gate_up_proj_scales" in name:
7814+
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
7815+
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
7816+
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
7817+
# self.repack_mxfp4(new_name_gate, blocks0, scales0)
7818+
# self.repack_mxfp4(new_name_up, blocks1, scales1)
7819+
self.convert_moe_packed_tensors(new_name_gate, blocks0, scales0)
7820+
self.convert_moe_packed_tensors(new_name_up, blocks1, scales1)
7821+
return []
77197822

77207823
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
77217824
del bid # unused
@@ -7728,32 +7831,20 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
77287831
if name.endswith("_bias"):
77297832
name = name.replace("down_proj_bias", "down_proj.bias")
77307833
else:
7731-
name = name.replace("down_proj", "down_proj.weight")
7732-
data_torch = data_torch.transpose(-1, -2)
7834+
return []
77337835

77347836
# split the gate_up into gate and up
77357837
if "gate_up_proj" in name:
77367838
if name.endswith("_bias"):
77377839
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
77387840
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
7739-
#dim_half = data_torch.shape[-1] // 2
7740-
#gate_proj_bias, up_proj_bias = data_torch.split(dim_half, dim=-1)
77417841
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
77427842
return [
77437843
(self.map_tensor_name(name_gate), gate_proj_bias),
77447844
(self.map_tensor_name(name_up), up_proj_bias)
77457845
]
77467846
else:
7747-
name_up = name.replace("gate_up_proj", "up_proj.weight")
7748-
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
7749-
#dim_half = data_torch.shape[-1] // 2
7750-
#gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2)
7751-
data_torch = data_torch.transpose(-1, -2)
7752-
gate_proj_weight, up_proj_weight = data_torch[:, ::2, :], data_torch[:, 1::2, :]
7753-
return [
7754-
(self.map_tensor_name(name_gate), gate_proj_weight),
7755-
(self.map_tensor_name(name_up), up_proj_weight)
7756-
]
7847+
return []
77577848

77587849
return [(self.map_tensor_name(name), data_torch)]
77597850

@@ -7767,7 +7858,7 @@ def set_gguf_parameters(self):
77677858

77687859
rope_scaling = self.hparams.get("rope_scaling") or {}
77697860
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
7770-
assert rope_type == "yarn", f"OpenAI MoE only supports yarn rope scaling, got {rope_type}"
7861+
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
77717862
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
77727863
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
77737864
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
@@ -7912,6 +8003,7 @@ class LazyTorchTensor(gguf.LazyBase):
79128003
_dtype_map: dict[torch.dtype, type] = {
79138004
torch.float16: np.float16,
79148005
torch.float32: np.float32,
8006+
torch.uint8: np.uint8,
79158007
}
79168008

79178009
# used for safetensors slices

gguf-py/gguf/constants.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class MODEL_ARCH(IntEnum):
377377
ERNIE4_5_MOE = auto()
378378
HUNYUAN_MOE = auto()
379379
SMOLLM3 = auto()
380-
OPENAI_MOE = auto()
380+
GPT_OSS = auto()
381381
LFM2 = auto()
382382
DREAM = auto()
383383
SMALLTHINKER = auto()
@@ -700,7 +700,7 @@ class MODEL_TENSOR(IntEnum):
700700
MODEL_ARCH.FALCON_H1: "falcon-h1",
701701
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
702702
MODEL_ARCH.SMOLLM3: "smollm3",
703-
MODEL_ARCH.OPENAI_MOE: "openai-moe",
703+
MODEL_ARCH.GPT_OSS: "gpt-oss",
704704
MODEL_ARCH.LFM2: "lfm2",
705705
MODEL_ARCH.DREAM: "dream",
706706
MODEL_ARCH.SMALLTHINKER: "smallthinker",
@@ -2491,7 +2491,7 @@ class MODEL_TENSOR(IntEnum):
24912491
MODEL_TENSOR.FFN_DOWN,
24922492
MODEL_TENSOR.FFN_UP,
24932493
],
2494-
MODEL_ARCH.OPENAI_MOE: [
2494+
MODEL_ARCH.GPT_OSS: [
24952495
MODEL_TENSOR.TOKEN_EMBD,
24962496
MODEL_TENSOR.OUTPUT_NORM,
24972497
MODEL_TENSOR.OUTPUT,
@@ -2661,6 +2661,7 @@ class GGMLQuantizationType(IntEnum):
26612661
BF16 = 30
26622662
TQ1_0 = 34
26632663
TQ2_0 = 35
2664+
MXFP4 = 39
26642665

26652666

26662667
class ExpertGatingFuncType(IntEnum):
@@ -2801,6 +2802,7 @@ class VisionProjectorType:
28012802
GGMLQuantizationType.BF16: (1, 2),
28022803
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
28032804
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
2805+
GGMLQuantizationType.MXFP4: (1, 1), # quick hack to write MXFP4 as U8
28042806
}
28052807

28062808

src/llama-arch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8686
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
8787
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
8888
{ LLM_ARCH_SMOLLM3, "smollm3" },
89-
{ LLM_ARCH_OPENAI_MOE, "openai-moe" },
89+
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
9090
{ LLM_ARCH_LFM2, "lfm2" },
9191
{ LLM_ARCH_DREAM, "dream" },
9292
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },

src/llama-chat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
6666
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
6767
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
6868
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
69-
{ "openai-moe", LLM_CHAT_TEMPLATE_OPENAI_MOE },
69+
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
7070
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
7171
};
7272

0 commit comments

Comments
 (0)