Skip to content

Commit ca30f59

Browse files
committed
update
1 parent fc6bb11 commit ca30f59

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

network/deform_conv/_deform_conv.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,37 @@
1212

1313

1414
class DeformConv2D(nn.Module):
15-
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None):
15+
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, lr_ratio=1.0):
1616
super(DeformConv2D, self).__init__()
1717
self.kernel_size = kernel_size
1818
self.padding = padding
1919
self.stride = stride
2020
self.zero_padding = nn.ZeroPad2d(padding)
21-
self.conv_kernel = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
2221

23-
def forward(self, x, offset):
22+
self.offset_conv = nn.Conv2d(inc, 2 * kernel_size * kernel_size, kernel_size=3, padding=1, stride=stride)
23+
nn.init.constant_(self.offset_conv.weight, 0) # the offset learning are initialized with zero weights
24+
self.offset_conv.register_backward_hook(self._set_lr)
25+
26+
self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
27+
28+
self.lr_ratio = lr_ratio
29+
30+
def _set_lr(self, module, grad_input, grad_output):
31+
# print('grad input:', grad_input)
32+
new_grad_input = []
33+
34+
for i in range(len(grad_input)):
35+
if grad_input[i] is not None:
36+
new_grad_input.append(grad_input[i] * self.lr_ratio)
37+
else:
38+
new_grad_input.append(grad_input[i])
39+
40+
new_grad_input = tuple(new_grad_input)
41+
# print('new grad input:', new_grad_input)
42+
return new_grad_input
43+
44+
def forward(self, x):
45+
offset = self.offset_conv(x)
2446
dtype = offset.data.type()
2547
ks = self.kernel_size
2648
N = offset.size(1) // 2
@@ -140,8 +162,8 @@ def forward(self, x, offset):
140162

141163
x_offset = self._reshape_x_offset(x_offset, ks)
142164

143-
out = self.conv_kernel(x_offset)
144-
return x_offset
165+
out = self.conv(x_offset)
166+
return out
145167

146168
def _get_p_n(self, N, dtype):
147169
"""
@@ -214,26 +236,34 @@ def _reshape_x_offset(x_offset, ks):
214236
if __name__ == '__main__':
215237
x = torch.randn(4, 3, 255, 255)
216238

217-
p_conv = nn.Conv2d(3, 2 * 3 * 3, kernel_size=3, padding=1, stride=1)
218-
conv = nn.Conv2d(3, 64, kernel_size=3, stride=3, bias=False)
219-
220-
d_conv1 = DeformConv2D(3, 64)
221-
d_conv2 = DeformConv2D_ori(3, 64)
222-
223-
offset = p_conv(x)
224-
225-
end = time()
226-
y1 = conv(d_conv1(x, offset))
227-
end = time() - end
228-
print('#1 speed = ', end)
229-
230-
end = time()
231-
y2 = conv(d_conv2(x, offset))
232-
end = time() - end
233-
print('#2 speed = ', end)
239+
# p_conv = nn.Conv2d(3, 2 * 3 * 3, kernel_size=3, padding=1, stride=1)
240+
# conv = nn.Conv2d(3, 64, kernel_size=3, stride=3, bias=False)
241+
#
242+
# d_conv1 = DeformConv2D(3, 64)
243+
# d_conv2 = DeformConv2D_ori(3, 64)
244+
#
245+
# offset = p_conv(x)
246+
#
247+
# end = time()
248+
# y1 = conv(d_conv1(x, offset))
249+
# end = time() - end
250+
# print('#1 speed = ', end)
251+
#
252+
# end = time()
253+
# y2 = conv(d_conv2(x, offset))
254+
# end = time() - end
255+
# print('#2 speed = ', end)
234256

235257
# mask = (y1 == y2)
236258
# print(mask)
237259
# print(torch.max(mask))
238260
# print(torch.min(mask))
239261

262+
x = torch.randn(4, 3, 255, 255)
263+
d_conv = DeformConv2D(3, 64)
264+
265+
end = time()
266+
y = d_conv(x)
267+
end = time() - end
268+
print('speed = ', end)
269+
print(y.size())

network/deform_conv/deform_conv_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _set_lr(self, module, grad_input, grad_output):
4444

4545
for i in range(len(grad_input)):
4646
if grad_input[i] is not None:
47-
new_grad_input.append(grad_input[i] * 0.1)
47+
new_grad_input.append(grad_input[i] * self.lr_ratio)
4848
else:
4949
new_grad_input.append(grad_input[i])
5050

0 commit comments

Comments
 (0)