Skip to content

Commit fc6bb11

Browse files
committed
fix set_lr bugs
1 parent 2725103 commit fc6bb11

File tree

4 files changed

+394
-52
lines changed

4 files changed

+394
-52
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@Time : 2019/2/20 22:16
4+
@Author : Wang Xin
5+
@Email : wangxin_buaa@163.com
6+
"""
7+
8+
import numpy as np
9+
10+
import torch
11+
import torch.nn as nn
12+
13+
14+
class DeformConv2D(nn.Module):
15+
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None):
16+
super(DeformConv2D, self).__init__()
17+
self.kernel_size = kernel_size
18+
self.padding = padding
19+
self.stride = stride
20+
self.zero_padding = nn.ZeroPad2d(padding)
21+
self.conv_kernel = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
22+
23+
def forward(self, x, offset):
24+
dtype = offset.data.type()
25+
ks = self.kernel_size
26+
N = offset.size(1) // 2
27+
28+
# Change offset's order from [x1, x2, ..., y1, y2, ...] to [x1, y1, x2, y2, ...]
29+
# Codes below are written to make sure same results of MXNet implementation.
30+
# You can remove them, and it won't influence the module's performance.
31+
offsets_index = torch.cat([torch.arange(0, 2 * N, 2), torch.arange(1, 2 * N + 1, 2)]).type_as(x).long()
32+
offsets_index.requires_grad = False
33+
offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*offset.size())
34+
offset = torch.gather(offset, dim=1, index=offsets_index)
35+
# ------------------------------------------------------------------------
36+
37+
if self.padding:
38+
x = self.zero_padding(x)
39+
40+
# (b, 2N, h, w)
41+
p = self._get_p(offset, dtype)
42+
43+
# (b, h, w, 2N)
44+
p = p.contiguous().permute(0, 2, 3, 1)
45+
46+
"""
47+
if q is float, using bilinear interpolate, it has four integer corresponding position.
48+
The four position is left top, right top, left bottom, right bottom, defined as q_lt, q_rb, q_lb, q_rt
49+
"""
50+
# (b, h, w, 2N)
51+
q_lt = p.detach().floor()
52+
53+
"""
54+
Because the shape of x is N, b, h, w, the pixel position is (y, x)
55+
*┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄→y
56+
┊ .(y, x) .(y+1, x)
57+
58+
┊ .(y, x+1) .(y+1, x+1)
59+
60+
61+
x
62+
63+
For right bottom point, it'x = left top'y + 1, it'y = left top'y + 1
64+
"""
65+
q_rb = q_lt + 1
66+
67+
"""
68+
x.size(2) is h, x.size(3) is w, make 0 <= p_y <= h - 1, 0 <= p_x <= w-1
69+
"""
70+
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
71+
dim=-1).long()
72+
73+
"""
74+
x.size(2) is h, x.size(3) is w, make 0 <= p_y <= h - 1, 0 <= p_x <= w-1
75+
"""
76+
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
77+
dim=-1).long()
78+
79+
"""
80+
For the left bottom point, it'x is equal to right bottom, it'y is equal to left top
81+
Therefore, it's y is from q_lt, it's x is from q_rb
82+
"""
83+
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1)
84+
85+
"""
86+
y from q_rb, x from q_lt
87+
For right top point, it's x is equal t to left top, it's y is equal to right bottom
88+
"""
89+
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1)
90+
91+
"""
92+
find p_y <= padding or p_y >= h - 1 - padding, find p_x <= padding or p_x >= x - 1 - padding
93+
This is to find the points in the area where the pixel value is meaningful.
94+
"""
95+
# (b, h, w, N)
96+
mask = torch.cat([p[..., :N].lt(self.padding) + p[..., :N].gt(x.size(2) - 1 - self.padding),
97+
p[..., N:].lt(self.padding) + p[..., N:].gt(x.size(3) - 1 - self.padding)], dim=-1).type_as(p)
98+
mask = mask.detach()
99+
# print('mask:', mask)
100+
101+
floor_p = torch.floor(p)
102+
# print('floor_p = ', floor_p)
103+
104+
"""
105+
when mask is 1, take floor_p;
106+
when mask is 0, take original p.
107+
When thr point in the padding area, interpolation is not meaningful and we can take the nearest
108+
point which is the most possible to have meaningful value.
109+
"""
110+
p = p * (1 - mask) + floor_p * mask
111+
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)
112+
113+
"""
114+
In the paper, G(q, p) = g(q_x, p_x) * g(q_y, p_y)
115+
g(a, b) = max(0, 1-|a-b|)
116+
"""
117+
# bilinear kernel (b, h, w, N)
118+
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
119+
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
120+
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
121+
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
122+
123+
# print('g_lt size is ', g_lt.size())
124+
# print('g_lt unsqueeze size:', g_lt.unsqueeze(dim=1).size())
125+
126+
# (b, c, h, w, N)
127+
x_q_lt = self._get_x_q(x, q_lt, N)
128+
x_q_rb = self._get_x_q(x, q_rb, N)
129+
x_q_lb = self._get_x_q(x, q_lb, N)
130+
x_q_rt = self._get_x_q(x, q_rt, N)
131+
132+
"""
133+
In the paper, x(p) = ΣG(p, q) * x(q), G is bilinear kernal
134+
"""
135+
# (b, c, h, w, N)
136+
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
137+
g_rb.unsqueeze(dim=1) * x_q_rb + \
138+
g_lb.unsqueeze(dim=1) * x_q_lb + \
139+
g_rt.unsqueeze(dim=1) * x_q_rt
140+
141+
x_offset = self._reshape_x_offset(x_offset, ks)
142+
143+
out = self.conv_kernel(x_offset)
144+
return x_offset
145+
146+
def _get_p_n(self, N, dtype):
147+
"""
148+
In torch 0.4.1 grid_x, grid_y = torch.meshgrid([x, y])
149+
In torch 1.0 grid_x, grid_y = torch.meshgrid(x, y)
150+
"""
151+
p_n_x, p_n_y = torch.meshgrid(
152+
[torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1),
153+
torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1)])
154+
# (2N, 1)
155+
p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
156+
p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)
157+
p_n.requires_grad = False
158+
# print('requires_grad:', p_n.requires_grad)
159+
160+
return p_n
161+
162+
def _get_p_0(self, h, w, N, dtype):
163+
p_0_x, p_0_y = torch.meshgrid([
164+
torch.arange(1, h * self.stride + 1, self.stride),
165+
torch.arange(1, w * self.stride + 1, self.stride)])
166+
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
167+
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
168+
p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
169+
p_0.requires_grad = False
170+
171+
return p_0
172+
173+
def _get_p(self, offset, dtype):
174+
N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)
175+
176+
# (1, 2N, 1, 1)
177+
p_n = self._get_p_n(N, dtype)
178+
179+
# (1, 2N, h, w)
180+
p_0 = self._get_p_0(h, w, N, dtype)
181+
182+
p = p_0 + p_n + offset
183+
return p
184+
185+
def _get_x_q(self, x, q, N):
186+
b, h, w, _ = q.size()
187+
padded_w = x.size(3)
188+
c = x.size(1)
189+
# (b, c, h*w)
190+
x = x.contiguous().view(b, c, -1)
191+
192+
# (b, h, w, N)
193+
index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y
194+
# (b, c, h*w*N)
195+
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
196+
197+
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
198+
199+
return x_offset
200+
201+
@staticmethod
202+
def _reshape_x_offset(x_offset, ks):
203+
b, c, h, w, N = x_offset.size()
204+
x_offset = torch.cat([x_offset[..., s:s + ks].contiguous().view(b, c, h, w * ks) for s in range(0, N, ks)],
205+
dim=-1)
206+
x_offset = x_offset.contiguous().view(b, c, h * ks, w * ks)
207+
208+
return x_offset
209+
210+
211+
from network.deform_conv.deform_conv import DeformConv2D as DeformConv2D_ori
212+
from time import time
213+
214+
if __name__ == '__main__':
215+
x = torch.randn(4, 3, 255, 255)
216+
217+
p_conv = nn.Conv2d(3, 2 * 3 * 3, kernel_size=3, padding=1, stride=1)
218+
conv = nn.Conv2d(3, 64, kernel_size=3, stride=3, bias=False)
219+
220+
d_conv1 = DeformConv2D(3, 64)
221+
d_conv2 = DeformConv2D_ori(3, 64)
222+
223+
offset = p_conv(x)
224+
225+
end = time()
226+
y1 = conv(d_conv1(x, offset))
227+
end = time() - end
228+
print('#1 speed = ', end)
229+
230+
end = time()
231+
y2 = conv(d_conv2(x, offset))
232+
end = time() - end
233+
print('#2 speed = ', end)
234+
235+
# mask = (y1 == y2)
236+
# print(mask)
237+
# print(torch.max(mask))
238+
# print(torch.min(mask))
239+

network/deform_conv/deform_conv.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,23 @@ def forward(self, x, offset):
5353
print('p size:', p.size())
5454

5555
"""
56-
if q is float, using bilinear interpolate, it has four integer position.
56+
if q is float, using bilinear interpolate, it has four integer corresponding position.
5757
The four position is left top, right top, left bottom, right bottom, defined as q_lt, q_rb, q_lb, q_rt
5858
"""
5959
q_lt = Variable(p.data, requires_grad=False).floor()
60+
print('q_lt size:', q_lt.size())
6061

6162
"""
62-
*┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄→x
63-
┊ .(x, y) .(x+1, y)
63+
Because the shape of x is N, b, h, w, the pixel position is (y, x)
64+
*┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄→y
65+
┊ .(y, x) .(y+1, x)
6466
65-
┊ .(x, y+1) .(x+1, y+1)
67+
┊ .(y, x+1) .(y+1, x+1)
6668
6769
68-
y
70+
x
6971
70-
for right bottom point, it'x = left top'y + 1, it'y = left top'y + 1
72+
For right bottom point, it'x = left top'y + 1, it'y = left top'y + 1
7173
"""
7274
q_rb = q_lt + 1
7375

@@ -78,47 +80,42 @@ def forward(self, x, offset):
7880
dim=-1).long()
7981

8082
"""
81-
83+
x.size(2) is h, x.size(3) is w, make 0 <= p_y <= h - 1, 0 <= p_x <= w-1
8284
"""
8385
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
8486
dim=-1).long()
8587

8688
"""
87-
For the left bottom point, it'x is equal to left top, it'y is equal to right bottom
89+
For the left bottom point, it'x is equal to right bottom, it'y is equal to left top
90+
Therefore, it's y is from q_lt, it's x is from q_rb
8891
"""
8992
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1)
9093

9194
"""
9295
y from q_rb, x from q_lt
93-
for right top point,
96+
For right top point, it's x is equal t to left top, it's y is equal to right bottom
9497
"""
9598
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1)
9699

97-
print('q_lt size:', q_lt.size())
98-
print('q_rb size:', q_rb.size())
99-
print('q_lb size:', q_lb.size())
100-
print('q_rt size:', q_rt.size())
101-
print('N = ', N)
102-
print('q_lt[..., :] size:', q_lt[..., :N].size())
103-
104-
105100
"""
106101
find p_y <= padding or p_y >= h - 1 - padding, find p_x <= padding or p_x >= x - 1 - padding
107-
which make the point in the area where the pixel value is meaningful.
102+
This is to find the points in the area where the pixel value is meaningful.
108103
"""
109104
# (b, h, w, N)
110105
mask = torch.cat([p[..., :N].lt(self.padding) + p[..., :N].gt(x.size(2) - 1 - self.padding),
111106
p[..., N:].lt(self.padding) + p[..., N:].gt(x.size(3) - 1 - self.padding)], dim=-1).type_as(p)
112107
mask = mask.detach()
113-
print('mask:', mask)
108+
# print('mask:', mask)
114109

115110
floor_p = torch.floor(p)
116-
print('floor_p = ', floor_p)
111+
# print('floor_p = ', floor_p)
117112

118113

119114
"""
120115
when mask is 1, take floor_p;
121116
when mask is 0, take original p.
117+
When thr point in the padding area, interpolation is not meaningful and we can take the nearest
118+
point which is the most possible to have meaningful value.
122119
"""
123120
p = p * (1 - mask) + floor_p * mask
124121
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)
@@ -133,14 +130,20 @@ def forward(self, x, offset):
133130
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
134131
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
135132

136-
print('g_lt = ', g_lt)
133+
print('g_lt size is ', g_lt.size())
134+
print('g_lt unsqueeze size:', g_lt.unsqueeze(dim=1).size())
137135

138136
# (b, c, h, w, N)
137+
"""
138+
To get the pixel value in left top, left bottom, right top and right bottom.
139+
"""
139140
x_q_lt = self._get_x_q(x, q_lt, N)
140141
x_q_rb = self._get_x_q(x, q_rb, N)
141142
x_q_lb = self._get_x_q(x, q_lb, N)
142143
x_q_rt = self._get_x_q(x, q_rt, N)
143144

145+
print('x_q_lt size:', x_q_lt.size())
146+
144147

145148
"""
146149
In the paper, x(p) = ΣG(p, q) * x(q), G is bilinear kernal
@@ -155,9 +158,12 @@ def forward(self, x, offset):
155158

