Skip to content
38 changes: 38 additions & 0 deletions neural_compressor/torch/algorithms/qat/mxfp4_packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

E2M1_max = 6.0

E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])


def cast_fp4(x):
sign = torch.sign(x)
sign_bit = (2 - sign) // 2
ord_ = torch.sum((x.abs().unsqueeze(-1) - E2M1_bounds.to(x.device)) > 0, dim=-1)
fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8)
return fp4_val


def fuse_uint4_to_uint8(x):
# If the last dimension is odd, pad with zeros
# If this behavior is not desired, please modify the code accordingly
left_side = x[..., 0::2] # Even indices (0, 2, 4...)
right_side = x[..., 1::2] # Odd indices (1, 3, 5...)
new_data = right_side.clone() << 4 # Put odd indices (higher addresses) in high bits
new_data[..., : left_side.shape[-1]] += left_side # Put even indices in low bits
return new_data
4 changes: 4 additions & 0 deletions neural_compressor/torch/algorithms/qat/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def get_quant_config(scheme: str) -> dict[str, Any]:
quantization_config["provider"] = "auto-round"
quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] = True
quantization_config["config_groups"]["group_0"]["input_activations"]["is_mx"] = True
quantization_config["format"] = "float-quantized"

except ImportError:
quantization_config = None
Expand All @@ -133,6 +134,9 @@ def _get_quantization_from_layer(layer):
if weight_quantizer.num_bits == 8 and weight_quantizer.data_type == "mx_fp8":
return "MXFP8"

if weight_quantizer.num_bits == 4 and weight_quantizer.data_type == "mx_fp4":
return "MXFP4"

# Raise error for unsupported num_bits
raise NotImplementedError(f"Unsupported quantizer with num_bits: {weight_quantizer.num_bits}")

Expand Down
13 changes: 13 additions & 0 deletions neural_compressor/torch/algorithms/qat/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ def weight_pack(self, weight, scale):
e8m0_scale = (scale + 127).to(torch.uint8)
return qweight.reshape(original_shape), e8m0_scale.reshape(original_shape[0], -1)

if self.data_type == "mx_fp4":
qweight = weight.reshape(-1, self.block_size) / torch.exp2(scale.float())

from .mxfp4_packing import cast_fp4, fuse_uint4_to_uint8

qweight = cast_fp4(qweight)
qweight_packed = fuse_uint4_to_uint8(qweight)

e8m0_scale = (scale + 127).to(torch.uint8)
return qweight_packed.reshape(original_shape[0], original_shape[1] // 2), e8m0_scale.reshape(
original_shape[0], -1
)

def __repr__(self):
if self._disabled or not self._if_quant:
return "TensorQuantizer(disabled)"
Expand Down
8 changes: 7 additions & 1 deletion neural_compressor/torch/export/export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ def _export_quantized_weight(sub_module: nn.Module, quantization_format: str = N

sub_module.register_buffer("weight_scale", e8m0_scale)

setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))
if quantization_format == "MXFP8":
setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))

if quantization_format == "MXFP4":
delattr(sub_module, weight_name)
# name aligned for vllm emulation
sub_module.register_buffer("weight_packed", quantized_weight)


def _export_hf_checkpoint(model: nn.Module, scheme: str | None = None) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,10 @@ def get_config_set_for_tuning(cls, dtype="int8"):
torch.nn.Linear: "MXFP8",
}

QAT_MODULE_MAPPINGS: dict[Callable, Any] = {
torch.nn.Linear: ["MXFP8", "MXFP4"],
}


def get_default_qat_module_mappings() -> dict[Callable, Any]:
"""Get default module mapping for quantization aware training."""
Expand Down
Loading