@@ -184,44 +184,111 @@ def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices,
184184 self ._check_pull_values (clustering_algo , pulling_indices , expected_output )
185185
186186 @parameterized .parameters (
187- ([[0. , 1 , 2 ], [3 , 4 , 5 ]],
188- [[[[0 ], [0 ]], [[0 ], [1 ]]],
189- [[[0 ], [2 ]], [[1 ], [0 ]]]],
190- [[[[0 ], [0 ]], [[0 ], [0 ]]],
191- [[[0 ], [0 ]], [[1 ], [1 ]]]]))
192- def testConvolutionalWeightsPerChannelCA (self ,
187+ (
188+ 'channels_last' ,
189+ [[1 , 2 ], [3 , 4 ], [5 , 6 ]], # 3 channels and 2 cluster per channel
190+ # pulling indices has shape (2, 2, 1, 3)
191+ [[[[0 , 1 , 0 ]], [[0 , 1 , 1 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
192+ [[[[1 , 4 , 5 ]], [[1 , 4 , 6 ]]], [[[2 , 3 , 6 ]], [[1 , 4 , 5 ]]]]),
193+ (
194+ 'channels_first' ,
195+ [[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]
196+ ], # 4 channels and 2 clusters per channel
197+ # pulling indices has shape (1, 4, 2, 2)
198+ [[[[0 , 1 ], [1 , 1 ]], [[0 , 0 ], [0 , 1 ]], [[1 , 0 ], [0 , 0 ]],
199+ [[1 , 1 ], [0 , 0 ]]]],
200+ [[[[1 , 2 ], [2 , 2 ]], [[3 , 3 ], [3 , 4 ]], [[5 , 4 ], [4 , 4 ]],
201+ [[7 , 7 ], [6 , 6 ]]]]))
202+ def testConvolutionalWeightsPerChannelCA (self , data_format ,
193203 clustering_centroids ,
194204 pulling_indices ,
195205 expected_output ):
196- """Verifies that PerChannelCA works as expected."""
206+ """Verifies that get_clustered_weight function works as expected."""
197207 clustering_centroids = tf .Variable (clustering_centroids , dtype = tf .float32 )
198- clustering_algo = clustering_registry .PerChannelCA (
199- clustering_centroids , GradientAggregation .SUM
208+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
209+ clustering_centroids , GradientAggregation .SUM , data_format
200210 )
211+ # Note that clustered_weights has the same shape as pulling_indices,
212+ # because they are defined inside of the check function.
201213 self ._check_pull_values (clustering_algo , pulling_indices , expected_output )
202214
203215 @parameterized .parameters (
204- (GradientAggregation .AVG ,
205- [[[[0 ], [0 ]], [[0 ], [1 ]]],
206- [[[0 ], [2 ]], [[1 ], [0 ]]]], [[1 , 1 , 0 ], [1 , 1 , 1 ]]),
207- (GradientAggregation .SUM ,
208- [[[[0 ], [0 ]], [[0 ], [1 ]]],
209- [[[0 ], [2 ]], [[1 ], [0 ]]]], [[3 , 1 , 0 ], [2 , 1 , 1 ]])
210- )
211- def testConvolutionalPerChannelCAGrad (self ,
212- cluster_gradient_aggregation ,
213- pulling_indices ,
214- expected_grad_centroids ):
215- """Verifies that the gradients of convolutional layer work as expected."""
216+ (
217+ 'channels_last' ,
218+ [[1 , 2 ], [3 , 4 ], [5 , 6 ]], # 3 channels and 2 cluster per channel
219+ # weight has shape (2, 2, 1, 3)
220+ [[[[1.1 , 3.2 , 5.2 ]], [[2.0 , 4.1 , 5.2 ]]],
221+ [[[2.1 , 2. , 6.1 ]], [[1. , 5. , 5. ]]]],
222+ # expected pulling indices
223+ [[[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]]),
224+ (
225+ 'channels_first' ,
226+ # 4 channels and 2 clusters per channel
227+ [[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]],
228+ # weight has shape (1, 4, 2, 2)
229+ [[[[0.1 , 1.5 ], [2.0 , 1.1 ]], [[0. , 3.5 ], [4.4 , 4. ]],
230+ [[4.1 , 4.2 ], [5.3 , 6. ]], [[7. , 7.1 ], [6.1 , 5.8 ]]]],
231+ # expected pulling indices
232+ [[[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]], [[0 , 0 ], [1 , 1 ]],
233+ [[1 , 1 ], [0 , 0 ]]]]))
234+ def testConvolutionalPullingIndicesPerChannelCA (self , data_format ,
235+ clustering_centroids , weight ,
236+ expected_output ):
237+ """Verifies that get_pulling_indices function works as expected."""
238+ clustering_centroids = tf .Variable (clustering_centroids , dtype = tf .float32 )
239+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
240+ clustering_centroids , GradientAggregation .SUM , data_format
241+ )
242+ weight = tf .convert_to_tensor (weight )
243+ pulling_indices = clustering_algo .get_pulling_indices (weight )
216244
217- clustering_centroids = tf .Variable ([[0. , 1 , 2 ], [3 , 4 , 5 ]],
245+ # check that pulling_indices has the same shape as weight
246+ self .assertEqual (pulling_indices .shape , weight .shape )
247+ self .assertAllEqual (pulling_indices , expected_output )
248+
249+ @parameterized .parameters (
250+ (GradientAggregation .AVG , [
251+ [[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
252+ [[1 , 1 ], [1 , 1 ], [1 , 1 ]]),
253+ (GradientAggregation .SUM , [
254+ [[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
255+ [[2 , 2 ], [2 , 2 ], [3 , 1 ]]))
256+ def testConvolutionalPerChannelCAGradChannelsLast (
257+ self , cluster_gradient_aggregation , pulling_indices ,
258+ expected_grad_centroids ):
259+ """Verifies that the gradients of convolutional layer works."""
260+
261+ clustering_centroids = tf .Variable ([[1 , 2 ], [3 , 4 ], [5 , 6 ]],
218262 dtype = tf .float32 )
219- weight = tf .constant ([[[[0 .1 , 3.0 ]], [[0.2 , 0.1 ]]],
220- [[[0 .1 , 3.0 ]], [[0.2 , 0.1 ]]]])
263+ weight = tf .constant ([[[[1 .1 , 3.2 , 5.2 ]], [[2.0 , 4.1 , 5.2 ]]],
264+ [[[2 .1 , 2. , 6.1 ]], [[1. , 5. , 5. ]]]])
221265
222- clustering_algo = clustering_registry .PerChannelCA (
223- clustering_centroids , cluster_gradient_aggregation
266+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
267+ clustering_centroids , cluster_gradient_aggregation , 'channels_last' )
268+ self ._check_gradients_clustered_weight (
269+ clustering_algo ,
270+ weight ,
271+ pulling_indices ,
272+ expected_grad_centroids ,
224273 )
274+
275+ @parameterized .parameters ((GradientAggregation .AVG , [
276+ [[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]], [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]
277+ ], [[1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ]]), (GradientAggregation .SUM , [
278+ [[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]], [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]
279+ ], [[3 , 1 ], [2 , 2 ], [2 , 2 ], [2 , 2 ]]))
280+ def testConvolutionalPerChannelCAGradChannelsFirst (
281+ self , cluster_gradient_aggregation , pulling_indices ,
282+ expected_grad_centroids ):
283+ """Verifies that the gradients of convolutional layer works."""
284+
285+ clustering_centroids = tf .Variable ([[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]],
286+ dtype = tf .float32 )
287+ weight = tf .constant ([[[[0.1 , 1.5 ], [2.0 , 1.1 ]], [[0. , 3.5 ], [4.4 , 4. ]],
288+ [[4.1 , 4.2 ], [5.3 , 6. ]], [[7. , 7.1 ], [6.1 , 5.8 ]]]])
289+
290+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
291+ clustering_centroids , cluster_gradient_aggregation , 'channels_first' )
225292 self ._check_gradients_clustered_weight (
226293 clustering_algo ,
227294 weight ,
0 commit comments