Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quantize_.common import (
_choose_quant_func_and_quantize_tensor,
)
from torchao.quantization.utils import compute_error, get_block_size
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.testing.utils import TorchAOIntegrationTestCase
Expand Down Expand Up @@ -221,5 +225,93 @@ def test_available_gpu_kernels(self):
).check_count("triton_poi_fused", 1).run(code[0])


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestInt8StaticQuant(TorchAOIntegrationTestCase):
@common_utils.parametrize("granularity", [PerRow(), PerTensor()])
@common_utils.parametrize("dtype", [torch.bfloat16])
def test_static_activation_per_row_int8_weight_earger(self, granularity, dtype):
M, N, K = 32, 32, 32
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

model_static_quant = (
torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype)
)
model_dynamic_quant = copy.deepcopy(model_static_quant)

dynamic_config = Int8DynamicActivationInt8WeightConfig(
version=2, granularity=granularity
)
quantize_(model_dynamic_quant, dynamic_config)

dynamic_quantize_out = model_dynamic_quant(input_tensor)

int8_input = _choose_quant_func_and_quantize_tensor(
input_tensor, model_dynamic_quant.weight.act_quant_kwargs
)

static_config = Int8StaticActivationInt8WeightConfig(
scale=int8_input.scale.detach().clone(), granularity=granularity
)
quantize_(model_static_quant, static_config)

static_quantize_out = model_static_quant(input_tensor)
torch.testing.assert_close(dynamic_quantize_out, static_quantize_out)

@common_utils.parametrize("granularity", [PerRow(), PerTensor()])
@common_utils.parametrize("dtype", [torch.bfloat16])
def test_static_activation_per_row_int8_weight_compile(self, granularity, dtype):
# for compile, we can't compare dynamic vs static because we may get slightly different qparams
torch.compiler.reset()

M, N, K = 32, 32, 32
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

model = torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype)
model_static_quant = copy.deepcopy(model)
model_dynamic_quant = copy.deepcopy(model)

model_out_baseline = model(input_tensor)

dynamic_config = Int8DynamicActivationInt8WeightConfig(
version=2, granularity=granularity
)
quantize_(model_dynamic_quant, dynamic_config)

dynamic_out_eager = model_dynamic_quant(input_tensor)
sqnr_dynamic_eager = compute_error(model_out_baseline, dynamic_out_eager)

model_dynamic_quant = torch.compile(model_dynamic_quant, fullgraph=True)

dynamic_out_compile = model_dynamic_quant(input_tensor)
sqnr_dynamic_compile = compute_error(model_out_baseline, dynamic_out_compile)

# we use eager scales to calculate
int8_input = _choose_quant_func_and_quantize_tensor(
input_tensor, model_dynamic_quant.weight.act_quant_kwargs
)

static_config = Int8StaticActivationInt8WeightConfig(
scale=int8_input.scale.detach().clone(),
granularity=granularity,
)
quantize_(model_static_quant, static_config)

static_out_eager = model_static_quant(input_tensor)
sqnr_static_eager = compute_error(model_out_baseline, static_out_eager)

model_static_quant = torch.compile(model_static_quant, fullgraph=True)

static_out_compile = model_dynamic_quant(input_tensor)
sqnr_static_compile = compute_error(model_out_baseline, static_out_compile)

assert (
sqnr_static_compile
== sqnr_static_eager
== sqnr_dynamic_compile
== sqnr_dynamic_eager
)


if __name__ == "__main__":
common_utils.run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
ModuleFqnToConfig,
Expand Down Expand Up @@ -150,6 +151,7 @@
"Int8DynamicActivationInt4WeightConfig",
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int8StaticActivationInt8WeightConfig",
"Int4WeightOnlyConfig",
"Float8DynamicActivationInt4WeightConfig",
"Int8WeightOnlyConfig",
Expand Down
95 changes: 89 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
IntxPackingFormat,
IntxUnpackedToInt8Tensor,
QuantizeTensorToFloat8Kwargs,
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
Expand Down Expand Up @@ -1590,10 +1591,6 @@ def get_weight_block_size(x):
)
quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func)
else:
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
QuantizeTensorToInt8Kwargs,
)

assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor are supported"
)
Expand Down Expand Up @@ -1621,7 +1618,10 @@ def get_weight_block_size(x):

@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig)
def _int8_dynamic_activation_int8_weight_transform(
module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig
module: torch.nn.Module,
config: Int8DynamicActivationInt8WeightConfig,
*,
parameter_name="weight",
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()
Expand All @@ -1634,7 +1634,90 @@ def _int8_dynamic_activation_int8_weight_transform(
module.weight, config
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


@dataclass
class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
"""
Configuration for applying float8 static symmetric quantization to

Args:
scale (torch.Tensor): The scale tensor for activation quantization.
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
"""

scale: torch.Tensor
granularity: Granularity = PerRow()
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
set_inductor_config: bool = True
version: int = 1

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8StaticActivationInt8WeightConfig"
)
if isinstance(self.granularity, PerTensor):
assert self.scale.numel() == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: also check the shapes, and check PerRow as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might want to enable scales to be None, for passing Int8StaticActivationInt8WeightConfig() as a base config, we can discuss on #3468

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK makes sense, sounds good, this is easier for user I think, otherwise they have to do a separate flow to get the scale here



@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig)
def _int8_static_activation_int8_weight_transform(
module: torch.nn.Module,
config: Int8StaticActivationInt8WeightConfig,
*,
parameter_name="weight",
):
assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor is supported currently"
)
assert config.act_mapping_type == MappingType.SYMMETRIC, (
"asymmetric static quant not supported currently"
)
assert hasattr(module, parameter_name), (
f"Expected module to have attribute `{parameter_name}` but not found"
)

