2121import numpy as np
2222
2323from tensorflow .python import keras
24+ from tensorflow .python .keras import backend as K
2425from tensorflow .python .platform import test
2526from tensorflow_model_optimization .python .core .quantization .keras import quantize_annotate as quant_annotate
2627from tensorflow_model_optimization .python .core .quantization .keras import quantize_aware_activation
@@ -195,6 +196,20 @@ def _get_annotated_functional_model(self):
195196
196197 return keras .Model (inputs = inputs , outputs = results )
197198
199+ def _assert_weights_equal_value (self , annotated_weights , emulated_weights ):
200+ annotated_weight_values = K .batch_get_value (annotated_weights )
201+ emulated_weight_values = K .batch_get_value (emulated_weights )
202+
203+ self .assertEqual (len (annotated_weight_values ), len (emulated_weight_values ))
204+ for aw , ew in zip (annotated_weight_values , emulated_weight_values ):
205+ self .assertAllClose (aw , ew )
206+
207+ def _assert_weights_different_objects (
208+ self , annotated_weights , emulated_weights ):
209+ self .assertEqual (len (annotated_weights ), len (emulated_weights ))
210+ for aw , ew in zip (annotated_weights , emulated_weights ):
211+ self .assertNotEqual (id (aw ), id (ew ))
212+
198213 def _assert_layer_emulated (
199214 self , annotated_layer , emulated_layer , exclude_keys = None ):
200215 self .assertIsInstance (emulated_layer , QuantizeEmulateWrapper )
@@ -216,6 +231,20 @@ def _assert_layer_emulated(
216231
217232 self .assertEqual (annotated_config , emulated_config )
218233
234+ def _sort_weights (weights ):
235+ # Variables are named `quantize_annotate0/kernel:0` and
236+ # `quantize_emulate0/kernel:0`. Strip layer name to sort.
237+ return sorted (weights , key = lambda w : w .name .split ('/' )[1 ])
238+
239+ annotated_weights = _sort_weights (annotated_layer .trainable_weights )
240+ emulated_weights = _sort_weights (emulated_layer .trainable_weights )
241+
242+ # Quantized model should pick the same weight values from the original
243+ # model. However, they should not be the same weight objects. We don't
244+ # want training the quantized model to change weights in the original model.
245+ self ._assert_weights_different_objects (annotated_weights , emulated_weights )
246+ self ._assert_weights_equal_value (annotated_weights , emulated_weights )
247+
219248 def _assert_model_emulated (
220249 self , annotated_model , emulated_model , exclude_keys = None ):
221250 for annotated_layer , emulated_layer in zip (annotated_model .layers ,
0 commit comments