3333
3434@six .add_metaclass (abc .ABCMeta )
3535class Quantizer (object ):
36- """ABC interface which contains logic to quantize a tensor."""
36+ """ABC interface which encapsulates the logic of how to quantize tensors.
37+
38+ A `Quantizer` is used by the library code to apply the mathematical
39+ transformations which actually quantize a tensor, hence allowing the user
40+ precise control over the algorithm with which tensors are quantized. When used
41+ in conjunction with `QuantizeConfig` it controls how a layer is quantized.
42+
43+ Create a custom quantizer:
44+
45+ ```python
46+ class FixedRangeQuantizer(Quantizer):
47+ # Example quantizer which clips tensors in a fixed range.
48+
49+ def build(self, tensor_shape, name, layer):
50+ range_var = layer.add_weight(
51+ name + '_range',
52+ initializer=keras.initializers.Constant(6.0),
53+ trainable=False)
54+
55+ return {
56+ 'range_var': range_var,
57+ }
58+
59+ def __call__(self, inputs, training, weights, **kwargs):
60+ return tf.keras.backend.clip(
61+ inputs, 0.0, weights['range_var'])
62+
63+ def get_config(self):
64+ # Not needed. No __init__ parameters to serialize.
65+ return {}
66+ ```
67+
68+ For a full example, see
69+ https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md
70+ """
3771
3872 @abc .abstractmethod
3973 def build (self , tensor_shape , name , layer ):
40- """Constructs the weights required by the quantizer.
74+ """Construct the weights required by the quantizer.
75+
76+ A quantizer may need to construct variables to hold the state for its
77+ algorithm. This function is invoked during the `build` stage of the layer
78+ that the quantizer is used for. Any variables constructed are under the
79+ scope of the `layer` and serialized as part of the layer.
4180
4281 Args:
4382 tensor_shape: Shape of tensor which needs to be quantized.
@@ -46,27 +85,30 @@ def build(self, tensor_shape, name, layer):
4685 to construct the weights, and is also the owner of the weights.
4786
4887 Returns: Dictionary of constructed weights. This dictionary will be
49- unpacked and passed to the quantizer's __call__ function as kwargs .
88+ passed to the quantizer's __call__ function as a `weights` dictionary .
5089 """
5190
5291 @abc .abstractmethod
5392 def __call__ (self , inputs , training , weights , ** kwargs ):
5493 """Apply quantization to the input tensor.
5594
56- The `step` variable allows a user to design a custom quantizer which
57- modifies quantization behavior as training progresses.
95+ This is the main function of the `Quantizer` which implements the core logic
96+ to quantize the tensor. It is invoked during the `call` stage of the layer,
97+ and allows modifying the tensors used in graph construction.
5898
5999 Args:
60100 inputs: Input tensor to be quantized.
61101 training: Whether the graph is currently training.
62102 weights: Dictionary of weights the quantizer can use to quantize the
63103 tensor. This contains the weights created in the `build` function.
64104 **kwargs: Additional variables which may be passed to the quantizer.
105+
65106 Returns: quantized tensor.
66107 """
67108
68109 @abc .abstractmethod
69110 def get_config (self ):
111+ """Returns the config used to serialize the `Quantizer`."""
70112 raise NotImplementedError ('Quantizer should implement get_config().' )
71113
72114 @classmethod
0 commit comments