@@ -41,78 +41,16 @@ def _is_auto_round_available():
4141
4242from neural_compressor .common .utils import Statistics
4343from neural_compressor .torch .algorithms import Quantizer
44+ from neural_compressor .torch .algorithms .weight_only .utility import CapturedDataloader , InputCaptureModule
4445from neural_compressor .torch .utils import get_accelerator , logger
4546
46- from .utility import CapturedDataloader , InputCaptureModule
47-
4847
4948class AutoRoundQuantizer (Quantizer ):
5049 """AutoRound Quantizer."""
5150
5251 def __init__ (
5352 self ,
54- bits : int = None ,
55- group_size : int = None ,
56- sym : bool = None ,
57- data_type : str = None ,
58- act_bits : int = None ,
59- act_group_size : int = None ,
60- act_sym : bool = None ,
61- act_data_type : str = None ,
62- act_dynamic : bool = None ,
63- super_bits : int = None ,
64- super_group_size : int = None ,
65- quant_config : dict = {}, # for INC
66- layer_config : dict [str , Union [str , dict , QuantizationScheme ]] = None ,
67- enable_full_range : bool = False , ##for symmetric, TODO support later
68- batch_size : int = 8 ,
69- amp : bool = True ,
70- device_map : str = None ,
71- quant_lm_head : bool = False ,
72- lr_scheduler = None ,
73- dataset : Union [str , list , tuple , torch .utils .data .DataLoader ] = "NeelNanda/pile-10k" ,
74- enable_quanted_input : bool = True ,
75- enable_minmax_tuning : bool = True ,
76- lr : float = None ,
77- minmax_lr : float = None ,
78- low_gpu_mem_usage : bool = False ,
79- iters : int = 200 ,
80- seqlen : int = 2048 ,
81- nsamples : int = 128 ,
82- sampler : str = "rand" ,
83- seed : int = 42 ,
84- nblocks : int = 1 ,
85- gradient_accumulate_steps : int = 1 ,
86- not_use_best_mse : bool = False ,
87- dynamic_max_gap : int = - 1 ,
88- scale_dtype : str = "fp16" ,
89- to_quant_block_names : list = None ,
90- low_cpu_mem_usage : bool = False ,
91- export_format : str = "itrex" ,
92- # v0.4
93- enable_norm_bias_tuning : bool = False ,
94- enable_torch_compile : bool = None ,
95- # mllm
96- quant_nontext_module : bool = False ,
97- extra_data_dir : str = None ,
98- image_processor = None ,
99- processor = None ,
100- template : Union [str , Template ] = None ,
101- truncation : bool = False ,
102- # 0.7
103- scheme : Union [str , dict , QuantizationScheme ] = "W4A16" ,
104- # diffusion
105- guidance_scale : float = 7.5 ,
106- num_inference_steps : int = 50 ,
107- generator_seed : int = None ,
108- # 0.9
109- target_bits : int = None ,
110- options : Union [str , list [Union [str ]], tuple [Union [str ], ...]] = ("MXFP4" , "MXFP8" ),
111- shared_layers : Optional [Iterable [Iterable [str ]]] = None ,
112- ignore_scale_zp_bits : bool = False ,
113- auto_scheme_method : str = "default" ,
114- auto_scheme_batch_size : int = None ,
115- auto_scheme_device_map : str = None ,
53+ quant_config : Optional [dict ] = None ,
11654 ** kwargs ,
11755 ):
11856 """Init a AutQRoundQuantizer object.
@@ -193,71 +131,14 @@ def __init__(
193131 Returns:
194132 The quantized model.
195133 """
196- super ().__init__ (quant_config )
197- self .layer_config = layer_config
198- self .output_dir = kwargs .pop ("output_dir" , "temp_auto_round" )
199- self .tokenizer = kwargs .pop ("tokenizer" , "Placeholder" ) # for AutoRound initialization
200- self .enable_full_range = enable_full_range
201- self .bits = bits
202- self .group_size = group_size
203- self .sym = sym
204- self .data_type = data_type
205- self .act_bits = act_bits
206- self .act_group_size = act_group_size
207- self .act_sym = act_sym
208- self .act_data_type = act_data_type
209- self .act_dynamic = act_dynamic
210- self .super_bits = super_bits
211- self .super_group_size = super_group_size
212- self .batch_size = batch_size
213- self .amp = amp
134+ super ().__init__ (quant_config = quant_config )
135+ for k , v in kwargs .items ():
136+ setattr (self , k , v )
214137 self .accelerator = get_accelerator (kwargs .pop ("device" , "auto" ))
215138 self .device = self .accelerator .name ()
216- self .lr_scheduler = lr_scheduler
217- self .dataset = dataset
218- self .enable_quanted_input = enable_quanted_input
219- self .enable_minmax_tuning = enable_minmax_tuning
220- self .lr = lr
221- self .minmax_lr = minmax_lr
222- self .low_gpu_mem_usage = low_gpu_mem_usage
223- self .iters = iters
224- self .seqlen = seqlen
225- self .nsamples = nsamples
226- self .sampler = sampler
227- self .seed = seed
228- self .nblocks = nblocks
229- self .gradient_accumulate_steps = gradient_accumulate_steps
230- self .not_use_best_mse = not_use_best_mse
231- self .dynamic_max_gap = dynamic_max_gap
232- self .scale_dtype = scale_dtype
233- self .to_quant_block_names = to_quant_block_names
234- self .low_cpu_mem_usage = low_cpu_mem_usage
235- self .export_format = export_format
236- self .enable_norm_bias_tuning = enable_norm_bias_tuning
237- self .enable_torch_compile = enable_torch_compile
238- self .quant_nontext_module = quant_nontext_module
239- self .extra_data_dir = extra_data_dir
240- self .processor = processor
241- self .image_processor = image_processor
242- self .template = template
243- self .truncation = truncation
244- self .scheme = scheme
245- self .device_map = device_map
246- self .quant_lm_head = quant_lm_head
247- self .enable_w4afp8 = self ._is_w4afp8 ()
248- self .guidance_scale = guidance_scale
249- self .num_inference_steps = num_inference_steps
250- self .generator_seed = generator_seed
251- self .target_bits = target_bits
252- self .options = options
253- self .shared_layers = shared_layers
254- self .ignore_scale_zp_bits = ignore_scale_zp_bits
255- self .auto_scheme_method = auto_scheme_method
256- self .auto_scheme_batch_size = auto_scheme_batch_size
257- self .auto_scheme_device_map = auto_scheme_device_map
258139
259140 def _is_w4afp8 (self ) -> bool :
260- return any ([ v . get ( " data_type" , None ) == "fp8_to_int_sym" for v in self . quant_config . values ()])
141+ return self . data_type == "fp8_to_int_sym"
261142
262143 def prepare (self , model : torch .nn .Module , * args , ** kwargs ):
263144 """Prepares a given model for quantization.
@@ -290,7 +171,9 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
290171 model = model .orig_model
291172 if pipe is not None :
292173 model = pipe
293- if self .target_bits is not None :
174+ # Remove AutoRound specific args before passing to AutoRound constructor
175+ keys_to_pop = ["quant_config" , "device" , "export_format" , "output_dir" , "accelerator" , "reloading" ]
176+ if hasattr (self , "target_bits" ) and self .target_bits is not None :
294177 from auto_round import AutoScheme
295178
296179 self .scheme = AutoScheme (
@@ -303,65 +186,28 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
303186 device_map = self .auto_scheme_device_map ,
304187 low_gpu_mem_usage = self .low_gpu_mem_usage ,
305188 )
189+ # Remove AutoRound specific AutoScheme args before passing to AutoRound constructor
190+ keys_to_pop += [
191+ "target_bits" ,
192+ "options" ,
193+ "shared_layers" ,
194+ "ignore_scale_zp_bits" ,
195+ "auto_scheme_method" ,
196+ "auto_scheme_batch_size" ,
197+ "auto_scheme_device_map" ,
198+ ]
306199
307200 rounder = AutoRound (
308201 model ,
309- layer_config = self .layer_config ,
310- bits = self .bits ,
311- data_type = self .data_type ,
312- group_size = self .group_size ,
313- sym = self .sym ,
314- act_bits = self .act_bits ,
315- act_group_size = self .act_group_size ,
316- act_sym = self .act_sym ,
317- act_data_type = self .act_data_type ,
318- act_dynamic = self .act_dynamic ,
319- super_bits = self .super_bits ,
320- super_group_size = self .super_group_size ,
321202 tokenizer = tokenizer ,
322- scheme = self .scheme ,
323- processor = self .processor ,
324- image_processor = self .image_processor ,
325- enable_full_range = self .enable_full_range ,
326- batch_size = self .batch_size ,
327- amp = self .amp ,
328- device_map = self .device_map ,
329- lr_scheduler = self .lr_scheduler ,
330- dataset = self .dataset ,
331- extra_data_dir = self .extra_data_dir ,
332- template = self .template ,
333- quant_nontext_module = self .quant_nontext_module ,
334- enable_quanted_input = self .enable_quanted_input ,
335- enable_minmax_tuning = self .enable_minmax_tuning ,
336- lr = self .lr ,
337- minmax_lr = self .minmax_lr ,
338- low_gpu_mem_usage = self .low_gpu_mem_usage ,
339- low_cpu_mem_usage = self .low_gpu_mem_usage ,
340- iters = self .iters ,
341- seqlen = self .seqlen ,
342- nsamples = self .nsamples ,
343- sampler = self .sampler ,
344- seed = self .seed ,
345- nblocks = self .nblocks ,
346- gradient_accumulate_steps = self .gradient_accumulate_steps ,
347- not_use_best_mse = self .not_use_best_mse ,
348- dynamic_max_gap = self .dynamic_max_gap ,
349- scale_dtype = self .scale_dtype ,
350- to_quant_block_names = self .to_quant_block_names ,
351- enable_norm_bias_tuning = self .enable_norm_bias_tuning ,
352- truncation = self .truncation ,
353- enable_torch_compile = self .enable_torch_compile ,
354- quant_lm_head = self .quant_lm_head ,
355- guidance_scale = self .guidance_scale ,
356- num_inference_steps = self .num_inference_steps ,
357- generator_seed = self .generator_seed ,
203+ ** {k : v for k , v in self .__dict__ .items () if k not in keys_to_pop },
358204 )
359205
360- if self .enable_w4afp8 :
206+ if self ._is_w4afp8 () :
361207 model , weight_config = rounder .quantize ()
362208 model .autoround_config = weight_config
363209 return rounder .save_quantized (output_dir = self .output_dir , inplace = True )
364- elif "itrex" in self .export_format :
210+ elif "itrex" in self .export_format : # TODO: remove itrex related code later
365211 model , weight_config = rounder .quantize ()
366212 model .autoround_config = weight_config
367213 model = pack_model (model , weight_config , device = self .device , inplace = True )
@@ -373,10 +219,14 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
373219 self .accelerator .empty_cache ()
374220 dump_model_op_stats (rounder .layer_config )
375221
376- if self .export_format in ["auto_round" , "llm_compressor" ]:
222+ reloading = self .__dict__ .get ("reloading" , True )
223+ if self .export_format in ["auto_round" , "llm_compressor" ] and reloading :
377224 # the directly returned model is QuantLinear, which is used for packing.
378225 try :
379- logger .info (f"Quantization is done, reloading model from saved directory({ self .output_dir } )..." )
226+ logger .info (
227+ f"Quantization is done, reloading model from saved directory({ self .output_dir } )...\n "
228+ "Set reloading=False to skip."
229+ )
380230 import transformers # pylint: disable=E0401
381231
382232 model = transformers .AutoModelForCausalLM .from_pretrained (self .output_dir )
0 commit comments