2020
2121import tensorflow as tf
2222
23- # b/(139939526): update assign ops to v2 API.
24- from tensorflow .python .ops import state_ops
2523from tensorflow .python .ops import summary_ops_v2
2624from tensorflow .python .summary import summary as summary_ops_v1
2725from tensorflow_model_optimization .python .core .sparsity .keras import pruning_utils
2826
27+
2928class Pruning (object ):
3029 """Implementation of magnitude-based weight pruning."""
3130
@@ -54,6 +53,13 @@ def __init__(self, training_step_fn, pruning_vars, pruning_schedule,
5453
5554 self ._validate_block ()
5655
56+ @staticmethod
57+ def _assign (ref , value ):
58+ if tf .__version__ [0 ] == '1' :
59+ return tf .assign (ref , value )
60+ else :
61+ return ref .assign (value )
62+
5763 def _validate_block (self ):
5864 if self ._block_size != [1 , 1 ]:
5965 for weight , _ , _ in self ._pruning_vars :
@@ -144,8 +150,15 @@ def _maybe_update_block_mask(self, weights):
144150 squeezed_weights .get_shape ()[1 ]])
145151 return new_threshold , tf .reshape (sliced_mask , tf .shape (weights ))
146152
147- def _get_weight_assign_ops (self ):
148- """Gather the assign ops for assigning weights<=weights*mask."""
153+ def _weight_assign_objs (self ):
154+ """Gather the assign objs for assigning weights<=weights*mask.
155+
156+ The objs are ops for graph execution and tensors for eager
157+ execution.
158+
159+ Returns:
160+ group of objs for weight assignment.
161+ """
149162
150163 def update_fn (distribution , values_and_vars ):
151164 # TODO(yunluli): Need this ReduceOp because the weight is created by the
@@ -158,34 +171,34 @@ def update_fn(distribution, values_and_vars):
158171 values_and_vars = zip (reduced_values , var_list )
159172
160173 def update_var (variable , reduced_value ):
161- return state_ops . assign (variable , reduced_value )
174+ return self . _assign (variable , reduced_value )
162175
163- update_ops = []
176+ update_objs = []
164177 for value , var in values_and_vars :
165- update_ops .append (
178+ update_objs .append (
166179 distribution .extended .update (var , update_var , args = (value ,)))
167180
168- return tf .group (update_ops )
181+ return tf .group (update_objs )
169182
170- assign_ops = []
183+ assign_objs = []
171184
172185 if tf .distribute .get_replica_context ():
173186 values_and_vars = []
174187 for weight , mask , _ in self ._pruning_vars :
175188 masked_weight = tf .math .multiply (weight , mask )
176189 values_and_vars .append ((masked_weight , weight ))
177190 if values_and_vars :
178- assign_ops .append (tf .distribute .get_replica_context ().merge_call (
191+ assign_objs .append (tf .distribute .get_replica_context ().merge_call (
179192 update_fn , args = (values_and_vars ,)))
180193 else :
181194 for weight , mask , _ in self ._pruning_vars :
182195 masked_weight = tf .math .multiply (weight , mask )
183- assign_ops .append (state_ops . assign (weight , masked_weight ))
196+ assign_objs .append (self . _assign (weight , masked_weight ))
184197
185- return assign_ops
198+ return assign_objs
186199
187200 def weight_mask_op (self ):
188- return tf .group (self ._get_weight_assign_ops ())
201+ return tf .group (self ._weight_assign_objs ())
189202
190203 def conditional_mask_update (self ):
191204 """Returns an op to updates masks as per the pruning schedule."""
@@ -200,35 +213,39 @@ def mask_update():
200213 """Updates mask without distribution strategy."""
201214
202215 def update ():
203- assign_ops = []
216+ assign_objs = []
204217
205218 for weight , mask , threshold in self ._pruning_vars :
206219 new_threshold , new_mask = self ._maybe_update_block_mask (weight )
207- assign_ops .append (state_ops . assign (threshold , new_threshold ))
208- assign_ops .append (state_ops . assign (mask , new_mask ))
220+ assign_objs .append (self . _assign (threshold , new_threshold ))
221+ assign_objs .append (self . _assign (mask , new_mask ))
209222
210- return tf .group (assign_ops )
223+ return tf .group (assign_objs )
211224
212225 return tf .cond (maybe_update_masks (), update , no_update )
213226
214227 def mask_update_distributed (distribution ):
215228 """Updates mask with distribution strategy."""
216229
217230 def update (var , value ):
218- return state_ops . assign (var , value )
231+ return self . _assign (var , value )
219232
220233 def update_distributed ():
221- """Gather distributed update ops."""
222- assign_ops = []
234+ """Gather distributed update objs.
235+
236+ The objs are ops for graph execution and tensors for eager
237+ execution.
238+ """
239+ assign_objs = []
223240
224241 for weight , mask , threshold in self ._pruning_vars :
225242 new_threshold , new_mask = self ._maybe_update_block_mask (weight )
226- assign_ops .append (
243+ assign_objs .append (
227244 distribution .extended .update (mask , update , (new_mask ,)))
228- assign_ops .append (
245+ assign_objs .append (
229246 distribution .extended .update (threshold , update , (new_threshold ,)))
230247
231- return tf .group (assign_ops )
248+ return tf .group (assign_objs )
232249
233250 return tf .cond (maybe_update_masks (), update_distributed , no_update )
234251
0 commit comments