Skip to content

Commit 58d1634

Browse files
committed
first commit
0 parents  commit 58d1634

File tree

7 files changed

+472
-0
lines changed

7 files changed

+472
-0
lines changed

.gitignore

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
.idea/
29+
/.idea
30+
*.iml
31+
*.ppt
32+
*.pptx
33+
*.caffemodel
34+
result/
35+
36+
# PyInstaller
37+
# Usually these files are written by a python script from a template
38+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
39+
*.manifest
40+
*.spec
41+
42+
# Installer logs
43+
pip-log.txt
44+
pip-delete-this-directory.txt
45+
46+
# Unit test / coverage reports
47+
htmlcov/
48+
.tox/
49+
.nox/
50+
.coverage
51+
.coverage.*
52+
.cache
53+
nosetests.xml
54+
coverage.xml
55+
*.cover
56+
.hypothesis/
57+
.pytest_cache/
58+
59+
# Translations
60+
*.mo
61+
*.pot
62+
63+
# Django stuff:
64+
*.log
65+
local_settings.py
66+
db.sqlite3
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# IPython
85+
profile_default/
86+
ipython_config.py
87+
88+
# pyenv
89+
.python-version
90+
91+
# celery beat schedule file
92+
celerybeat-schedule
93+
94+
# SageMath parsed files
95+
*.sage.py
96+
97+
# Environments
98+
.env
99+
.venv
100+
env/
101+
venv/
102+
ENV/
103+
env.bak/
104+
venv.bak/
105+
106+
# Spyder project settings
107+
.spyderproject
108+
.spyproject
109+
110+
# Rope project settings
111+
.ropeproject
112+
113+
# mkdocs documentation
114+
/site
115+
116+
# mypy
117+
.mypy_cache/
118+
.dmypy.json
119+
dmypy.json
120+
121+
# Pyre type checker
122+
.pyre/
123+
124+
# JPG PNG
125+
*.jpg
126+
*.png

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Deformable_ConvNet_pytorch

main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@Time : 2019/2/19 16:06
4+
@Author : Wang Xin
5+
@Email : wangxin_buaa@163.com
6+
"""

network/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@Time : 2019/2/19 16:06
4+
@Author : Wang Xin
5+
@Email : wangxin_buaa@163.com
6+
"""