156159
x_offset = self._reshape_x_offset(x_offset, ks)
157160
print('#02 x_offset size:', x_offset.size())
158-
161+
"""
162+
x_offset is kernel_size * kernel_size x.
163+
"""
159164
out = self.conv_kernel(x_offset)
160-
return out
165+
print('out size:', out.size())
166+
return x_offset
161167

162168
def _get_p_n(self, N, dtype):
163169
p_n_x, p_n_y = np.meshgrid(range(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1),
@@ -198,15 +204,20 @@ def _get_x_q(self, x, q, N):
198204
padded_w = x.size(3)
199205
c = x.size(1)
200206
# (b, c, h*w)
207+
print('#01 x size:', x.size())
201208
x = x.contiguous().view(b, c, -1)
209+
print('#02 x size:', x.size())
202210

203211
# (b, h, w, N)
204212
index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y
213+
print('q size:', q.size())
214+
print('#01 index size:', index.size())
205215
# (b, c, h*w*N)
206216
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
217+
print('#02 index size:', index.size())
207218

208219
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
209-
220+
print('x offset size:', x_offset.size())
210221
return x_offset
211222

212223
@staticmethod
@@ -228,7 +239,7 @@ def _reshape_x_offset(x_offset, ks):
228239
offset = p_conv(x)
229240
y = d_conv2(x, offset)
230241

231-
print(y.size())
232-
print('y = ', y)
233-
print("offset:")
234-
print(offset)
242+
# print(y.size())
243+
# print('y = ', y)
244+
# print("offset:")
245+
# print(offset)

0 commit comments

Comments
 (0)