Skip to content

Commit 134dd92

Browse files
authored
add mxfp4 qat, mainly packing code. (#2347)
1 parent 18a29ce commit 134dd92

File tree

12 files changed

+215
-49
lines changed

12 files changed

+215
-49
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm_qat/README.md

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54,43 +54,16 @@ accelerate launch --config-file accelerate_config/fsdp1.yaml \
5454

5555
##### Step 2:
5656

57-
Quantize the trained model using `prepare_qat()` by setting the following flags `--quant_scheme MXFP8 --do_train False`. This inserts fake quantization modules into the model without starting training yet. Then save the model directly to a get post training quantization model.
57+
Save the model directly to a get post training quantization model with using [auto-round](https://github.com/intel/auto-round).
5858

5959

6060
```
61-
accelerate launch --config-file accelerate_config/fsdp1.yaml \
62-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
63-
main.py \
64-
--model_name_or_path ./llama3.1-finetuned \
65-
--model_max_length 4096 \
66-
--dataloader_drop_last True \
67-
--do_train False \
68-
--do_eval False \
69-
--quant_scheme MXFP8 \
70-
--output_dir ./llama3.1-finetuned-ptq \
71-
--dataset Daring-Anteater \
72-
--num_train_epochs 2.0 \
73-
--per_device_train_batch_size 4 \
74-
--per_device_eval_batch_size 4 \
75-
--gradient_accumulation_steps 1 \
76-
--eval_accumulation_steps 1 \
77-
--save_strategy steps \
78-
--save_steps 3000 \
79-
--eval_strategy steps \
80-
--eval_steps 3000 \
81-
--load_best_model_at_end True \
82-
--save_total_limit 2 \
83-
--learning_rate 1e-5 \
84-
--weight_decay 0.0 \
85-
--warmup_ratio 0.1 \
86-
--lr_scheduler_type linear \
87-
--logging_steps 1 \
88-
--report_to tensorboard
61+
python quantize_autoround.py
8962
```
9063

9164
##### Step 3:
9265

93-
Train/fine-tune the quantized model with a small learning rate, e.g. 1e-5 for Adam optimizer by setting `--quant_scheme MXFP8 --do_train True`
66+
Train/fine-tune the quantized model with a small learning rate, e.g. 1e-5 for Adam optimizer by setting `--quant_scheme MXFP4 --do_train True`
9467

9568
```
9669
accelerate launch --config-file accelerate_config/fsdp1.yaml \
@@ -101,7 +74,7 @@ accelerate launch --config-file accelerate_config/fsdp1.yaml \
10174
--dataloader_drop_last True \
10275
--do_train True \
10376
--do_eval True \
104-
--quant_scheme MXFP8 \
77+
--quant_scheme MXFP4 \
10578
--output_dir ./llama3.1-finetuned-qat \
10679
--dataset Daring-Anteater \
10780
--max_steps 1000 \
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: FSDP
4+
downcast_bf16: 'no'
5+
enable_cpu_affinity: false
6+
fsdp_config:
7+
fsdp_activation_checkpointing: true
8+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9+
fsdp_cpu_ram_efficient_loading: true
10+
fsdp_offload_params: false
11+
fsdp_reshard_after_forward: true
12+
fsdp_state_dict_type: SHARDED_STATE_DICT
13+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
14+
fsdp_version: 2
15+
machine_rank: 0
16+
main_training_function: main
17+
mixed_precision: bf16
18+
num_machines: 1
19+
num_processes: gpu
20+
rdzv_backend: static
21+
same_network: true
22+
tpu_env: []
23+
tpu_use_cluster: false
24+
tpu_use_sudo: false
25+
use_cpu: false

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm_qat/main.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
AutoModelForCausalLM,
1313
AutoTokenizer,
1414
HfArgumentParser,
15-
Trainer,
1615
default_data_collator,
1716
set_seed,
1817
TrainerCallback,
@@ -21,6 +20,7 @@
2120
from utils import (
2221
get_metrics_with_perplexity,
2322
make_supervised_data_module,
23+
QATTrainer
2424
)
2525

2626
logger = logging.getLogger(__name__)
@@ -47,7 +47,7 @@ class TrainingArguments(transformers.TrainingArguments):
4747
class DataArguments:
4848
dataset: str = field(
4949
default="Daring-Anteater",
50-
metadata={"help": "Specify the dataset.", "choices": ["Daring-Anteater"]},
50+
metadata={"help": "Specify the dataset.", "choices": ["Daring-Anteater", "cnn_dailymail"]},
5151
)
5252
train_size: int = field(
5353
default=0,
@@ -69,7 +69,7 @@ class QuantizationArguments:
6969
"Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled"
7070
" with the specified quantization format"
7171
),
72-
"choices": ["MXFP8"],
72+
"choices": ["MXFP8", "MXFP4"],
7373
},
7474
)
7575

@@ -124,9 +124,16 @@ def train():
124124
# prepare model for quantization
125125
if quant_args.quant_scheme is not None:
126126
from neural_compressor.torch.quantization.quantize import prepare_qat
127+
128+
model.train()
127129
# inplace
128-
# default mxfp8
129-
prepare_qat(model)
130+
if quant_args.quant_scheme == "MXFP8":
131+
# default mxfp8
132+
prepare_qat(model)
133+
if quant_args.quant_scheme == "MXFP4":
134+
mappings = {torch.nn.Linear: "MXFP4"}
135+
prepare_qat(model, mappings)
136+
130137

131138
logger.info("Finish model preparation for QAT.")
132139

@@ -154,7 +161,7 @@ def train():
154161
if training_args.gradient_checkpointing and training_args.gradient_checkpointing_kwargs is None:
155162
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
156163

157-
trainer = Trainer(
164+
trainer = QATTrainer(
158165
model=model,
159166
processing_class=tokenizer,
160167
args=training_args,
@@ -172,15 +179,8 @@ def train():
172179
metrics = get_metrics_with_perplexity(metrics)
173180
logger.info(f"Evaluation results: \n{metrics}")
174181

175-
if training_args.do_train and quant_args.quant_scheme is None:
176-
logger.info("Saving the model...")
177-
trainer.save_model(training_args.output_dir)
178-
elif quant_args.quant_scheme is not None:
179-
from neural_compressor.torch.export.export_hf import export_hf2compressored_model
180-
# export quantized model for vllm inference using llm-compressor and compressed_tensor
181-
export_hf2compressored_model(model, training_args.output_dir, quant_args.quant_scheme)
182-
if tokenizer is not None:
183-
tokenizer.save_pretrained(training_args.output_dir)
182+
logger.info("Saving the model...")
183+
trainer.save_model(training_args.output_dir)
184184

185185

186186
if __name__ == "__main__":
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
from auto_round import AutoRound
3+
4+
model_name_or_path = "./llama3.1-finetuned"
5+
output_dir = "./Llama-3.1-8B-Instruct_autoround_rtn_mxfp4"
6+
7+
# Available schemes: "W2A16", "W3A16", "W4A16", "W8A16", "NVFP4", "MXFP4" (no real kernels), "GGUF:Q4_K_M", etc.
8+
ar = AutoRound(model_name_or_path, scheme="MXFP4", iters=0)
9+
10+
ar.quantize_and_save(output_dir=output_dir, format="llm_compressor")

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm_qat/utils.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import types
1818
from contextlib import contextmanager
1919
from functools import partial
20-
20+
import json
2121
import datasets
2222
import torch
2323

2424
import transformers
25-
from transformers import default_data_collator
25+
from transformers import default_data_collator, Trainer
2626

2727
IGNORE_INDEX = -100
2828

@@ -146,3 +146,78 @@ def get_metrics_with_perplexity(metrics):
146146
if "eval_loss" in metrics:
147147
metrics["perplexity"] = float(torch.exp(torch.tensor(metrics["eval_loss"])))
148148
return metrics
149+
150+
151+
def print_rank_0(*args, **kwargs):
152+
"""Prints only on the master process."""
153+
154+
if torch.distributed.is_available() and torch.distributed.is_initialized():
155+
if torch.distributed.get_rank(group=None) == 0:
156+
print(*args, **kwargs, flush=True)
157+
else:
158+
print(*args, **kwargs, flush=True)
159+
160+
class QATTrainer(Trainer):
161+
"""A drop-in replacement of HuggingFace's Trainer for ModelOpt.
162+
163+
This class adds extra utilities for ModelOpt checkpointing and memory reporting.
164+
"""
165+
166+
def __init__(self, *args, **kwargs):
167+
"""Initialize."""
168+
# enable_huggingface_checkpointing()
169+
super().__init__(*args, **kwargs)
170+
171+
self._original_dtype = getattr(
172+
getattr(self.model, "config", None), "dtype", None
173+
) or getattr(getattr(self.model, "config", None), "torch_dtype", None)
174+
175+
def save_model(self, *args, **kwargs):
176+
"""Save the quantized model."""
177+
if (
178+
(not self.is_in_train)
179+
and self.is_fsdp_enabled
180+
and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
181+
):
182+
print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.")
183+
original_type = self.accelerator.state.fsdp_plugin.state_dict_type
184+
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
185+
outputs = super().save_model(*args, **kwargs)
186+
if torch.distributed.is_initialized():
187+
torch.distributed.barrier()
188+
189+
self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
190+
else:
191+
outputs = super().save_model(*args, **kwargs)
192+
if (not self.is_in_train) and self.args.should_save:
193+
out_dir = args[0]
194+
# FSDP may upcast parameter dtype to float32 during mixed-precision training,
195+
# we convert it back to original dtype by updating `torch-dtype` in `config.json`
196+
self._update_config_json_dtype(out_dir, str(self._original_dtype).split(".")[1])
197+
return outputs
198+
199+
def _update_config_json_dtype(self, output_dir: str, dtype_str: str | None) -> None:
200+
"""Rewrite <output_dir>/config.json 'dtype' (preferred) or 'torch_dtype' to dtype_str."""
201+
cfg_path = os.path.join(output_dir, "config.json")
202+
if not os.path.isfile(cfg_path):
203+
print_rank_0(f"[warn] config.json not found under {output_dir}; skip dtype rewrite.")
204+
return
205+
try:
206+
with open(cfg_path, encoding="utf-8") as f:
207+
data = json.load(f)
208+
# Prefer 'dtype', else fall back to 'torch_dtype'
209+
key_to_update = (
210+
"dtype" if "dtype" in data else ("torch_dtype" if "torch_dtype" in data else None)
211+
)
212+
if key_to_update is None:
213+
print_rank_0(
214+
"[warn] Neither 'dtype' nor 'torch_dtype' present in config.json; skip dtype rewrite."
215+
)
216+
return
217+
if data.get(key_to_update) != dtype_str:
218+
data[key_to_update] = dtype_str
219+
with open(cfg_path, "w", encoding="utf-8") as f:
220+
json.dump(data, f, ensure_ascii=False, indent=2)
221+
print_rank_0(f'Updated config.json: {key_to_update} -> "{dtype_str}"')
222+
except Exception as e:
223+
print_rank_0(f"[warn] Failed to update dtype in config.json: {e}")

neural_compressor/torch/algorithms/qat/quant_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def get_quant_config(scheme: str) -> dict[str, Any]:
108108
quantization_config["provider"] = "auto-round"
109109
quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] = True
110110
quantization_config["config_groups"]["group_0"]["input_activations"]["is_mx"] = True
111+
quantization_config["format"] = "float-quantized"
111112

112113
except ImportError:
113114
quantization_config = None
@@ -133,6 +134,9 @@ def _get_quantization_from_layer(layer):
133134
if weight_quantizer.num_bits == 8 and weight_quantizer.data_type == "mx_fp8":
134135
return "MXFP8"
135136

137+
if weight_quantizer.num_bits == 4 and weight_quantizer.data_type == "mx_fp4":
138+
return "MXFP4"
139+
136140
# Raise error for unsupported num_bits
137141
raise NotImplementedError(f"Unsupported quantizer with num_bits: {weight_quantizer.num_bits}")
138142

neural_compressor/torch/algorithms/qat/tensor_quantizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,18 @@ def weight_pack(self, weight, scale):
161161
e8m0_scale = (scale + 127).to(torch.uint8)
162162
return qweight.reshape(original_shape), e8m0_scale.reshape(original_shape[0], -1)
163163

164+
if self.data_type == "mx_fp4":
165+
qweight = weight.reshape(-1, self.block_size) / torch.exp2(scale.float())
166+
167+
from auto_round.export.export_to_autoround.qlinear_fp import pack_fp4_to_uint8
168+
169+
qweight_packed = pack_fp4_to_uint8(qweight)
170+
171+
e8m0_scale = (scale + 127).to(torch.uint8)
172+
return qweight_packed.reshape(original_shape[0], original_shape[1] // 2), e8m0_scale.reshape(
173+
original_shape[0], -1
174+
)
175+
164176
def __repr__(self):
165177
if self._disabled or not self._if_quant:
166178
return "TensorQuantizer(disabled)"

neural_compressor/torch/export/export_hf.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@ def _export_quantized_weight(sub_module: nn.Module, quantization_format: str = N
4141

4242
sub_module.register_buffer("weight_scale", e8m0_scale)
4343

44-
setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))
44+
if quantization_format == "MXFP8":
45+
setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))
46+
47+
if quantization_format == "MXFP4":
48+
delattr(sub_module, weight_name)
49+
# name aligned for vllm emulation
50+
sub_module.register_buffer("weight_packed", quantized_weight)
4551

4652

4753
def _export_hf_checkpoint(model: nn.Module, scheme: str | None = None) -> tuple[dict[str, Any], dict[str, Any]]:

neural_compressor/torch/quantization/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,6 +2091,10 @@ def get_config_set_for_tuning(cls, dtype="int8"):
20912091
torch.nn.Linear: "MXFP8",
20922092
}
20932093

2094+
QAT_MODULE_MAPPINGS: dict[Callable, Any] = {
2095+
torch.nn.Linear: ["MXFP8", "MXFP4"],
2096+
}
2097+
20942098

20952099
def get_default_qat_module_mappings() -> dict[Callable, Any]:
20962100
"""Get default module mapping for quantization aware training."""

test/3x/torch/algorithms/qat/test_qat.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,28 @@ def test_train():
6565
for name, param in model.named_parameters():
6666
assert param.grad is not None
6767
optimizer.step()
68+
69+
70+
def test_train_mxfp4():
71+
"""QAT test."""
72+
setup_seed(20)
73+
74+
model = TinyModel()
75+
mappings = {torch.nn.Linear: "MXFP4"}
76+
prepare_qat(model, mappings)
77+
78+
inp = torch.randn([2, 32])
79+
80+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
81+
82+
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
83+
output = model(inp)
84+
loss = output.mean()
85+
86+
optimizer.zero_grad()
87+
loss.backward()
88+
89+
# check the grad
90+
for name, param in model.named_parameters():
91+
assert param.grad is not None
92+
optimizer.step()

0 commit comments

Comments
 (0)