if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

activation_granularity = config.granularity
weight_granularity = config.granularity

quantized_tensor = Int8Tensor.from_hp(
getattr(module, parameter_name),
granularity=weight_granularity,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=activation_granularity,
mapping_type=config.act_mapping_type,
),
activation_scale=config.scale.detach(),
)

setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_tensor, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import abc
from typing import ClassVar
from typing import ClassVar, Optional

import torch

Expand All @@ -31,7 +31,9 @@ def from_hp(cls, tensor, quant_kwargs: QuantizeTensorKwargs)


def _choose_quant_func_and_quantize_tensor(
tensor: torch.Tensor, quant_kwargs: QuantizeTensorKwargs
tensor: torch.Tensor,
quant_kwargs: QuantizeTensorKwargs,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Given a tensor and a kwargs container, chooses a derived dtype (float8, int8, etc) to quantize tensor to, based on the type of quant_kwargs
quantizes tensor to the derived dtype chosen in (1)
Expand Down Expand Up @@ -60,6 +62,7 @@ def _choose_quant_func_and_quantize_tensor(
tensor,
quant_kwargs.granularity,
mapping_type=quant_kwargs.mapping_type,
scale=scale,
)

raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")
47 changes: 32 additions & 15 deletions torchao/quantization/quantize_/workflows/int8/int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torchao.float8.inference import _slice_scale_for_dimension
from torchao.kernel import int_scaled_matmul
from torchao.quantization.granularity import Granularity
from torchao.quantization.granularity import Granularity, PerRow, PerTensor
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
Expand Down Expand Up @@ -53,15 +53,14 @@ class Int8Tensor(TorchAOBaseTensor):
Tensor Attributes:
qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D)
scale: scale factors for dequantization
# TODO: Static quantization support using `static_scale`

Non-Tensor Attributes:
granularity: the granularity for quantization (e.g., PerRow(), PerTensor())
act_quant_kwargs: flags for dynamic activation quantization
"""

# TODO: Static quantization support using `static_scale`
tensor_data_names = ["qdata", "scale"]
optional_tensor_data_names = ["activation_scale"]
tensor_attribute_names = ["block_size", "dtype"]
optional_tensor_attribute_names = [
"act_quant_kwargs",
Expand All @@ -74,6 +73,7 @@ def __new__(
block_size: List[int],
dtype: torch.dtype,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
activation_scale=None,
):
kwargs = {
"device": qdata.device,
Expand All @@ -89,13 +89,15 @@ def __init__(
block_size: List[int],
dtype: torch.dtype,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
activation_scale=None,
):
super().__init__()
self.qdata = qdata
self.scale = scale
self.block_size = block_size
# don't set dtype because this gets done in __new__
self.act_quant_kwargs = act_quant_kwargs
self.activation_scale = activation_scale

def __repr__(self):
return (
Expand All @@ -116,22 +118,34 @@ def from_hp(
granularity: Granularity,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
mapping_type=MappingType.SYMMETRIC,
scale: Optional[torch.Tensor] = None,
activation_scale: Optional[torch.Tensor] = None,
):
"""Create Int8Tensor from high-precision tensor"""
block_size = get_block_size(hp_tensor.shape, granularity)
block_size = list(block_size)

scale, zero_point = choose_qparams_affine(
input=hp_tensor,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=torch.int8,
keepdim=True,
)
if scale is None:
scale, zero_point = choose_qparams_affine(
input=hp_tensor,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=torch.int8,
keepdim=True,
)
else:
# Scale can be provided in the case of static quant
assert scale.ndim == hp_tensor.ndim
if isinstance(granularity, PerTensor):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I changed these checks in #3468

Copy link
Contributor

@jerryzh168 jerryzh168 Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we typically also check the shape of scale tensor as well, like these

def _is_rowwise_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is rowwise scaled
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
return tuple(x.block_size) == (1,) * (x.dim() - 1) + (x.shape[-1],)
def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is rowwise scaled
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
return all(
x.block_size[i] == -1 or x.block_size[i] == x.shape[i] for i in range(x.ndim)
)

assert scale.numel() == 1
elif isinstance(granularity, PerRow):
assert scale.numel() == block_size[-1]

zero_point = torch.zeros_like(scale, dtype=torch.int8)

int_data = quantize_affine(
hp_tensor,
Expand All @@ -147,6 +161,7 @@ def from_hp(
block_size,
hp_tensor.dtype,
act_quant_kwargs=act_quant_kwargs,
activation_scale=activation_scale,
)

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
Expand Down Expand Up @@ -185,7 +200,9 @@ def _(func, types, args, kwargs):

if weight_tensor.act_quant_kwargs is not None:
activation_tensor = _choose_quant_func_and_quantize_tensor(
activation_tensor, weight_tensor.act_quant_kwargs
activation_tensor,
weight_tensor.act_quant_kwargs,
scale=weight_tensor.activation_scale,
)
# Dynamic activation quantization path

Expand Down
Loading