Skip to content

[bug] PerGroup(1) will throw RuntimeError #3458

@Freed-Wu

Description

@Freed-Wu
import torch
from torchao.quantization import IntxWeightOnlyConfig, quantize_
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.granularity import PerGroup


class ToyLinearModel(torch.nn.Module):
    def __init__(self, m: int, n: int, k: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False)
        self.linear2 = torch.nn.Linear(n, k, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


model = ToyLinearModel(32, 32, 32).eval()

# Optional: compile model for faster inference and generation
# model = torch.compile(model, mode="max-autotune", fullgraph=True)
# model_bf16 = copy.deepcopy(model)
config = IntxWeightOnlyConfig(torch.int4, PerGroup(1), mapping_type=MappingType.ASYMMETRIC)
quantize_(model, config)
inp = torch.ones(2, 32)
model(inp)
$ python a.py
Traceback (most recent call last):
  File "/home/wzy/Desktop/ao/a.py", line 25, in <module>
    quantize_(model, config)
    ~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 498, in quantize_
    _replace_with_custom_fn_if_matches_filter(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        model,
        ^^^^^^
    ...<3 lines>...
        extra_args=(config,),
        ^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 214, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
        child,
    ...<4 lines>...
        extra_args,
    )
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 209, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model, *extra_args)
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 2375, in _intx_weight_only_transform
    new_weight = _intx_weight_only_quantize_tensor(
        module.weight,
    ...<2 lines>...
        custom_zero_point=custom_zero_point,
    )
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 2320, in _intx_weight_only_quantize_tensor
    new_weight = IntxUnpackedToInt8Tensor.from_hp(
        weight,
    ...<5 lines>...
        intx_choose_qparams_algorithm=intx_choose_qparams_algorithm,
    )
  File "/home/wzy/Desktop/ao/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py", line 233, in from_hp
    qdata = quantize_affine(
        hp_tensor,
    ...<5 lines>...
        quant_max=qmax,
    )
  File "/usr/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 357, in quantize_affine
    return _quantize_affine(
        input,
    ...<5 lines>...
        quant_max,
    )
  File "/usr/lib/python3.13/site-packages/torch/_ops.py", line 1158, in __call__
    return self._op(*args, **(kwargs or {}))
           ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 403, in _quantize_affine
    return _quantize_affine_no_dtype_cast(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        input,
        ^^^^^^
    ...<4 lines>...
        quant_max,
        ^^^^^^^^^^
    ).to(output_dtype)
    ^
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 460, in _quantize_affine_no_dtype_cast
    scale = scale.view(shape_after_reduction)
RuntimeError: shape '[32, 32]' is invalid for input of size 1

Other PerGroup(x) can work.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions