File tree Expand file tree Collapse file tree 3 files changed +22
-2
lines changed
Expand file tree Collapse file tree 3 files changed +22
-2
lines changed Original file line number Diff line number Diff line change @@ -224,3 +224,20 @@ def forward(
224224
225225 out = gaf_function (tree_spec , * tree_nodes )
226226 return out
227+
228+ # helper functions for disabling GAF wrappers within a network
229+ # for handy ablation, in the case subnetworks within a neural network were wrapped
230+
231+ def set_filter_gradients_ (
232+ m : Module ,
233+ filter_gradients : bool ,
234+ filter_distance_thres = None
235+ ):
236+ for module in m .modules ():
237+ if not isinstance (module , GAFWrapper ):
238+ continue
239+
240+ module .filter_gradients = filter_gradients
241+
242+ if exists (filter_distance_thres ):
243+ module .filter_distance_thres = filter_distance_thres
Original file line number Diff line number Diff line change 1- from GAF_microbatch_pytorch .GAF import GAFWrapper
1+ from GAF_microbatch_pytorch .GAF import (
2+ GAFWrapper ,
3+ set_filter_gradients_
4+ )
Original file line number Diff line number Diff line change 11[project ]
22name = " GAF-microbatch-pytorch"
3- version = " 0.0.2 "
3+ version = " 0.0.3 "
44description = " Gradient Agreement Filtering"
55authors = [
66 { name = " Phil Wang" , email = " lucidrains@gmail.com" }
You can’t perform that action at this time.
0 commit comments