Skip to content

Commit 2725103

Browse files
committed
add comment
1 parent 58d1634 commit 2725103

File tree

3 files changed

+95
-21
lines changed

3 files changed

+95
-21
lines changed

main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33
@Time : 2019/2/19 16:06
44
@Author : Wang Xin
55
@Email : wangxin_buaa@163.com
6-
"""
6+
"""
7+
8+
# TODO: 手写体数字识别

network/deform_conv/deform_conv.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch.nn as nn
1212
from torch.autograd import Variable
1313

14-
1514
"""
1615
https://github.com/ChunhuanLin/deform_conv_pytorch/blob/master/deform_conv.py
1716
"""
@@ -45,47 +44,119 @@ def forward(self, x, offset):
4544
# (b, 2N, h, w)
4645
p = self._get_p(offset, dtype)
4746

47+
print('p size:', p.size())
48+
print('p = ', p)
49+
4850
# (b, h, w, 2N)
4951
p = p.contiguous().permute(0, 2, 3, 1)
52+
53+
print('p size:', p.size())
54+
55+
"""
56+
if q is float, using bilinear interpolate, it has four integer position.
57+
The four position is left top, right top, left bottom, right bottom, defined as q_lt, q_rb, q_lb, q_rt
58+
"""
5059
q_lt = Variable(p.data, requires_grad=False).floor()
60+
61+
"""
62+
*┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄→x
63+
┊ .(x, y) .(x+1, y)
64+
65+
┊ .(x, y+1) .(x+1, y+1)
66+
67+
68+
y
69+
70+
for right bottom point, it'x = left top'y + 1, it'y = left top'y + 1
71+
"""
5172
q_rb = q_lt + 1
5273

74+
"""
75+
x.size(2) is h, x.size(3) is w, make 0 <= p_y <= h - 1, 0 <= p_x <= w-1
76+
"""
5377
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
5478
dim=-1).long()
79+
80+
"""
81+
82+
"""
5583
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
5684
dim=-1).long()
85+
86+
"""
87+
For the left bottom point, it'x is equal to left top, it'y is equal to right bottom
88+
"""
5789
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1)
90+
91+
"""
92+
y from q_rb, x from q_lt
93+
for right top point,
94+
"""
5895
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1)
5996

97+
print('q_lt size:', q_lt.size())
98+
print('q_rb size:', q_rb.size())
99+
print('q_lb size:', q_lb.size())
100+
print('q_rt size:', q_rt.size())
101+
print('N = ', N)
102+
print('q_lt[..., :] size:', q_lt[..., :N].size())
103+
104+
105+
"""
106+
find p_y <= padding or p_y >= h - 1 - padding, find p_x <= padding or p_x >= x - 1 - padding
107+
which make the point in the area where the pixel value is meaningful.
108+
"""
60109
# (b, h, w, N)
61110
mask = torch.cat([p[..., :N].lt(self.padding) + p[..., :N].gt(x.size(2) - 1 - self.padding),
62111
p[..., N:].lt(self.padding) + p[..., N:].gt(x.size(3) - 1 - self.padding)], dim=-1).type_as(p)
63112
mask = mask.detach()
64-
floor_p = p - (p - torch.floor(p))
113+
print('mask:', mask)
114+
115+
floor_p = torch.floor(p)
116+
print('floor_p = ', floor_p)
117+
118+
119+
"""
120+
when mask is 1, take floor_p;
121+
when mask is 0, take original p.
122+
"""
65123
p = p * (1 - mask) + floor_p * mask
66124
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)
67125

126+
"""
127+
In the paper, G(q, p) = g(q_x, p_x) * g(q_y, p_y)
128+
g(a, b) = max(0, 1-|a-b|)
129+
"""
68130
# bilinear kernel (b, h, w, N)
69131
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
70132
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
71133
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
72134
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
73135

136+
print('g_lt = ', g_lt)
137+
74138
# (b, c, h, w, N)
75139
x_q_lt = self._get_x_q(x, q_lt, N)
76140
x_q_rb = self._get_x_q(x, q_rb, N)
77141
x_q_lb = self._get_x_q(x, q_lb, N)
78142
x_q_rt = self._get_x_q(x, q_rt, N)
79143

