22
33from functools import partial
44from typing import Literal , Callable
5+ from collections .abc import Sequence
56
67import torch
78from torch import nn
@@ -102,6 +103,7 @@ def forward(self, ctx, tree_spec, *tree_nodes):
102103 net = package ['net' ]
103104 params , buffers = package ['params_buffers' ]
104105 filter_gradients_fn = package ['filter_gradients_fn' ]
106+ exclude_from_filtering = package ['exclude_from_filtering' ]
105107 inp_tensor , args , kwargs = package ['inputs' ]
106108
107109 batch = inp_tensor .shape [0 ]
@@ -115,28 +117,54 @@ def fn(params, buffers, inp_tensor):
115117
116118 output , vjpfunc = vjp (fn , params , buffers , inp_tensor )
117119
118- ctx ._saved_info_for_backwards = (vjpfunc , filter_gradients_fn , args , kwargs )
120+ ctx ._saved_info_for_backwards = (
121+ vjpfunc ,
122+ filter_gradients_fn ,
123+ args ,
124+ kwargs ,
125+ exclude_from_filtering
126+ )
127+
119128 return output
120129
121130 @classmethod
122131 def backward (self , ctx , do ):
123132
124- vjp_func , filter_gradients_fn , args , kwargs = ctx ._saved_info_for_backwards
133+ (
134+ vjp_func ,
135+ filter_gradients_fn ,
136+ args ,
137+ kwargs ,
138+ exclude_from_filtering
139+ ) = ctx ._saved_info_for_backwards
125140
126141 dparams , dbuffers , dinp = vjp_func (do )
127142
128- filtered_dparams = {name : filter_gradients_fn (dparam ) for name , dparam in dparams .items ()}
143+ # filter gradients for each parameter tensor
144+ # unless it is in `exclude_from_filtering`
145+
146+ filtered_dparams = dict ()
147+
148+ for name , dparam in dparams .items ():
149+ if name in exclude_from_filtering :
150+ filtered_dparams [name ] = dparam
151+ continue
152+
153+ filtered_dparams [name ] = filter_gradients_fn (dparam )
154+
155+ # tree flatten back out
129156
130157 package = dict (
131158 net = None ,
132159 params_buffers = (filtered_dparams , dbuffers ),
133- inputs = (dinp , None , None )
160+ inputs = (dinp , None , None ),
161+ filter_gradients_fn = None ,
162+ exclude_from_filtering = None
134163 )
135164
136165 tree_nodes , _ = tree_flatten (package )
137166
138- output = (None , * tree_nodes )
139- return output
167+ return (None , * tree_nodes )
140168
141169gaf_function = GAF .apply
142170
@@ -151,12 +179,15 @@ def __init__(
151179 net : Module ,
152180 filter_distance_thres = 0.97 ,
153181 filter_gradients = True ,
154- filter_gradients_fn : Callable | None = None
182+ filter_gradients_fn : Callable | None = None ,
183+ exclude_from_filtering : Sequence [str ] = ()
155184 ):
156185 super ().__init__ ()
157186
158187 self .net = net
159188
189+ self .exclude_from_filtering = set (exclude_from_filtering )
190+
160191 # gradient agreement filtering related
161192
162193 self .filter_gradients = filter_gradients
@@ -185,7 +216,8 @@ def forward(
185216 net = self .net ,
186217 params_buffers = (params , buffers ),
187218 inputs = (inp_tensor , args , kwargs ),
188- filter_gradients_fn = self .filter_gradients_fn
219+ filter_gradients_fn = self .filter_gradients_fn ,
220+ exclude_from_filtering = self .exclude_from_filtering
189221 )
190222
191223 tree_nodes , tree_spec = tree_flatten (package )
0 commit comments