Skip to content

Commit 4eb9ac9

Browse files
committed
able to exclude certain parameters from having its gradients filtered
1 parent c9e82ab commit 4eb9ac9

File tree

3 files changed

+42
-10
lines changed

3 files changed

+42
-10
lines changed

GAF_microbatch_pytorch/GAF.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from functools import partial
44
from typing import Literal, Callable
5+
from collections.abc import Sequence
56

67
import torch
78
from torch import nn
@@ -102,6 +103,7 @@ def forward(self, ctx, tree_spec, *tree_nodes):
102103
net = package['net']
103104
params, buffers = package['params_buffers']
104105
filter_gradients_fn = package['filter_gradients_fn']
106+
exclude_from_filtering = package['exclude_from_filtering']
105107
inp_tensor, args, kwargs = package['inputs']
106108

107109
batch = inp_tensor.shape[0]
@@ -115,28 +117,54 @@ def fn(params, buffers, inp_tensor):
115117

116118
output, vjpfunc = vjp(fn, params, buffers, inp_tensor)
117119

118-
ctx._saved_info_for_backwards = (vjpfunc, filter_gradients_fn, args, kwargs)
120+
ctx._saved_info_for_backwards = (
121+
vjpfunc,
122+
filter_gradients_fn,
123+
args,
124+
kwargs,
125+
exclude_from_filtering
126+
)
127+
119128
return output
120129

121130
@classmethod
122131
def backward(self, ctx, do):
123132

124-
vjp_func, filter_gradients_fn, args, kwargs = ctx._saved_info_for_backwards
133+
(
134+
vjp_func,
135+
filter_gradients_fn,
136+
args,
137+
kwargs,
138+
exclude_from_filtering
139+
) = ctx._saved_info_for_backwards
125140

126141
dparams, dbuffers, dinp = vjp_func(do)
127142

128-
filtered_dparams = {name: filter_gradients_fn(dparam) for name, dparam in dparams.items()}
143+
# filter gradients for each parameter tensor
144+
# unless it is in `exclude_from_filtering`
145+
146+
filtered_dparams = dict()
147+
148+
for name, dparam in dparams.items():
149+
if name in exclude_from_filtering:
150+
filtered_dparams[name] = dparam
151+
continue
152+
153+
filtered_dparams[name] = filter_gradients_fn(dparam)
154+
155+
# tree flatten back out
129156

130157
package = dict(
131158
net = None,
132159
params_buffers = (filtered_dparams, dbuffers),
133-
inputs = (dinp, None, None)
160+
inputs = (dinp, None, None),
161+
filter_gradients_fn = None,
162+
exclude_from_filtering = None
134163
)
135164

136165
tree_nodes, _ = tree_flatten(package)
137166

138-
output = (None, *tree_nodes)
139-
return output
167+
return (None, *tree_nodes)
140168

141169
gaf_function = GAF.apply
142170

@@ -151,12 +179,15 @@ def __init__(
151179
net: Module,
152180
filter_distance_thres = 0.97,
153181
filter_gradients = True,
154-
filter_gradients_fn: Callable | None = None
182+
filter_gradients_fn: Callable | None = None,
183+
exclude_from_filtering: Sequence[str] = ()
155184
):
156185
super().__init__()
157186

158187
self.net = net
159188

189+
self.exclude_from_filtering = set(exclude_from_filtering)
190+
160191
# gradient agreement filtering related
161192

162193
self.filter_gradients = filter_gradients
@@ -185,7 +216,8 @@ def forward(
185216
net = self.net,
186217
params_buffers = (params, buffers),
187218
inputs = (inp_tensor, args, kwargs),
188-
filter_gradients_fn = self.filter_gradients_fn
219+
filter_gradients_fn = self.filter_gradients_fn,
220+
exclude_from_filtering = self.exclude_from_filtering
189221
)
190222

191223
tree_nodes, tree_spec = tree_flatten(package)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ gaf_net = GAFWrapper(
7070
## Todo
7171

7272
- [ ] replicate cifar results on single machine
73-
- [ ] allow for excluding certain parameters from being filtered
73+
- [x] allow for excluding certain parameters from being filtered
7474

7575
## Citations
7676

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "GAF-microbatch-pytorch"
3-
version = "0.0.1"
3+
version = "0.0.2"
44
description = "Gradient Agreement Filtering"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)