|
12 | 12 |
|
13 | 13 |
|
14 | 14 | class DeformConv2D(nn.Module): |
15 | | - def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None): |
| 15 | + def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, lr_ratio=1.0): |
16 | 16 | super(DeformConv2D, self).__init__() |
17 | 17 | self.kernel_size = kernel_size |
18 | 18 | self.padding = padding |
19 | 19 | self.stride = stride |
20 | 20 | self.zero_padding = nn.ZeroPad2d(padding) |
21 | | - self.conv_kernel = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) |
22 | 21 |
|
23 | | - def forward(self, x, offset): |
| 22 | + self.offset_conv = nn.Conv2d(inc, 2 * kernel_size * kernel_size, kernel_size=3, padding=1, stride=stride) |
| 23 | + nn.init.constant_(self.offset_conv.weight, 0) # the offset learning are initialized with zero weights |
| 24 | + self.offset_conv.register_backward_hook(self._set_lr) |
| 25 | + |
| 26 | + self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) |
| 27 | + |
| 28 | + self.lr_ratio = lr_ratio |
| 29 | + |
| 30 | + def _set_lr(self, module, grad_input, grad_output): |
| 31 | + # print('grad input:', grad_input) |
| 32 | + new_grad_input = [] |
| 33 | + |
| 34 | + for i in range(len(grad_input)): |
| 35 | + if grad_input[i] is not None: |
| 36 | + new_grad_input.append(grad_input[i] * self.lr_ratio) |
| 37 | + else: |
| 38 | + new_grad_input.append(grad_input[i]) |
| 39 | + |
| 40 | + new_grad_input = tuple(new_grad_input) |
| 41 | + # print('new grad input:', new_grad_input) |
| 42 | + return new_grad_input |
| 43 | + |
| 44 | + def forward(self, x): |
| 45 | + offset = self.offset_conv(x) |
24 | 46 | dtype = offset.data.type() |
25 | 47 | ks = self.kernel_size |
26 | 48 | N = offset.size(1) // 2 |
@@ -140,8 +162,8 @@ def forward(self, x, offset): |
140 | 162 |
|
141 | 163 | x_offset = self._reshape_x_offset(x_offset, ks) |
142 | 164 |
|
143 | | - out = self.conv_kernel(x_offset) |
144 | | - return x_offset |
| 165 | + out = self.conv(x_offset) |
| 166 | + return out |
145 | 167 |
|
146 | 168 | def _get_p_n(self, N, dtype): |
147 | 169 | """ |
@@ -214,26 +236,34 @@ def _reshape_x_offset(x_offset, ks): |
214 | 236 | if __name__ == '__main__': |
215 | 237 | x = torch.randn(4, 3, 255, 255) |
216 | 238 |
|
217 | | - p_conv = nn.Conv2d(3, 2 * 3 * 3, kernel_size=3, padding=1, stride=1) |
218 | | - conv = nn.Conv2d(3, 64, kernel_size=3, stride=3, bias=False) |
219 | | - |
220 | | - d_conv1 = DeformConv2D(3, 64) |
221 | | - d_conv2 = DeformConv2D_ori(3, 64) |
222 | | - |
223 | | - offset = p_conv(x) |
224 | | - |
225 | | - end = time() |
226 | | - y1 = conv(d_conv1(x, offset)) |
227 | | - end = time() - end |
228 | | - print('#1 speed = ', end) |
229 | | - |
230 | | - end = time() |
231 | | - y2 = conv(d_conv2(x, offset)) |
232 | | - end = time() - end |
233 | | - print('#2 speed = ', end) |
| 239 | + # p_conv = nn.Conv2d(3, 2 * 3 * 3, kernel_size=3, padding=1, stride=1) |
| 240 | + # conv = nn.Conv2d(3, 64, kernel_size=3, stride=3, bias=False) |
| 241 | + # |
| 242 | + # d_conv1 = DeformConv2D(3, 64) |
| 243 | + # d_conv2 = DeformConv2D_ori(3, 64) |
| 244 | + # |
| 245 | + # offset = p_conv(x) |
| 246 | + # |
| 247 | + # end = time() |
| 248 | + # y1 = conv(d_conv1(x, offset)) |
| 249 | + # end = time() - end |
| 250 | + # print('#1 speed = ', end) |
| 251 | + # |
| 252 | + # end = time() |
| 253 | + # y2 = conv(d_conv2(x, offset)) |
| 254 | + # end = time() - end |
| 255 | + # print('#2 speed = ', end) |
234 | 256 |
|
235 | 257 | # mask = (y1 == y2) |
236 | 258 | # print(mask) |
237 | 259 | # print(torch.max(mask)) |
238 | 260 | # print(torch.min(mask)) |
239 | 261 |
|
| 262 | + x = torch.randn(4, 3, 255, 255) |
| 263 | + d_conv = DeformConv2D(3, 64) |
| 264 | + |
| 265 | + end = time() |
| 266 | + y = d_conv(x) |
| 267 | + end = time() - end |
| 268 | + print('speed = ', end) |
| 269 | + print(y.size()) |
0 commit comments