Skip to content

Commit 46eff9b

Browse files
authored
autotune target_bits example for llama recipe (#2344)
* autotune target_bits example for llama recipe Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent abb99f1 commit 46eff9b

File tree

16 files changed

+1018
-8440
lines changed

16 files changed

+1018
-8440
lines changed

docs/source/3x/PT_MXQuant.md

Lines changed: 130 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ The exponent (exp) is equal to clamp(floor(log2(amax)) - maxExp, -127, 127), MAX
8585

8686
To get a model quantized with Microscaling Data Types, users can use the AutoRound Quantization API as follows.
8787

88+
### Basic Usage
89+
90+
The following example demonstrates how to quantize a model using MX data types:
91+
8892
```python
8993
from neural_compressor.torch.quantization import AutoRoundConfig, prepare, convert
9094
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -98,13 +102,13 @@ output_dir = "./saved_inc"
98102

99103
# quantization configuration
100104
quant_config = AutoRoundConfig(
101-
tokenizer=tokenizer,
102-
nsamples=32,
103-
seqlen=32,
104-
iters=20,
105-
scheme="MXFP4", # MXFP4, MXFP8
106-
export_format="auto_round",
107-
output_dir=output_dir, # default is "temp_auto_round"
105+
tokenizer=tokenizer, # Tokenizer for processing calibration data
106+
nsamples=32, # Number of calibration samples (default: 128)
107+
seqlen=32, # Sequence length of calibration data (default: 2048)
108+
iters=20, # Number of optimization iterations (default: 200)
109+
scheme="MXFP4", # MX quantization scheme: "MXFP4", "MXFP8"
110+
export_format="auto_round", # Export format for the quantized model
111+
output_dir=output_dir, # Directory to save the quantized model (default: "temp_auto_round")
108112
)
109113

110114
# quantize the model and save to output_dir
@@ -120,9 +124,127 @@ inputs = tokenizer(text, return_tensors="pt").to(model.device)
120124
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=10)[0]))
121125
```
122126

127+
### Advantages of MX Quantization
128+
129+
1. **Hardware-Friendly**: Uses power-of-2 scaling factors for efficient hardware implementation
130+
2. **Fine-Grained Quantization**: Per-block scaling (block size = 32) provides better accuracy than per-tensor or per-channel methods
131+
3. **Zero-Point Free**: No zero-point overhead, simplifying computation
132+
4. **Memory Efficient**: Significantly reduces model size while maintaining competitive accuracy
133+
5. **Energy Efficient**: Lower energy consumption for multiply-accumulate operations compared to traditional data types
134+
135+
## Mix Precision (MXFP4 + MXFP8)
136+
137+
To achieve optimal compression ratios with acceptable accuracy, we integrate AutoRound automatic mix-precision algorithm. The mix-precision approach combines MXFP4 and MXFP8 formats to quantize different layers of the model based on their sensitivity to quantization.
138+
139+
### Benefits of Mix Precision
140+
141+
- **Better Accuracy-Compression Trade-off**: Sensitive layers use MXFP8 (higher precision) while less sensitive layers use MXFP4 (higher compression), optimizing the overall model performance.
142+
- **Flexible Configuration**: Users can customize the precision assignment strategy based on their specific accuracy and compression requirements.
143+
- **Automatic Layer Selection**: The AutoRound algorithm automatically identifies which layers should use which precision level, reducing manual tuning effort.
144+
145+
### Target Bits Configuration
146+
147+
To achieve optimal compression ratios in mixed-precision quantization, we provide the `target_bits` parameter for automated precision configuration.
148+
149+
- **Single target bit**: If you pass a single float number, it will automatically generate an optimal quantization recipe to achieve that target average bit-width.
150+
- **Multiple target bits**: If you pass multiple float numbers, it will generate multiple recipes for different target bit-widths, allowing you to compare trade-offs between model size and accuracy.
151+
152+
**Note**: For MX data type, `target_bits` ranges from 4.25 to 8.25 due to scale bits overhead.
153+
154+
### Usage Example
155+
156+
#### AutoTune with Multiple Target Bits
157+
158+
For automatically finding the best configuration across multiple target bits:
159+
160+
```python
161+
from neural_compressor.torch.quantization import AutoRoundConfig, autotune, TuningConfig
162+
from transformers import AutoModelForCausalLM, AutoTokenizer
163+
164+
fp32_model = AutoModelForCausalLM.from_pretrained(
165+
"meta-llama/Llama-3.1-8B-Instruct",
166+
device_map="auto",
167+
)
168+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
169+
170+
171+
# Define evaluation function
172+
def eval_fn(model):
173+
# Implement your evaluation logic here
174+
# Return accuracy score
175+
pass
176+
177+
178+
# Configuration with multiple target bits
179+
config = AutoRoundConfig(
180+
tokenizer=tokenizer,
181+
nsamples=128,
182+
seqlen=2048,
183+
iters=200,
184+
target_bits=[7.2, 7.5, 7.8], # Try multiple target bits
185+
options=["MXFP4", "MXFP8"],
186+
shared_layers=[
187+
["k_proj", "v_proj", "q_proj"],
188+
["gate_proj", "up_proj"],
189+
],
190+
export_format="auto_round",
191+
output_dir="./llama3.1-8B-MXFP4-MXFP8",
192+
)
193+
194+
# AutoTune to find the best configuration
195+
tuning_config = TuningConfig(config_set=[config], tolerable_loss=0.01)
196+
model = autotune(fp32_model, tuning_config, eval_fn=eval_fn)
197+
```
198+
199+
### Key Parameters for Mix Precision
200+
201+
- **target_bits**: Target average bit-width for the model. Can be a single float or a list of floats.
202+
- Single value: Generates one recipe for that specific target bit-width
203+
- Multiple values: Generates multiple recipes for comparison and selects the best one via autotune
204+
205+
- **options**: List of available data types for mixed precision (e.g., `["MXFP4", "MXFP8"]`)
206+
207+
- **shared_layers**: List of layer groups that should use the same precision. Each group is a list of layer name patterns.
208+
- Ensures architectural consistency (e.g., all attention projections use the same precision)
209+
- Improves model performance by maintaining balanced computation
210+
211+
- **tolerable_loss**: Maximum acceptable accuracy loss compared to FP32 baseline (used with autotune)
212+
213+
214+
123215
## Examples
124216

125-
- PyTorch [LLM/VLM models](/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4)
217+
### PyTorch Examples
218+
219+
- **Multimodal Models**: [Llama-4-Scout-17B-16E-Instruct with MXFP4](/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4)
220+
- **Language Models**: [Llama3 series with MXFP4/MXFP8 and Mix Precision](/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3)
221+
- Llama 3.1 8B: MXFP8, MXFP4, and Mix Precision (target_bits=7.8)
222+
- Llama 3.3 70B: MXFP8, MXFP4, and Mix Precision (target_bits=5.8)
223+
224+
## Best Practices and Tips
225+
226+
### Choosing the Right Data Type
227+
228+
| Data Type | Compression | Accuracy | Use Case | Export Format |
229+
|-----------|-------------|----------|----------|---------------|
230+
| **MXFP8** | Moderate (8-bit) | High | Production models where accuracy is critical | `auto_round` |
231+
| **MXFP4** | High (4-bit) | Moderate | Aggressive compression with acceptable accuracy loss | `auto_round` |
232+
| **MXFP4+MXFP8 Mix** | Configurable (4.25-8.25 bits) | High | Best balance between compression and accuracy | `auto_round` |
233+
234+
235+
### Common Issues and Solutions
236+
237+
**Issue**: Out of Memory (OOM) during quantization
238+
- **Solution**: Use `low_gpu_mem_usage=True`, enable `enable_torch_compile`, reduce `nsamples`, or use smaller `seqlen`
239+
240+
**Issue**: Accuracy drop is too large
241+
- **Solution**: Increase `iters`, use more `nsamples`, or try mixed precision with higher `target_bits`
242+
243+
**Issue**: Quantization is too slow
244+
- **Solution**: Reduce `iters` or set to 0 for RTN, decrease `nsamples`, enable `enable_torch_compile`
245+
246+
**Issue**: Model loading fails after quantization
247+
- **Solution**: Refer to [auto_round/llama3/inference](/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/README.md#inference)
126248

127249

128250
## Reference

examples/README.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,32 @@ Intel® Neural Compressor validated examples with multiple compression technique
4545
<td>Quantization (MXFP4)</td>
4646
<td><a href="./pytorch/multimodal-modeling/quantization/auto_round/llama4">link</a></td>
4747
</tr>
48+
<tr>
49+
<td rowspan="2">Llama-3.1-8B-Instruct</td>
50+
<td rowspan="2">Natural Language Processing</td>
51+
<td>Mixed Precision (MXFP4+MXFP8)</td>
52+
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/README.md#llama-31-8b-mxfp4-mixed-with-mxfp8-target_bits78">link</a></td>
53+
</tr>
54+
<tr>
55+
<td>Quantization (MXFP4/MXFP8/NVFP4)</td>
56+
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/README.md#demo-mxfp4-mxfp8-nvfp4-unvfp4">link</a></td>
57+
</tr>
58+
<tr>
59+
<td rowspan="2">Llama-3.1-70B-Instruct</td>
60+
<td rowspan="2">Natural Language Processing</td>
61+
<tr>
62+
<td>Quantization (MXFP8/NVFP4/uNVFP4)</td>
63+
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/README.md#llama-31-70b-mxfp8">link</a></td>
64+
</tr>
4865
<tr>
4966
<td rowspan="2">Llama-3.3-70B-Instruct</td>
5067
<td rowspan="2">Natural Language Processing</td>
5168
<td>Mixed Precision (MXFP4+MXFP8)</td>
52-
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/mix-precision#mix-precision-quantization-mxfp4--mxfp8">link</a></td>
69+
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/README.md#llama-33-70b-mxfp4-mixed-with-mxfp8-target_bits58">link</a></td>
5370
</tr>
5471
<tr>
5572
<td>Quantization (MXFP4/MXFP8/NVFP4)</td>
56-
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/mix-precision#mxfp4--mxfp8">link</a></td>
73+
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/README.md#demo-mxfp4-mxfp8-nvfp4-unvfp4">link</a></td>
5774
</tr>
5875
<tr>
5976
<td rowspan="2">gpt_j</td>

0 commit comments

Comments
 (0)