-
Notifications
You must be signed in to change notification settings - Fork 386
Add int8 static quantization workflow #3442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 29 commits
48cdb61
0b73aed
1e49945
669b6ee
9071526
1539e0f
d9a2b1b
673f228
739fd64
750db1a
9410488
45a3a76
4e2f09c
dd80cca
7f73062
ac6a2b6
f28df4a
328585e
ce4d568
a665d45
0338016
ee39691
9eb0aa9
d4a1514
3cdea56
fa9022d
8ce5cde
6f64121
8ae921d
b5309eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,7 @@ | |
| IntxPackingFormat, | ||
| IntxUnpackedToInt8Tensor, | ||
| QuantizeTensorToFloat8Kwargs, | ||
| QuantizeTensorToInt8Kwargs, | ||
| ) | ||
| from torchao.quantization.transform_module import ( | ||
| _QUANTIZE_CONFIG_HANDLER, | ||
|
|
@@ -1590,10 +1591,6 @@ def get_weight_block_size(x): | |
| ) | ||
| quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func) | ||
| else: | ||
| from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||
| QuantizeTensorToInt8Kwargs, | ||
| ) | ||
|
|
||
| assert config.granularity in {PerRow(), PerTensor()}, ( | ||
| "Only PerRow and PerTensor are supported" | ||
| ) | ||
|
|
@@ -1621,7 +1618,10 @@ def get_weight_block_size(x): | |
|
|
||
| @register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) | ||
| def _int8_dynamic_activation_int8_weight_transform( | ||
| module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig | ||
| module: torch.nn.Module, | ||
| config: Int8DynamicActivationInt8WeightConfig, | ||
| *, | ||
| parameter_name="weight", | ||
| ) -> torch.nn.Module: | ||
| if config.set_inductor_config: | ||
| torchao.quantization.utils.recommended_inductor_config_setter() | ||
|
|
@@ -1634,7 +1634,90 @@ def _int8_dynamic_activation_int8_weight_transform( | |
| module.weight, config | ||
| ) | ||
| module.weight = torch.nn.Parameter(new_weight, requires_grad=False) | ||
| module.extra_repr = types.MethodType(_linear_extra_repr, module) | ||
| module.extra_repr = types.MethodType( | ||
| partial( | ||
| _module_extra_repr, | ||
| original_extra_repr=module.extra_repr, | ||
| parameter_name=parameter_name, | ||
| ), | ||
| module, | ||
| ) | ||
| return module | ||
|
|
||
|
|
||
| @dataclass | ||
| class Int8StaticActivationInt8WeightConfig(AOBaseConfig): | ||
| """ | ||
| Configuration for applying float8 static symmetric quantization to | ||
|
|
||
| Args: | ||
| scale (torch.Tensor): The scale tensor for activation quantization. | ||
| activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m | ||
jcaip marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m | ||
jcaip marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. | ||
| set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. | ||
| """ | ||
|
|
||
| scale: torch.Tensor | ||
| granularity: Granularity = PerRow() | ||
| act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC | ||
| set_inductor_config: bool = True | ||
| version: int = 1 | ||
|
|
||
| def __post_init__(self): | ||
| torch._C._log_api_usage_once( | ||
| "torchao.quantization.Int8StaticActivationInt8WeightConfig" | ||
| ) | ||
| if isinstance(self.granularity, PerTensor): | ||
| assert self.scale.numel() == 1 | ||
|
||
|
|
||
|
|
||
| @register_quantize_module_handler(Int8StaticActivationInt8WeightConfig) | ||
| def _int8_static_activation_int8_weight_transform( | ||
| module: torch.nn.Module, | ||
| config: Int8StaticActivationInt8WeightConfig, | ||
| *, | ||
| parameter_name="weight", | ||
| ): | ||
| assert config.granularity in {PerRow(), PerTensor()}, ( | ||
| "Only PerRow and PerTensor is supported currently" | ||
| ) | ||
| assert config.act_mapping_type == MappingType.SYMMETRIC, ( | ||
| "asymmetric static quant not supported currently" | ||
| ) | ||
| assert hasattr(module, parameter_name), ( | ||
| f"Expected module to have attribute `{parameter_name}` but not found" | ||
| ) | ||
|
|
||
| if config.set_inductor_config: | ||
| torchao.quantization.utils.recommended_inductor_config_setter() | ||
|
|
||
| activation_granularity = config.granularity | ||
| weight_granularity = config.granularity | ||
|
|
||
| quantized_tensor = Int8Tensor.from_hp( | ||
| getattr(module, parameter_name), | ||
| granularity=weight_granularity, | ||
| act_quant_kwargs=QuantizeTensorToInt8Kwargs( | ||
| granularity=activation_granularity, | ||
| mapping_type=config.act_mapping_type, | ||
| ), | ||
| activation_scale=config.scale.detach(), | ||
| ) | ||
|
|
||
| setattr( | ||
| module, | ||
| parameter_name, | ||
| torch.nn.Parameter(quantized_tensor, requires_grad=False), | ||
| ) | ||
| module.extra_repr = types.MethodType( | ||
| partial( | ||
| _module_extra_repr, | ||
| original_extra_repr=module.extra_repr, | ||
| parameter_name=parameter_name, | ||
| ), | ||
| module, | ||
| ) | ||
| return module | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from torchao.float8.inference import _slice_scale_for_dimension | ||||||||||||||||||||||||||||||||||||||
| from torchao.kernel import int_scaled_matmul | ||||||||||||||||||||||||||||||||||||||
| from torchao.quantization.granularity import Granularity | ||||||||||||||||||||||||||||||||||||||
| from torchao.quantization.granularity import Granularity, PerRow, PerTensor | ||||||||||||||||||||||||||||||||||||||
| from torchao.quantization.quant_primitives import ( | ||||||||||||||||||||||||||||||||||||||
| MappingType, | ||||||||||||||||||||||||||||||||||||||
| choose_qparams_affine, | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -53,15 +53,14 @@ class Int8Tensor(TorchAOBaseTensor): | |||||||||||||||||||||||||||||||||||||
| Tensor Attributes: | ||||||||||||||||||||||||||||||||||||||
| qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D) | ||||||||||||||||||||||||||||||||||||||
| scale: scale factors for dequantization | ||||||||||||||||||||||||||||||||||||||
| # TODO: Static quantization support using `static_scale` | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Non-Tensor Attributes: | ||||||||||||||||||||||||||||||||||||||
| granularity: the granularity for quantization (e.g., PerRow(), PerTensor()) | ||||||||||||||||||||||||||||||||||||||
| act_quant_kwargs: flags for dynamic activation quantization | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # TODO: Static quantization support using `static_scale` | ||||||||||||||||||||||||||||||||||||||
| tensor_data_names = ["qdata", "scale"] | ||||||||||||||||||||||||||||||||||||||
| optional_tensor_data_names = ["activation_scale"] | ||||||||||||||||||||||||||||||||||||||
| tensor_attribute_names = ["block_size", "dtype"] | ||||||||||||||||||||||||||||||||||||||
| optional_tensor_attribute_names = [ | ||||||||||||||||||||||||||||||||||||||
| "act_quant_kwargs", | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -74,6 +73,7 @@ def __new__( | |||||||||||||||||||||||||||||||||||||
| block_size: List[int], | ||||||||||||||||||||||||||||||||||||||
| dtype: torch.dtype, | ||||||||||||||||||||||||||||||||||||||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||||||||||||||||||||||||||||||||||||||
| activation_scale=None, | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| kwargs = { | ||||||||||||||||||||||||||||||||||||||
| "device": qdata.device, | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -89,13 +89,15 @@ def __init__( | |||||||||||||||||||||||||||||||||||||
| block_size: List[int], | ||||||||||||||||||||||||||||||||||||||
| dtype: torch.dtype, | ||||||||||||||||||||||||||||||||||||||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||||||||||||||||||||||||||||||||||||||
| activation_scale=None, | ||||||||||||||||||||||||||||||||||||||
jcaip marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||
| self.qdata = qdata | ||||||||||||||||||||||||||||||||||||||
| self.scale = scale | ||||||||||||||||||||||||||||||||||||||
| self.block_size = block_size | ||||||||||||||||||||||||||||||||||||||
| # don't set dtype because this gets done in __new__ | ||||||||||||||||||||||||||||||||||||||
| self.act_quant_kwargs = act_quant_kwargs | ||||||||||||||||||||||||||||||||||||||
| self.activation_scale = activation_scale | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def __repr__(self): | ||||||||||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -116,22 +118,34 @@ def from_hp( | |||||||||||||||||||||||||||||||||||||
| granularity: Granularity, | ||||||||||||||||||||||||||||||||||||||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||||||||||||||||||||||||||||||||||||||
| mapping_type=MappingType.SYMMETRIC, | ||||||||||||||||||||||||||||||||||||||
| scale: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||
| activation_scale: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| """Create Int8Tensor from high-precision tensor""" | ||||||||||||||||||||||||||||||||||||||
| block_size = get_block_size(hp_tensor.shape, granularity) | ||||||||||||||||||||||||||||||||||||||
| block_size = list(block_size) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| scale, zero_point = choose_qparams_affine( | ||||||||||||||||||||||||||||||||||||||
| input=hp_tensor, | ||||||||||||||||||||||||||||||||||||||
| mapping_type=mapping_type, | ||||||||||||||||||||||||||||||||||||||
| block_size=block_size, | ||||||||||||||||||||||||||||||||||||||
| target_dtype=torch.int8, | ||||||||||||||||||||||||||||||||||||||
| quant_min=-128, | ||||||||||||||||||||||||||||||||||||||
| quant_max=127, | ||||||||||||||||||||||||||||||||||||||
| scale_dtype=hp_tensor.dtype, | ||||||||||||||||||||||||||||||||||||||
| zero_point_dtype=torch.int8, | ||||||||||||||||||||||||||||||||||||||
| keepdim=True, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| if scale is None: | ||||||||||||||||||||||||||||||||||||||
| scale, zero_point = choose_qparams_affine( | ||||||||||||||||||||||||||||||||||||||
| input=hp_tensor, | ||||||||||||||||||||||||||||||||||||||
| mapping_type=mapping_type, | ||||||||||||||||||||||||||||||||||||||
| block_size=block_size, | ||||||||||||||||||||||||||||||||||||||
| target_dtype=torch.int8, | ||||||||||||||||||||||||||||||||||||||
| quant_min=-128, | ||||||||||||||||||||||||||||||||||||||
| quant_max=127, | ||||||||||||||||||||||||||||||||||||||
| scale_dtype=hp_tensor.dtype, | ||||||||||||||||||||||||||||||||||||||
| zero_point_dtype=torch.int8, | ||||||||||||||||||||||||||||||||||||||
| keepdim=True, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| # Scale can be provided in the case of static quant | ||||||||||||||||||||||||||||||||||||||
| assert scale.ndim == hp_tensor.ndim | ||||||||||||||||||||||||||||||||||||||
| if isinstance(granularity, PerTensor): | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
| def _is_rowwise_scaled(x: torch.Tensor) -> bool: | |
| """Checks if a quantized tensor is rowwise scaled | |
| Args: | |
| x: quantized tensor (should have `block_size` attribute) | |
| """ | |
| assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" | |
| return tuple(x.block_size) == (1,) * (x.dim() - 1) + (x.shape[-1],) | |
| def _is_tensorwise_scaled(x: torch.Tensor) -> bool: | |
| """Checks if a quantized tensor is rowwise scaled | |
| Args: | |
| x: quantized tensor (should have `block_size` attribute) | |
| """ | |
| assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" | |
| return all( | |
| x.block_size[i] == -1 or x.block_size[i] == x.shape[i] for i in range(x.ndim) | |
| ) |
Uh oh!
There was an error while loading. Please reload this page.