Skip to content

Commit 18a29ce

Browse files
AutoRound Refactor (#2348)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 46eff9b commit 18a29ce

File tree

7 files changed

+342
-450
lines changed

7 files changed

+342
-450
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""The AutoRound-related modules."""
16+
17+
from .autoround import *

neural_compressor/torch/algorithms/weight_only/autoround.py renamed to neural_compressor/torch/algorithms/autoround/autoround.py

Lines changed: 28 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -41,78 +41,16 @@ def _is_auto_round_available():
4141

4242
from neural_compressor.common.utils import Statistics
4343
from neural_compressor.torch.algorithms import Quantizer
44+
from neural_compressor.torch.algorithms.weight_only.utility import CapturedDataloader, InputCaptureModule
4445
from neural_compressor.torch.utils import get_accelerator, logger
4546

46-
from .utility import CapturedDataloader, InputCaptureModule
47-
4847

4948
class AutoRoundQuantizer(Quantizer):
5049
"""AutoRound Quantizer."""
5150

5251
def __init__(
5352
self,
54-
bits: int = None,
55-
group_size: int = None,
56-
sym: bool = None,
57-
data_type: str = None,
58-
act_bits: int = None,
59-
act_group_size: int = None,
60-
act_sym: bool = None,
61-
act_data_type: str = None,
62-
act_dynamic: bool = None,
63-
super_bits: int = None,
64-
super_group_size: int = None,
65-
quant_config: dict = {}, # for INC
66-
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
67-
enable_full_range: bool = False, ##for symmetric, TODO support later
68-
batch_size: int = 8,
69-
amp: bool = True,
70-
device_map: str = None,
71-
quant_lm_head: bool = False,
72-
lr_scheduler=None,
73-
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
74-
enable_quanted_input: bool = True,
75-
enable_minmax_tuning: bool = True,
76-
lr: float = None,
77-
minmax_lr: float = None,
78-
low_gpu_mem_usage: bool = False,
79-
iters: int = 200,
80-
seqlen: int = 2048,
81-
nsamples: int = 128,
82-
sampler: str = "rand",
83-
seed: int = 42,
84-
nblocks: int = 1,
85-
gradient_accumulate_steps: int = 1,
86-
not_use_best_mse: bool = False,
87-
dynamic_max_gap: int = -1,
88-
scale_dtype: str = "fp16",
89-
to_quant_block_names: list = None,
90-
low_cpu_mem_usage: bool = False,
91-
export_format: str = "itrex",
92-
# v0.4
93-
enable_norm_bias_tuning: bool = False,
94-
enable_torch_compile: bool = None,
95-
# mllm
96-
quant_nontext_module: bool = False,
97-
extra_data_dir: str = None,
98-
image_processor=None,
99-
processor=None,
100-
template: Union[str, Template] = None,
101-
truncation: bool = False,
102-
# 0.7
103-
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
104-
# diffusion
105-
guidance_scale: float = 7.5,
106-
num_inference_steps: int = 50,
107-
generator_seed: int = None,
108-
# 0.9
109-
target_bits: int = None,
110-
options: Union[str, list[Union[str]], tuple[Union[str], ...]] = ("MXFP4", "MXFP8"),
111-
shared_layers: Optional[Iterable[Iterable[str]]] = None,
112-
ignore_scale_zp_bits: bool = False,
113-
auto_scheme_method: str = "default",
114-
auto_scheme_batch_size: int = None,
115-
auto_scheme_device_map: str = None,
53+
quant_config: Optional[dict] = None,
11654
**kwargs,
11755
):
11856
"""Init a AutQRoundQuantizer object.
@@ -193,71 +131,14 @@ def __init__(
193131
Returns:
194132
The quantized model.
195133
"""
196-
super().__init__(quant_config)
197-
self.layer_config = layer_config
198-
self.output_dir = kwargs.pop("output_dir", "temp_auto_round")
199-
self.tokenizer = kwargs.pop("tokenizer", "Placeholder") # for AutoRound initialization
200-
self.enable_full_range = enable_full_range
201-
self.bits = bits
202-
self.group_size = group_size
203-
self.sym = sym
204-
self.data_type = data_type
205-
self.act_bits = act_bits
206-
self.act_group_size = act_group_size
207-
self.act_sym = act_sym
208-
self.act_data_type = act_data_type
209-
self.act_dynamic = act_dynamic
210-
self.super_bits = super_bits
211-
self.super_group_size = super_group_size
212-
self.batch_size = batch_size
213-
self.amp = amp
134+
super().__init__(quant_config=quant_config)
135+
for k, v in kwargs.items():
136+
setattr(self, k, v)
214137
self.accelerator = get_accelerator(kwargs.pop("device", "auto"))
215138
self.device = self.accelerator.name()
216-
self.lr_scheduler = lr_scheduler
217-
self.dataset = dataset
218-
self.enable_quanted_input = enable_quanted_input
219-
self.enable_minmax_tuning = enable_minmax_tuning
220-
self.lr = lr
221-
self.minmax_lr = minmax_lr
222-
self.low_gpu_mem_usage = low_gpu_mem_usage
223-
self.iters = iters
224-
self.seqlen = seqlen
225-
self.nsamples = nsamples
226-
self.sampler = sampler
227-
self.seed = seed
228-
self.nblocks = nblocks
229-
self.gradient_accumulate_steps = gradient_accumulate_steps
230-
self.not_use_best_mse = not_use_best_mse
231-
self.dynamic_max_gap = dynamic_max_gap
232-
self.scale_dtype = scale_dtype
233-
self.to_quant_block_names = to_quant_block_names
234-
self.low_cpu_mem_usage = low_cpu_mem_usage
235-
self.export_format = export_format
236-
self.enable_norm_bias_tuning = enable_norm_bias_tuning
237-
self.enable_torch_compile = enable_torch_compile
238-
self.quant_nontext_module = quant_nontext_module
239-
self.extra_data_dir = extra_data_dir
240-
self.processor = processor
241-
self.image_processor = image_processor
242-
self.template = template
243-
self.truncation = truncation
244-
self.scheme = scheme
245-
self.device_map = device_map
246-
self.quant_lm_head = quant_lm_head
247-
self.enable_w4afp8 = self._is_w4afp8()
248-
self.guidance_scale = guidance_scale
249-
self.num_inference_steps = num_inference_steps
250-
self.generator_seed = generator_seed
251-
self.target_bits = target_bits
252-
self.options = options
253-
self.shared_layers = shared_layers
254-
self.ignore_scale_zp_bits = ignore_scale_zp_bits
255-
self.auto_scheme_method = auto_scheme_method
256-
self.auto_scheme_batch_size = auto_scheme_batch_size
257-
self.auto_scheme_device_map = auto_scheme_device_map
258139

259140
def _is_w4afp8(self) -> bool:
260-
return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()])
141+
return self.data_type == "fp8_to_int_sym"
261142

262143
def prepare(self, model: torch.nn.Module, *args, **kwargs):
263144
"""Prepares a given model for quantization.
@@ -290,7 +171,9 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
290171
model = model.orig_model
291172
if pipe is not None:
292173
model = pipe
293-
if self.target_bits is not None:
174+
# Remove AutoRound specific args before passing to AutoRound constructor
175+
keys_to_pop = ["quant_config", "device", "export_format", "output_dir", "accelerator", "reloading"]
176+
if hasattr(self, "target_bits") and self.target_bits is not None:
294177
from auto_round import AutoScheme
295178

296179
self.scheme = AutoScheme(
@@ -303,65 +186,28 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
303186
device_map=self.auto_scheme_device_map,
304187
low_gpu_mem_usage=self.low_gpu_mem_usage,
305188
)
189+
# Remove AutoRound specific AutoScheme args before passing to AutoRound constructor
190+
keys_to_pop += [
191+
"target_bits",
192+
"options",
193+
"shared_layers",
194+
"ignore_scale_zp_bits",
195+
"auto_scheme_method",
196+
"auto_scheme_batch_size",
197+
"auto_scheme_device_map",
198+
]
306199

307200
rounder = AutoRound(
308201
model,
309-
layer_config=self.layer_config,
310-
bits=self.bits,
311-
data_type=self.data_type,
312-
group_size=self.group_size,
313-
sym=self.sym,
314-
act_bits=self.act_bits,
315-
act_group_size=self.act_group_size,
316-
act_sym=self.act_sym,
317-
act_data_type=self.act_data_type,
318-
act_dynamic=self.act_dynamic,
319-
super_bits=self.super_bits,
320-
super_group_size=self.super_group_size,
321202
tokenizer=tokenizer,
322-
scheme=self.scheme,
323-
processor=self.processor,
324-
image_processor=self.image_processor,
325-
enable_full_range=self.enable_full_range,
326-
batch_size=self.batch_size,
327-
amp=self.amp,
328-
device_map=self.device_map,
329-
lr_scheduler=self.lr_scheduler,
330-
dataset=self.dataset,
331-
extra_data_dir=self.extra_data_dir,
332-
template=self.template,
333-
quant_nontext_module=self.quant_nontext_module,
334-
enable_quanted_input=self.enable_quanted_input,
335-
enable_minmax_tuning=self.enable_minmax_tuning,
336-
lr=self.lr,
337-
minmax_lr=self.minmax_lr,
338-
low_gpu_mem_usage=self.low_gpu_mem_usage,
339-
low_cpu_mem_usage=self.low_gpu_mem_usage,
340-
iters=self.iters,
341-
seqlen=self.seqlen,
342-
nsamples=self.nsamples,
343-
sampler=self.sampler,
344-
seed=self.seed,
345-
nblocks=self.nblocks,
346-
gradient_accumulate_steps=self.gradient_accumulate_steps,
347-
not_use_best_mse=self.not_use_best_mse,
348-
dynamic_max_gap=self.dynamic_max_gap,
349-
scale_dtype=self.scale_dtype,
350-
to_quant_block_names=self.to_quant_block_names,
351-
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
352-
truncation=self.truncation,
353-
enable_torch_compile=self.enable_torch_compile,
354-
quant_lm_head=self.quant_lm_head,
355-
guidance_scale=self.guidance_scale,
356-
num_inference_steps=self.num_inference_steps,
357-
generator_seed=self.generator_seed,
203+
**{k: v for k, v in self.__dict__.items() if k not in keys_to_pop},
358204
)
359205

360-
if self.enable_w4afp8:
206+
if self._is_w4afp8():
361207
model, weight_config = rounder.quantize()
362208
model.autoround_config = weight_config
363209
return rounder.save_quantized(output_dir=self.output_dir, inplace=True)
364-
elif "itrex" in self.export_format:
210+
elif "itrex" in self.export_format: # TODO: remove itrex related code later
365211
model, weight_config = rounder.quantize()
366212
model.autoround_config = weight_config
367213
model = pack_model(model, weight_config, device=self.device, inplace=True)
@@ -373,10 +219,14 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
373219
self.accelerator.empty_cache()
374220
dump_model_op_stats(rounder.layer_config)
375221

376-
if self.export_format in ["auto_round", "llm_compressor"]:
222+
reloading = self.__dict__.get("reloading", True)
223+
if self.export_format in ["auto_round", "llm_compressor"] and reloading:
377224
# the directly returned model is QuantLinear, which is used for packing.
378225
try:
379-
logger.info(f"Quantization is done, reloading model from saved directory({self.output_dir})...")
226+
logger.info(
227+
f"Quantization is done, reloading model from saved directory({self.output_dir})...\n"
228+
"Set reloading=False to skip."
229+
)
380230
import transformers # pylint: disable=E0401
381231

382232
model = transformers.AutoModelForCausalLM.from_pretrained(self.output_dir)

0 commit comments

Comments
 (0)