network/deform_conv/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@Time : 2019/2/18 15:19
4+
@Author : Wang Xin
5+
@Email : wangxin_buaa@163.com
6+
"""

network/deform_conv/deform_conv.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@Time : 2019/2/18 15:23
4+
@Author : Wang Xin
5+
@Email : wangxin_buaa@163.com
6+
"""
7+
8+
import numpy as np
9+
10+
import torch
11+
import torch.nn as nn
12+
from torch.autograd import Variable
13+
14+
15+
"""
16+
https://github.com/ChunhuanLin/deform_conv_pytorch/blob/master/deform_conv.py
17+
"""
18+
19+
20+
class DeformConv2D(nn.Module):
21+
def __init__(self, inc, outc, kernel_size=3, padding=1, bias=None):
22+
super(DeformConv2D, self).__init__()
23+
self.kernel_size = kernel_size
24+
self.padding = padding
25+
self.zero_padding = nn.ZeroPad2d(padding)
26+
self.conv_kernel = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
27+
28+
def forward(self, x, offset):
29+
dtype = offset.data.type()
30+
ks = self.kernel_size
31+
N = offset.size(1) // 2
32+
33+
# Change offset's order from [x1, x2, ..., y1, y2, ...] to [x1, y1, x2, y2, ...]
34+
# Codes below are written to make sure same results of MXNet implementation.
35+
# You can remove them, and it won't influence the module's performance.
36+
offsets_index = Variable(torch.cat([torch.arange(0, 2 * N, 2), torch.arange(1, 2 * N + 1, 2)]),
37+
requires_grad=False).type_as(x).long()
38+
offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*offset.size())
39+
offset = torch.gather(offset, dim=1, index=offsets_index)
40+
# ------------------------------------------------------------------------
41+
42+
if self.padding:
43+
x = self.zero_padding(x)
44+
45+
# (b, 2N, h, w)
46+
p = self._get_p(offset, dtype)
47+
48+
# (b, h, w, 2N)
49+
p = p.contiguous().permute(0, 2, 3, 1)
50+
q_lt = Variable(p.data, requires_grad=False).floor()
51+
q_rb = q_lt + 1
52+
53+
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
54+
dim=-1).long()
55+
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
56+
dim=-1).long()
57+
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1)
58+
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1)
59+
60+
# (b, h, w, N)
61+
mask = torch.cat([p[..., :N].lt(self.padding) + p[..., :N].gt(x.size(2) - 1 - self.padding),
62+
p[..., N:].lt(self.padding) + p[..., N:].gt(x.size(3) - 1 - self.padding)], dim=-1).type_as(p)
63+
mask = mask.detach()
64+
floor_p = p - (p - torch.floor(p))
65+
p = p * (1 - mask) + floor_p * mask
66+
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)
67+
68+
# bilinear kernel (b, h, w, N)
69+
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
70+
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
71+
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
72+
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
73+
74+
# (b, c, h, w, N)
75+
x_q_lt = self._get_x_q(x, q_lt, N)
76+
x_q_rb = self._get_x_q(x, q_rb, N)
77+
x_q_lb = self._get_x_q(x, q_lb, N)
78+
x_q_rt = self._get_x_q(x, q_rt, N)
79+
80+
# (b, c, h, w, N)
81+
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
82+
g_rb.unsqueeze(dim=1) * x_q_rb + \
83+
g_lb.unsqueeze(dim=1) * x_q_lb + \
84+
g_rt.unsqueeze(dim=1) * x_q_rt
85+
86+
x_offset = self._reshape_x_offset(x_offset, ks)
87+
out = self.conv_kernel(x_offset)
88+
89+
return out
90+
91+
def _get_p_n(self, N, dtype):
92+
p_n_x, p_n_y = np.meshgrid(range(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1),
93+
range(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1), indexing='ij')
94+
# (2N, 1)
95+
p_n = np.concatenate((p_n_x.flatten(), p_n_y.flatten()))
96+
p_n = np.reshape(p_n, (1, 2 * N, 1, 1))
97+
p_n = Variable(torch.from_numpy(p_n).type(dtype), requires_grad=False)
98+
99+
return p_n
100+
101+
@staticmethod
102+
def _get_p_0(h, w, N, dtype):
103+
p_0_x, p_0_y = np.meshgrid(range(1, h + 1), range(1, w + 1), indexing='ij')
104+
p_0_x = p_0_x.flatten().reshape(1, 1, h, w).repeat(N, axis=1)
105+
p_0_y = p_0_y.flatten().reshape(1, 1, h, w).repeat(N, axis=1)
106+
p_0 = np.concatenate((p_0_x, p_0_y), axis=1)
107+
p_0 = Variable(torch.from_numpy(p_0).type(dtype), requires_grad=False)
108+
109+
return p_0
110+
111+
def _get_p(self, offset, dtype):
112+
N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)
113+
114+
# (1, 2N, 1, 1)
115+
p_n = self._get_p_n(N, dtype)
116+
# (1, 2N, h, w)
117+
p_0 = self._get_p_0(h, w, N, dtype)
118+
p = p_0 + p_n + offset
119+
return p
120+
121+
def _get_x_q(self, x, q, N):
122+
b, h, w, _ = q.size()
123+
padded_w = x.size(3)
124+
c = x.size(1)
125+
# (b, c, h*w)
126+
x = x.contiguous().view(b, c, -1)
127+
128+
# (b, h, w, N)
129+
index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y
130+
# (b, c, h*w*N)
131+
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
132+
133+
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
134+
135+
return x_offset
136+
137+
@staticmethod
138+
def _reshape_x_offset(x_offset, ks):
139+
b, c, h, w, N = x_offset.size()
140+
x_offset = torch.cat([x_offset[..., s:s + ks].contiguous().view(b, c, h, w * ks) for s in range(0, N, ks)],
141+
dim=-1)
142+
x_offset = x_offset.contiguous().view(b, c, h * ks, w * ks)
143+
144+
return x_offset
145+
146+
147+
if __name__ == '__main__':
148+
x = torch.randn(4, 3, 5, 5)
149+
150+
p_conv = nn.Conv2d(3, 2 * 3 * 3, kernel_size=3, padding=1, stride=1)
151+
d_conv2 = DeformConv2D(3, 64)
152+
153+
offset = p_conv(x)
154+
y = d_conv2(x, offset)
155+
156+
print(y.size())
157+
print('y = ', y)
158+
print("offset:")
159+
print(offset)

0 commit comments

Comments
 (0)