-
Notifications
You must be signed in to change notification settings - Fork 382
Add int8 static quantization workflow #3442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: This PR creates a new Int8Tensor and updates the configs to use the new Int8Tensor flow Test Plan: To ensure BC: ``` pytest test/quantization/test_quant_api.py ``` To test new Int8Tensor: ``` pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3442
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b5309eb with merge base c4273fe ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| else: | ||
| # Scale can be provided in the case of static quant | ||
| assert scale.ndim == hp_tensor.ndim | ||
| if isinstance(granularity, PerTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: I changed these checks in #3468
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we typically also check the shape of scale tensor as well, like these
ao/torchao/float8/inference.py
Lines 181 to 198 in 08e5e20
| def _is_rowwise_scaled(x: torch.Tensor) -> bool: | |
| """Checks if a quantized tensor is rowwise scaled | |
| Args: | |
| x: quantized tensor (should have `block_size` attribute) | |
| """ | |
| assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" | |
| return tuple(x.block_size) == (1,) * (x.dim() - 1) + (x.shape[-1],) | |
| def _is_tensorwise_scaled(x: torch.Tensor) -> bool: | |
| """Checks if a quantized tensor is rowwise scaled | |
| Args: | |
| x: quantized tensor (should have `block_size` attribute) | |
| """ | |
| assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" | |
| return all( | |
| x.block_size[i] == -1 or x.block_size[i] == x.shape[i] for i in range(x.ndim) | |
| ) |
torchao/quantization/quant_api.py
Outdated
| if isinstance(self.granularity, PerTensor): | ||
| assert self.scale.numel() == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: also check the shapes, and check PerRow as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might want to enable scales to be None, for passing Int8StaticActivationInt8WeightConfig() as a base config, we can discuss on #3468
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK makes sense, sounds good, this is easier for user I think, otherwise they have to do a separate flow to get the scale here
jerryzh168
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg, see some comments inline
This PR adds in a new static quant workflow based off of Int8Tensor (#3407).
It introduces a new config,
Int8StaticActivationInt8WeightConfigwhich requires a scale tensor and granularityCurrently PerRow and PerTensor symmetric quant is support only.
This scale tensor is stored on the weight Int8Tensor under
activation_scale, and is used to create a new activation Int8Tensor for static quantization.It would be nice to store this scale tensor in
QuantizeTensorToInt8Kwargsbut unfortunately this breaks dynamo tracing, as we store the quant kwargs as an object for the weight tensor and we are unable to fakeify them properly.As a result, we need to keep track and pass scale outside of this
Kwargsobject.