144+
145+
"""
146+
In the paper, x(p) = ΣG(p, q) * x(q), G is bilinear kernal
147+
"""
80148
# (b, c, h, w, N)
81149
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
82150
g_rb.unsqueeze(dim=1) * x_q_rb + \
83151
g_lb.unsqueeze(dim=1) * x_q_lb + \
84152
g_rt.unsqueeze(dim=1) * x_q_rt
85153

154+
print('#01 x_offset size:', x_offset.size())
155+
86156
x_offset = self._reshape_x_offset(x_offset, ks)
87-
out = self.conv_kernel(x_offset)
157+
print('#02 x_offset size:', x_offset.size())
88158

159+
out = self.conv_kernel(x_offset)
89160
return out
90161

91162
def _get_p_n(self, N, dtype):
@@ -113,8 +184,12 @@ def _get_p(self, offset, dtype):
113184

114185
# (1, 2N, 1, 1)
115186
p_n = self._get_p_n(N, dtype)
187+
print('p_n:', p_n)
188+
print('p_n size:', p_n.size())
116189
# (1, 2N, h, w)
117190
p_0 = self._get_p_0(h, w, N, dtype)
191+
print('p_0:', p_0)
192+
print('p_0 size:', p_0.size())
118193
p = p_0 + p_n + offset
119194
return p
120195

network/deform_conv/deform_conv_v2.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, mod
2727
self.zero_padding = nn.ZeroPad2d(padding)
2828
self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
2929

30-
self.p_conv = nn.Conv2d(inc, 2 * kernel_size * kernel_size, kernel_size=3, padding=1, stride=stride)
30+
self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
3131
nn.init.constant_(self.p_conv.weight, 0)
3232
self.p_conv.register_backward_hook(self._set_lr)
3333

3434
self.modulation = modulation
3535
if modulation:
36-
self.m_conv = nn.Conv2d(inc, kernel_size * kernel_size, kernel_size=3, padding=1, stride=stride)
36+
self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
3737
nn.init.constant_(self.m_conv.weight, 0.5)
3838
self.m_conv.register_backward_hook(self._set_lr)
3939

@@ -62,15 +62,13 @@ def forward(self, x):
6262
q_lt = p.detach().floor()
6363
q_rb = q_lt + 1
6464

65-
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
66-
dim=-1).long()
67-
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
68-
dim=-1).long()
65+
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
66+
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
6967
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
7068
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
7169

7270
# clip p
73-
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)
71+
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
7472

7573
# bilinear kernel (b, h, w, N)
7674
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
@@ -104,26 +102,26 @@ def forward(self, x):
104102

105103
def _get_p_n(self, N, dtype):
106104
p_n_x, p_n_y = torch.meshgrid(
107-
torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1),
108-
torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1))
105+
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
106+
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
109107
# (2N, 1)
110108
p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
111-
p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)
109+
p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
112110

113111
return p_n
114112

115113
def _get_p_0(self, h, w, N, dtype):
116114
p_0_x, p_0_y = torch.meshgrid(
117-
torch.arange(1, h * self.stride + 1, self.stride),
118-
torch.arange(1, w * self.stride + 1, self.stride))
115+
torch.arange(1, h*self.stride+1, self.stride),
116+
torch.arange(1, w*self.stride+1, self.stride))
119117
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
120118
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
121119
p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
122120

123121
return p_0
124122

125123
def _get_p(self, offset, dtype):
126-
N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)
124+
N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
127125

128126
# (1, 2N, 1, 1)
129127
p_n = self._get_p_n(N, dtype)
@@ -140,7 +138,7 @@ def _get_x_q(self, x, q, N):
140138
x = x.contiguous().view(b, c, -1)
141139

142140
# (b, h, w, N)
143-
index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y
141+
index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
144142
# (b, c, h*w*N)
145143
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
146144

@@ -151,9 +149,8 @@ def _get_x_q(self, x, q, N):
151149
@staticmethod
152150
def _reshape_x_offset(x_offset, ks):
153151
b, c, h, w, N = x_offset.size()
154-
x_offset = torch.cat([x_offset[..., s:s + ks].contiguous().view(b, c, h, w * ks) for s in range(0, N, ks)],
155-
dim=-1)
156-
x_offset = x_offset.contiguous().view(b, c, h * ks, w * ks)
152+
x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
153+
x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
157154

158155
return x_offset
159156

0 commit comments

Comments
 (0)