@@ -28,42 +28,72 @@ def quantize(
2828 """
2929
3030 with unset_fake_temporarily ():
31- if isinstance (input_tensor , TRTTensor ) and input_tensor .dtype not in (
32- trt .float32 ,
33- trt .float16 ,
34- ):
35- raise ValueError (
36- f"quantize converter received an input of { input_tensor .dtype } type. Supported types: float32 | float16"
37- )
31+ if isinstance (input_tensor , (torch .Tensor , TRTTensor )):
32+ if input_tensor .dtype not in (
33+ trt .float32 ,
34+ trt .float16 ,
35+ trt .bfloat16 ,
36+ torch .bfloat16 ,
37+ torch .float16 ,
38+ torch .float32 ,
39+ ):
40+ raise ValueError (
41+ f"quantize converter received an input of { input_tensor .dtype } type. Supported types: float32 | float16 | bfloat16"
42+ )
3843 if num_bits != 8 or exponent_bits not in (0 , 4 ):
3944 raise ValueError (
4045 f"quantize converter currently only accept INT8 or FP8 based quantize, got { num_bits = } , { exponent_bits = } "
4146 )
47+ else :
48+ raise ValueError (
49+ f"quantize converter received an input of { type (input_tensor )} type. Supported types: torch.Tensor | TRTTensor"
50+ )
51+
4252 if num_bits == 8 and exponent_bits == 0 :
53+ dtype = trt .DataType .INT8
4354 max_bound = 127
4455 elif num_bits == 8 and exponent_bits == 4 :
56+ dtype = trt .DataType .FP8
4557 max_bound = 448
4658
4759 amax = to_torch (amax , None )
60+ axis = None
61+ # int8 weight quantization is per-channel quantization(it can have one or multiple amax values)
62+ if dtype == trt .DataType .INT8 and amax .numel () > 1 :
63+ # if the amax has more than one element, calculate the axis, otherwise axis value will be ignored
64+ amax_init_shape = amax .shape
65+ amax = amax .squeeze ().data
66+ assert (
67+ len (amax .shape ) == 1
68+ ), f"TensorRT does not support multi-axis quantization. { name = } { amax_init_shape = } { amax .shape = } "
69+ axis = list (amax_init_shape ).index (list (amax .shape )[0 ])
70+ assert (
71+ axis == 0
72+ ), f"{ name = } { amax = } is per-channel quantization, expected axis to be 0, but got { axis = } "
73+ else :
74+ # int8 activation and fp8 weight/activation quantization is per-tensor quantization, it can only have single amax value
75+ assert (
76+ amax .numel () == 1
77+ ), f"{ name = } is per-tensor quantization, expected amax is a singular value, but got { amax .shape = } "
4878 scale = torch .divide (amax , max_bound )
79+ scale .masked_fill_ (scale == 0 , 1.0 )
4980 scale = get_trt_tensor (ctx , scale , name + "_scale" )
50- # Add Q node
51- quantize_layer = ctx .net .add_quantize (input_tensor , scale )
52- if num_bits == 8 and exponent_bits == 0 :
53- quantize_layer .set_output_type (0 , trt .DataType .INT8 )
54- elif num_bits == 8 and exponent_bits == 4 :
55- quantize_layer .set_output_type (0 , trt .DataType .FP8 )
81+ input_tensor = get_trt_tensor (ctx , input_tensor , name )
5682
83+ # Add Q node
84+ quantize_layer = ctx .net .add_quantize (input_tensor , scale , dtype )
85+ if axis is not None :
86+ quantize_layer .axis = axis
5787 set_layer_name (quantize_layer , target , name + "_quantize" , source_ir )
5888 q_output = quantize_layer .get_output (0 )
5989 # Add DQ node
60- dequantize_layer = ctx .net .add_dequantize (q_output , scale )
90+ dequantize_layer = ctx .net .add_dequantize (
91+ q_output , scale , output_type = input_tensor .dtype
92+ )
93+ dequantize_layer .to_type = input_tensor .dtype
94+ if axis is not None :
95+ dequantize_layer .axis = axis
6196 set_layer_name (dequantize_layer , target , name + "_dequantize" , source_ir )
62- if num_bits == 8 and exponent_bits == 0 :
63- dequantize_layer .precision = trt .DataType .INT8
64- elif num_bits == 8 and exponent_bits == 4 :
65- # Set DQ layer precision to FP8
66- dequantize_layer .precision = trt .DataType .FP8
6797 dq_output = dequantize_layer .get_output (0 )
6898
6999 return dq_output
0 commit comments