diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py index 5589e90e715..27a7ac56fc6 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py @@ -51,7 +51,7 @@ def create_scale_tensor(orig_scales, scale_format): if isinstance(orig_scales, (torch.Tensor, float)): return scale_creation_func(orig_scales) elif isinstance(orig_scales, list): - return [scale_creation_func(x) for x in orig_scales] + return torch.nn.ParameterList([scale_creation_func(x) for x in orig_scales]) else: raise ValueError("unexpected scale format value {}".format(scale_format)) @@ -78,6 +78,8 @@ def scale_to_scalar(scale): def get_scale_dtype(scale): if isinstance(scale, torch.Tensor): # tensor case return scale.dtype + if isinstance(scale, torch.nn.ParameterList): # tensor case + return scale[0].dtype elif isinstance(scale, float): # already scalar case return type(scale).__name__ elif scale is None: # possible dynamic scalar case