Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -54,43 +54,16 @@ accelerate launch --config-file accelerate_config/fsdp1.yaml \

##### Step 2:

Quantize the trained model using `prepare_qat()` by setting the following flags `--quant_scheme MXFP8 --do_train False`. This inserts fake quantization modules into the model without starting training yet. Then save the model directly to a get post training quantization model.
Save the model directly to a get post training quantization model with using [auto-round](https://github.com/intel/auto-round).


```
accelerate launch --config-file accelerate_config/fsdp1.yaml \
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
main.py \
--model_name_or_path ./llama3.1-finetuned \
--model_max_length 4096 \
--dataloader_drop_last True \
--do_train False \
--do_eval False \
--quant_scheme MXFP8 \
--output_dir ./llama3.1-finetuned-ptq \
--dataset Daring-Anteater \
--num_train_epochs 2.0 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 1 \
--save_strategy steps \
--save_steps 3000 \
--eval_strategy steps \
--eval_steps 3000 \
--load_best_model_at_end True \
--save_total_limit 2 \
--learning_rate 1e-5 \
--weight_decay 0.0 \
--warmup_ratio 0.1 \
--lr_scheduler_type linear \
--logging_steps 1 \
--report_to tensorboard
python quantize_autoround.py
```

##### Step 3:

Train/fine-tune the quantized model with a small learning rate, e.g. 1e-5 for Adam optimizer by setting `--quant_scheme MXFP8 --do_train True`
Train/fine-tune the quantized model with a small learning rate, e.g. 1e-5 for Adam optimizer by setting `--quant_scheme MXFP4 --do_train True`

```
accelerate launch --config-file accelerate_config/fsdp1.yaml \
Expand All @@ -101,7 +74,7 @@ accelerate launch --config-file accelerate_config/fsdp1.yaml \
--dataloader_drop_last True \
--do_train True \
--do_eval True \
--quant_scheme MXFP8 \
--quant_scheme MXFP4 \
--output_dir ./llama3.1-finetuned-qat \
--dataset Daring-Anteater \
--max_steps 1000 \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: gpu
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
default_data_collator,
set_seed,
TrainerCallback,
Expand All @@ -21,6 +20,7 @@
from utils import (
get_metrics_with_perplexity,
make_supervised_data_module,
QATTrainer
)

logger = logging.getLogger(__name__)
Expand All @@ -47,7 +47,7 @@ class TrainingArguments(transformers.TrainingArguments):
class DataArguments:
dataset: str = field(
default="Daring-Anteater",
metadata={"help": "Specify the dataset.", "choices": ["Daring-Anteater"]},
metadata={"help": "Specify the dataset.", "choices": ["Daring-Anteater", "cnn_dailymail"]},
)
train_size: int = field(
default=0,
Expand All @@ -69,7 +69,7 @@ class QuantizationArguments:
"Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled"
" with the specified quantization format"
),
"choices": ["MXFP8"],
"choices": ["MXFP8", "MXFP4"],
},
)

Expand Down Expand Up @@ -124,9 +124,16 @@ def train():
# prepare model for quantization
if quant_args.quant_scheme is not None:
from neural_compressor.torch.quantization.quantize import prepare_qat

model.train()
# inplace
# default mxfp8
prepare_qat(model)
if quant_args.quant_scheme == "MXFP8":
# default mxfp8
prepare_qat(model)
if quant_args.quant_scheme == "MXFP4":
mappings = {torch.nn.Linear: "MXFP4"}
prepare_qat(model, mappings)


logger.info("Finish model preparation for QAT.")

Expand Down Expand Up @@ -154,7 +161,7 @@ def train():
if training_args.gradient_checkpointing and training_args.gradient_checkpointing_kwargs is None:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

trainer = Trainer(
trainer = QATTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
Expand All @@ -172,15 +179,8 @@ def train():
metrics = get_metrics_with_perplexity(metrics)
logger.info(f"Evaluation results: \n{metrics}")

if training_args.do_train and quant_args.quant_scheme is None:
logger.info("Saving the model...")
trainer.save_model(training_args.output_dir)
elif quant_args.quant_scheme is not None:
from neural_compressor.torch.export.export_hf import export_hf2compressored_model
# export quantized model for vllm inference using llm-compressor and compressed_tensor
export_hf2compressored_model(model, training_args.output_dir, quant_args.quant_scheme)
if tokenizer is not None:
tokenizer.save_pretrained(training_args.output_dir)
logger.info("Saving the model...")
trainer.save_model(training_args.output_dir)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

from auto_round import AutoRound

model_name_or_path = "./llama3.1-finetuned"
output_dir = "./Llama-3.1-8B-Instruct_autoround_rtn_mxfp4"

# Available schemes: "W2A16", "W3A16", "W4A16", "W8A16", "NVFP4", "MXFP4" (no real kernels), "GGUF:Q4_K_M", etc.
ar = AutoRound(model_name_or_path, scheme="MXFP4", iters=0)

ar.quantize_and_save(output_dir=output_dir, format="llm_compressor")
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import types
from contextlib import contextmanager
from functools import partial

import json
import datasets
import torch

import transformers
from transformers import default_data_collator
from transformers import default_data_collator, Trainer

IGNORE_INDEX = -100

Expand Down Expand Up @@ -146,3 +146,78 @@ def get_metrics_with_perplexity(metrics):
if "eval_loss" in metrics:
metrics["perplexity"] = float(torch.exp(torch.tensor(metrics["eval_loss"])))
return metrics


def print_rank_0(*args, **kwargs):
"""Prints only on the master process."""

if torch.distributed.is_available() and torch.distributed.is_initialized():
if torch.distributed.get_rank(group=None) == 0:
print(*args, **kwargs, flush=True)
else:
print(*args, **kwargs, flush=True)

class QATTrainer(Trainer):
"""A drop-in replacement of HuggingFace's Trainer for ModelOpt.

This class adds extra utilities for ModelOpt checkpointing and memory reporting.
"""

def __init__(self, *args, **kwargs):
"""Initialize."""
# enable_huggingface_checkpointing()
super().__init__(*args, **kwargs)

self._original_dtype = getattr(
getattr(self.model, "config", None), "dtype", None
) or getattr(getattr(self.model, "config", None), "torch_dtype", None)

def save_model(self, *args, **kwargs):
"""Save the quantized model."""
if (
(not self.is_in_train)
and self.is_fsdp_enabled
and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
):
print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.")
original_type = self.accelerator.state.fsdp_plugin.state_dict_type
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
outputs = super().save_model(*args, **kwargs)
if torch.distributed.is_initialized():
torch.distributed.barrier()

self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
else:
outputs = super().save_model(*args, **kwargs)
if (not self.is_in_train) and self.args.should_save:
out_dir = args[0]
# FSDP may upcast parameter dtype to float32 during mixed-precision training,
# we convert it back to original dtype by updating `torch-dtype` in `config.json`
self._update_config_json_dtype(out_dir, str(self._original_dtype).split(".")[1])
return outputs

def _update_config_json_dtype(self, output_dir: str, dtype_str: str | None) -> None:
"""Rewrite <output_dir>/config.json 'dtype' (preferred) or 'torch_dtype' to dtype_str."""
cfg_path = os.path.join(output_dir, "config.json")
if not os.path.isfile(cfg_path):
print_rank_0(f"[warn] config.json not found under {output_dir}; skip dtype rewrite.")
return
try:
with open(cfg_path, encoding="utf-8") as f:
data = json.load(f)
# Prefer 'dtype', else fall back to 'torch_dtype'
key_to_update = (
"dtype" if "dtype" in data else ("torch_dtype" if "torch_dtype" in data else None)
)
if key_to_update is None:
print_rank_0(
"[warn] Neither 'dtype' nor 'torch_dtype' present in config.json; skip dtype rewrite."
)
return
if data.get(key_to_update) != dtype_str:
data[key_to_update] = dtype_str
with open(cfg_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print_rank_0(f'Updated config.json: {key_to_update} -> "{dtype_str}"')
except Exception as e:
print_rank_0(f"[warn] Failed to update dtype in config.json: {e}")
4 changes: 4 additions & 0 deletions neural_compressor/torch/algorithms/qat/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def get_quant_config(scheme: str) -> dict[str, Any]:
quantization_config["provider"] = "auto-round"
quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] = True
quantization_config["config_groups"]["group_0"]["input_activations"]["is_mx"] = True
quantization_config["format"] = "float-quantized"

except ImportError:
quantization_config = None
Expand All @@ -133,6 +134,9 @@ def _get_quantization_from_layer(layer):
if weight_quantizer.num_bits == 8 and weight_quantizer.data_type == "mx_fp8":
return "MXFP8"

if weight_quantizer.num_bits == 4 and weight_quantizer.data_type == "mx_fp4":
return "MXFP4"

# Raise error for unsupported num_bits
raise NotImplementedError(f"Unsupported quantizer with num_bits: {weight_quantizer.num_bits}")

Expand Down
12 changes: 12 additions & 0 deletions neural_compressor/torch/algorithms/qat/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,18 @@ def weight_pack(self, weight, scale):
e8m0_scale = (scale + 127).to(torch.uint8)
return qweight.reshape(original_shape), e8m0_scale.reshape(original_shape[0], -1)

if self.data_type == "mx_fp4":
qweight = weight.reshape(-1, self.block_size) / torch.exp2(scale.float())

from auto_round.export.export_to_autoround.qlinear_fp import pack_fp4_to_uint8

qweight_packed = pack_fp4_to_uint8(qweight)

e8m0_scale = (scale + 127).to(torch.uint8)
return qweight_packed.reshape(original_shape[0], original_shape[1] // 2), e8m0_scale.reshape(
original_shape[0], -1
)

def __repr__(self):
if self._disabled or not self._if_quant:
return "TensorQuantizer(disabled)"
Expand Down
8 changes: 7 additions & 1 deletion neural_compressor/torch/export/export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ def _export_quantized_weight(sub_module: nn.Module, quantization_format: str = N

sub_module.register_buffer("weight_scale", e8m0_scale)

setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))
if quantization_format == "MXFP8":
setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))

if quantization_format == "MXFP4":
delattr(sub_module, weight_name)
# name aligned for vllm emulation
sub_module.register_buffer("weight_packed", quantized_weight)


def _export_hf_checkpoint(model: nn.Module, scheme: str | None = None) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,10 @@ def get_config_set_for_tuning(cls, dtype="int8"):
torch.nn.Linear: "MXFP8",
}

QAT_MODULE_MAPPINGS: dict[Callable, Any] = {
torch.nn.Linear: ["MXFP8", "MXFP4"],
}


def get_default_qat_module_mappings() -> dict[Callable, Any]:
"""Get default module mapping for quantization aware training."""
Expand Down
25 changes: 25 additions & 0 deletions test/3x/torch/algorithms/qat/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,28 @@ def test_train():
for name, param in model.named_parameters():
assert param.grad is not None
optimizer.step()


def test_train_mxfp4():
"""QAT test."""
setup_seed(20)

model = TinyModel()
mappings = {torch.nn.Linear: "MXFP4"}
prepare_qat(model, mappings)

inp = torch.randn([2, 32])

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
output = model(inp)
loss = output.mean()

optimizer.zero_grad()
loss.backward()

# check the grad
for name, param in model.named_parameters():
assert param.grad is not None
optimizer.step()
11 changes: 11 additions & 0 deletions test/3x/torch/algorithms/qat/test_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,14 @@ def test_get_quantization_format_disabled_returns_none(disabled):
assert fmt is None
else:
assert fmt == "MXFP8"

layer.weight_quantizer = TensorQuantizer(bits=4, data_type="mx_fp4")
layer.weight_quantizer._disabled = disabled
layer.input_quantizer = TensorQuantizer(bits=4, data_type="mx_fp4")
layer.input_quantizer._disabled = disabled

fmt = quant_utils.get_quantization_format(layer)
if disabled:
assert fmt is None
else:
assert fmt == "MXFP4"
21 changes: 21 additions & 0 deletions test/3x/torch/algorithms/qat/test_quantizer_and_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,24 @@ def test_tensor_quantizer_scale_persistence():
assert tq.scale.dtype == torch.uint8
# Heuristic: at least one non-zero (if all zero it may still be valid, but improbable)
assert (tq.scale != 0).any() or (shared_exp == 0).all()


def test_weight_pack():
# Provide scale_shape so internal buffer is registered & updated
tq = TensorQuantizer(scale_shape=(4, 32), block_size=32)
x = torch.randn(4, 32)
# Use internal fake quant function to generate scale
q, shared_exp = tq._fake_quantize(x)

q_packed, scale = tq.weight_pack(q, shared_exp)

assert q_packed.dtype == torch.float8_e4m3fn

tq = TensorQuantizer(data_type="mx_fp4", bits=4, scale_shape=(4, 32), block_size=32)
x = torch.randn(4, 32)
# Use internal fake quant function to generate scale
q, shared_exp = tq._fake_quantize(x)

q_packed, scale = tq.weight_pack(q, shared_exp)

assert q_packed.dtype == torch.uint8
Loading