Skip to content

Commit 2e88459

Browse files
author
Study-is-happy
committed
t
1 parent 2b182f7 commit 2e88459

File tree

10 files changed

+222
-54
lines changed

10 files changed

+222
-54
lines changed

NeuFlow/backbone_v6.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ def init_pos(self, batch_size, height, width, device, amp):
5454
return pos[None].repeat(batch_size,1,1,1)
5555

5656
def init_bhwd(self, batch_size, height, width, device, amp):
57-
self.pos_s8 = self.init_pos(batch_size, height//8, width//8, device, amp)
58-
self.pos_s16 = self.init_pos(batch_size, height//16, width//16, device, amp)
57+
self.pos_s16 = self.init_pos(batch_size, height, width, device, amp)
5958

6059
def forward(self, img):
6160

@@ -75,6 +74,5 @@ def forward(self, img):
7574
x_16 = self.block_cat_16(torch.cat([x_16, x_16_2], dim=1))
7675

7776
x_16 = torch.cat([x_16, self.pos_s16], dim=1)
78-
x_8 = torch.cat([x_8, self.pos_s8], dim=1)
7977

8078
return x_16, x_8

NeuFlow/backbone_v7.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
5+
class ConvBlock(torch.nn.Module):
6+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding):
7+
super(ConvBlock, self).__init__()
8+
9+
self.conv1 = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode='zeros', bias=False)
10+
11+
self.conv2 = torch.nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
12+
13+
self.relu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=False)
14+
15+
self.norm1 = torch.nn.BatchNorm2d(out_planes)
16+
17+
self.norm2 = torch.nn.BatchNorm2d(out_planes)
18+
19+
# self.dropout = torch.nn.Dropout(p=0.1)
20+
21+
def forward(self, x):
22+
23+
# x = self.dropout(x)
24+
25+
x = self.relu(self.norm1(self.conv1(x)))
26+
x = self.relu(self.norm2(self.conv2(x)))
27+
# x = self.relu(self.conv1(x))
28+
# x = self.relu(self.conv2(x))
29+
30+
return x
31+
32+
class CNNEncoder(torch.nn.Module):
33+
def __init__(self, feature_dim_s16, context_dim_s16, feature_dim_s8, context_dim_s8):
34+
super(CNNEncoder, self).__init__()
35+
36+
self.block_8_1 = ConvBlock(3, feature_dim_s8 * 2, kernel_size=8, stride=4, padding=2)
37+
38+
self.block_8_2 = ConvBlock(3, feature_dim_s8, kernel_size=6, stride=2, padding=2)
39+
40+
self.block_cat_8 = ConvBlock(feature_dim_s8 * 3, feature_dim_s8 + context_dim_s8, kernel_size=3, stride=1, padding=1)
41+
42+
self.block_16_1 = ConvBlock(3, feature_dim_s16, kernel_size=6, stride=2, padding=2)
43+
44+
self.block_8_16 = ConvBlock(feature_dim_s8 + context_dim_s8, feature_dim_s16, kernel_size=6, stride=2, padding=2)
45+
46+
self.block_cat_16 = ConvBlock(feature_dim_s16 * 2, feature_dim_s16 + context_dim_s16 - 2, kernel_size=3, stride=1, padding=1)
47+
48+
def init_pos(self, batch_size, height, width, device, amp):
49+
ys, xs = torch.meshgrid(torch.arange(height, dtype=torch.half if amp else torch.float, device=device), torch.arange(width, dtype=torch.half if amp else torch.float, device=device), indexing='ij')
50+
ys = (ys-height/2)
51+
xs = (xs-width/2)
52+
pos = torch.stack([ys, xs])
53+
return pos[None].repeat(batch_size,1,1,1)
54+
55+
def init_bhwd(self, batch_size, height, width, device, amp):
56+
self.pos_s16 = self.init_pos(batch_size, height, width, device, amp)
57+
58+
def forward(self, img):
59+
60+
img = F.avg_pool2d(img, kernel_size=2, stride=2)
61+
x_8 = self.block_8_1(img)
62+
63+
img = F.avg_pool2d(img, kernel_size=2, stride=2)
64+
x_8_2 = self.block_8_2(img)
65+
66+
x_8 = self.block_cat_8(torch.cat([x_8, x_8_2], dim=1))
67+
68+
img = F.avg_pool2d(img, kernel_size=2, stride=2)
69+
x_16 = self.block_16_1(img)
70+
71+
x_16_2 = self.block_8_16(x_8)
72+
73+
x_16 = self.block_cat_16(torch.cat([x_16, x_16_2], dim=1))
74+
75+
x_16 = torch.cat([x_16, self.pos_s16], dim=1)
76+
77+
return x_16, x_8

NeuFlow/backbone_v8.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
5+
class ConvBlock(torch.nn.Module):
6+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding):
7+
super(ConvBlock, self).__init__()
8+
9+
self.conv1 = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode='zeros', bias=False)
10+
11+
self.conv2 = torch.nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
12+
13+
self.relu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=False)
14+
15+
self.norm1 = torch.nn.BatchNorm2d(out_planes)
16+
17+
self.norm2 = torch.nn.BatchNorm2d(out_planes)
18+
19+
# self.dropout = torch.nn.Dropout(p=0.1)
20+
21+
def forward(self, x):
22+
23+
# x = self.dropout(x)
24+
25+
x = self.relu(self.norm1(self.conv1(x)))
26+
x = self.relu(self.norm2(self.conv2(x)))
27+
# x = self.relu(self.conv1(x))
28+
# x = self.relu(self.conv2(x))
29+
30+
return x
31+
32+
class CNNEncoder(torch.nn.Module):
33+
def __init__(self, feature_dim_s16, context_dim_s16, feature_dim_s8, context_dim_s8):
34+
super(CNNEncoder, self).__init__()
35+
36+
self.block_8_1 = ConvBlock(3, feature_dim_s8 * 2, kernel_size=8, stride=4, padding=2)
37+
38+
self.block_8_2 = ConvBlock(3, feature_dim_s8, kernel_size=6, stride=2, padding=2)
39+
40+
self.block_cat_8 = ConvBlock(feature_dim_s8 * 3, feature_dim_s8 + context_dim_s8, kernel_size=3, stride=1, padding=1)
41+
42+
self.block_8_16 = ConvBlock(feature_dim_s8 + context_dim_s8, feature_dim_s16 + context_dim_s16 - 2, kernel_size=6, stride=2, padding=2)
43+
44+
def init_pos(self, batch_size, height, width, device, amp):
45+
ys, xs = torch.meshgrid(torch.arange(height, dtype=torch.half if amp else torch.float, device=device), torch.arange(width, dtype=torch.half if amp else torch.float, device=device), indexing='ij')
46+
ys = (ys-height/2)
47+
xs = (xs-width/2)
48+
pos = torch.stack([ys, xs])
49+
return pos[None].repeat(batch_size,1,1,1)
50+
51+
def init_bhwd(self, batch_size, height, width, device, amp):
52+
self.pos_s16 = self.init_pos(batch_size, height, width, device, amp)
53+
54+
def forward(self, img):
55+
56+
img = F.avg_pool2d(img, kernel_size=2, stride=2)
57+
x_8 = self.block_8_1(img)
58+
59+
img = F.avg_pool2d(img, kernel_size=2, stride=2)
60+
x_8_2 = self.block_8_2(img)
61+
62+
x_8 = self.block_cat_8(torch.cat([x_8, x_8_2], dim=1))
63+
64+
x_16 = self.block_8_16(x_8)
65+
66+
x_16 = torch.cat([x_16, self.pos_s16], dim=1)
67+
68+
return x_16, x_8

NeuFlow/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
feature_dim_s16 = 256
2-
hidden_dim_s16 = 128
1+
feature_dim_s16 = 128
2+
context_dim_s16 = 64
3+
iter_context_dim_s16 = 64
34
feature_dim_s8 = 128
4-
hidden_dim_s8 = 96
5+
context_dim_s8 = 64
6+
iter_context_dim_s8 = 64
57
feature_dim_s1 = 128

NeuFlow/neuflow.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33

4-
from NeuFlow import backbone_v6
4+
from NeuFlow import backbone_v8
55
from NeuFlow import transformer
66
from NeuFlow import matching
77
from NeuFlow import corr
@@ -17,38 +17,29 @@ class NeuFlow(torch.nn.Module):
1717
def __init__(self):
1818
super(NeuFlow, self).__init__()
1919

20-
self.backbone = backbone_v6.CNNEncoder(config.feature_dim_s16, config.feature_dim_s8)
20+
self.backbone = backbone_v8.CNNEncoder(config.feature_dim_s16, config.context_dim_s16, config.feature_dim_s8, config.context_dim_s8)
2121

22-
self.cross_attn_s16 = transformer.FeatureAttention(config.feature_dim_s16, num_layers=2, ffn=True, ffn_dim_expansion=1, post_norm=True)
22+
self.cross_attn_s16 = transformer.FeatureAttention(config.feature_dim_s16+config.context_dim_s16, num_layers=2, ffn=True, ffn_dim_expansion=1, post_norm=True)
2323

2424
self.matching_s16 = matching.Matching()
2525

2626
# self.flow_attn_s16 = transformer.FlowAttention(config.feature_dim_s16)
27-
28-
self.merge_s8 = torch.nn.Sequential(torch.nn.Conv2d(config.feature_dim_s16 + config.feature_dim_s8 + 2, config.feature_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
29-
torch.nn.GELU(),
30-
torch.nn.Conv2d(config.feature_dim_s8, config.feature_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
31-
torch.nn.Tanh())
3227

3328
self.corr_block_s16 = corr.CorrBlock(radius=4, levels=1)
3429
self.corr_block_s8 = corr.CorrBlock(radius=4, levels=1)
30+
31+
self.merge_s8 = torch.nn.Sequential(torch.nn.Conv2d(config.feature_dim_s16 + config.feature_dim_s8, config.feature_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
32+
torch.nn.GELU(),
33+
torch.nn.Conv2d(config.feature_dim_s8, config.feature_dim_s8, kernel_size=3, stride=1, padding=1, bias=False))
3534

36-
self.context_s16 = torch.nn.Sequential(torch.nn.Conv2d(config.feature_dim_s16, config.hidden_dim_s16, kernel_size=3, stride=1, padding=1, bias=False),
37-
torch.nn.GELU(),
38-
torch.nn.Conv2d(config.hidden_dim_s16, config.hidden_dim_s16, kernel_size=3, stride=1, padding=1, bias=False))
39-
40-
self.context_merge_s8 = torch.nn.Sequential(torch.nn.Conv2d(config.hidden_dim_s16 + config.feature_dim_s8, config.hidden_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
35+
self.context_merge_s8 = torch.nn.Sequential(torch.nn.Conv2d(config.context_dim_s16 + config.context_dim_s8, config.context_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
4136
torch.nn.GELU(),
42-
torch.nn.Conv2d(config.hidden_dim_s8, config.hidden_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
43-
torch.nn.Tanh())
44-
45-
self.refine_s16 = refine.Refine(config.hidden_dim_s16, num_layers=6, levels=1, radius=4)
46-
self.refine_s8 = refine.Refine(config.hidden_dim_s8, num_layers=6, levels=1, radius=4)
37+
torch.nn.Conv2d(config.context_dim_s8, config.context_dim_s8, kernel_size=3, stride=1, padding=1, bias=False))
4738

48-
# self.conv_s16 = backbone_v6.ConvBlock(3, config.feature_dim_s1 * 2, kernel_size=16, stride=16, padding=0)
49-
# self.upsample_s16 = upsample.UpSample(config.feature_dim_s1 * 2, upsample_factor=16)
39+
self.refine_s16 = refine.Refine(config.context_dim_s16, config.iter_context_dim_s16, num_layers=5, levels=1, radius=4, inter_dim=128)
40+
self.refine_s8 = refine.Refine(config.context_dim_s8, config.iter_context_dim_s8, num_layers=5, levels=1, radius=4, inter_dim=96)
5041

51-
self.conv_s8 = backbone_v6.ConvBlock(3, config.feature_dim_s1, kernel_size=8, stride=8, padding=0)
42+
self.conv_s8 = backbone_v8.ConvBlock(3, config.feature_dim_s1, kernel_size=8, stride=8, padding=0)
5243
self.upsample_s8 = upsample.UpSample(config.feature_dim_s1, upsample_factor=8)
5344

5445
for p in self.parameters():
@@ -57,7 +48,7 @@ def __init__(self):
5748

5849
def init_bhwd(self, batch_size, height, width, device, amp=True):
5950

60-
self.backbone.init_bhwd(batch_size*2, height, width, device, amp)
51+
self.backbone.init_bhwd(batch_size*2, height//16, width//16, device, amp)
6152

6253
self.matching_s16.init_bhwd(batch_size, height//16, width//16, device, amp)
6354

@@ -67,7 +58,19 @@ def init_bhwd(self, batch_size, height, width, device, amp=True):
6758
self.refine_s16.init_bhwd(batch_size, height//16, width//16, device, amp)
6859
self.refine_s8.init_bhwd(batch_size, height//8, width//8, device, amp)
6960

70-
def forward(self, img0, img1, iters_s16=1, iters_s8=6):
61+
self.init_iter_context_s16 = torch.zeros(batch_size, config.iter_context_dim_s16, height//16, width//16, device=device, dtype=torch.half if amp else torch.float)
62+
self.init_iter_context_s8 = torch.zeros(batch_size, config.iter_context_dim_s8, height//8, width//8, device=device, dtype=torch.half if amp else torch.float)
63+
64+
def split_features(self, features, context_dim, feature_dim):
65+
66+
context, features = torch.split(features, [context_dim, feature_dim], dim=1)
67+
68+
context, _ = context.chunk(chunks=2, dim=0)
69+
feature0, feature1 = features.chunk(chunks=2, dim=0)
70+
71+
return features, torch.relu(context)
72+
73+
def forward(self, img0, img1, iters_s16=3, iters_s8=7):
7174

7275
flow_list = []
7376

@@ -78,6 +81,9 @@ def forward(self, img0, img1, iters_s16=1, iters_s8=6):
7881

7982
features_s16 = self.cross_attn_s16(features_s16)
8083

84+
features_s16, context_s16 = self.split_features(features_s16, config.context_dim_s16, config.feature_dim_s16)
85+
features_s8, context_s8 = self.split_features(features_s8, config.context_dim_s8, config.feature_dim_s8)
86+
8187
feature0_s16, feature1_s16 = features_s16.chunk(chunks=2, dim=0)
8288

8389
flow0 = self.matching_s16.global_correlation_softmax(feature0_s16, feature1_s16)
@@ -86,18 +92,17 @@ def forward(self, img0, img1, iters_s16=1, iters_s8=6):
8692

8793
corr_pyr_s16 = self.corr_block_s16.init_corr_pyr(feature0_s16, feature1_s16)
8894

89-
context_s16 = self.context_s16(feature0_s16)
90-
iter_context_s16 = context_s16.clone()
95+
iter_context_s16 = self.init_iter_context_s16
9196

9297
for i in range(iters_s16):
9398

9499
if self.training and i > 0:
95100
flow0 = flow0.detach()
96-
# iter_feature0_s16 = iter_feature0_s16.detach()
101+
# iter_context_s16 = iter_context_s16.detach()
97102

98103
corrs = self.corr_block_s16(corr_pyr_s16, flow0)
99104

100-
iter_context_s16, delta_flow = self.refine_s16(corrs, iter_context_s16, flow0)
105+
iter_context_s16, delta_flow = self.refine_s16(corrs, context_s16, iter_context_s16, flow0)
101106

102107
flow0 = flow0 + delta_flow
103108

@@ -117,18 +122,20 @@ def forward(self, img0, img1, iters_s16=1, iters_s8=6):
117122

118123
context_s16 = F.interpolate(context_s16, scale_factor=2, mode='nearest')
119124

120-
context_s8 = self.context_merge_s8(torch.cat([feature0_s8, context_s16], dim=1))
121-
iter_context_s8 = context_s8.clone()
125+
context_s8 = torch.zeros_like(context_s8)
126+
context_s8 = self.context_merge_s8(torch.cat([context_s8, context_s16], dim=1))
127+
128+
iter_context_s8 = self.init_iter_context_s8
122129

123130
for i in range(iters_s8):
124131

125132
if self.training and i > 0:
126133
flow0 = flow0.detach()
127-
# iter_feature0_s8 = iter_feature0_s8.detach()
134+
# iter_context_s8 = iter_context_s8.detach()
128135

129136
corrs = self.corr_block_s8(corr_pyr_s8, flow0)
130137

131-
iter_context_s8, delta_flow = self.refine_s8(corrs, iter_context_s8, flow0)
138+
iter_context_s8, delta_flow = self.refine_s8(corrs, context_s8, iter_context_s8, flow0)
132139

133140
flow0 = flow0 + delta_flow
134141

NeuFlow/refine.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,49 @@
11
import torch
2+
from NeuFlow import utils
23

34

45
class ConvBlock(torch.nn.Module):
56
def __init__(self, in_planes, out_planes, kernel_size, stride, padding):
67
super(ConvBlock, self).__init__()
78

8-
self.conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode='zeros', bias=True)
9+
self.conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode='zeros', bias=False)
910
self.relu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=False)
1011

1112
def forward(self, x):
1213
return self.relu(self.conv(x))
1314

1415
class Refine(torch.nn.Module):
15-
def __init__(self, feature_dim, num_layers, levels, radius):
16+
def __init__(self, context_dim, iter_context_dim, num_layers, levels, radius, inter_dim):
1617
super(Refine, self).__init__()
1718

1819
self.radius = radius
1920

20-
self.conv1 = ConvBlock((radius*2+1)**2*levels+feature_dim+2+1, feature_dim, kernel_size=3, stride=1, padding=1)
21+
self.conv1 = ConvBlock((radius*2+1)**2*levels+context_dim+iter_context_dim+2+1, context_dim+iter_context_dim, kernel_size=3, stride=1, padding=1)
2122

22-
self.conv_layers = torch.nn.ModuleList([ConvBlock(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1)
23+
self.conv2 = ConvBlock(context_dim+iter_context_dim, inter_dim, kernel_size=3, stride=1, padding=1)
24+
25+
self.conv_layers = torch.nn.ModuleList([ConvBlock(inter_dim, inter_dim, kernel_size=3, stride=1, padding=1)
2326
for i in range(num_layers)])
2427

25-
self.conv2 = torch.nn.Conv2d(feature_dim, feature_dim+2, kernel_size=3, stride=1, padding=1, padding_mode='zeros', bias=True)
28+
self.conv3 = torch.nn.Conv2d(inter_dim, iter_context_dim+2, kernel_size=3, stride=1, padding=1, padding_mode='zeros', bias=True)
29+
30+
self.hidden_act = torch.nn.Tanh()
31+
# self.hidden_norm = torch.nn.BatchNorm2d(feature_dim)
2632

2733
def init_bhwd(self, batch_size, height, width, device, amp):
2834
self.radius_emb = torch.tensor(self.radius, dtype=torch.half if amp else torch.float, device=device).view(1,-1,1,1).expand([batch_size,1,height,width])
2935

30-
def forward(self, corrs, feature0, flow0):
36+
def forward(self, corrs, context, iter_context, flow0):
3137

32-
x = torch.cat([corrs, feature0, flow0, self.radius_emb], dim=1)
38+
x = torch.cat([corrs, context, iter_context, flow0, self.radius_emb], dim=1)
3339

3440
x = self.conv1(x)
3541

42+
x = self.conv2(x)
43+
3644
for layer in self.conv_layers:
3745
x = layer(x)
3846

39-
x = self.conv2(x)
47+
x = self.conv3(x)
4048

41-
return torch.tanh(x[:,2:]), x[:,:2]
49+
return self.hidden_act(x[:,2:]), x[:,:2]

NeuFlow/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
import torch.nn.functional as F
33

4-
def normalize_img(img, mean, std):
5-
return (img / 255. - mean) / std
4+
# def normalize(x):
5+
# x_min = x.min()
6+
# return (x - x_min) / (x.max() - x_min)
67

78
def coords_grid(b, h, w, device, amp):
89
ys, xs = torch.meshgrid(torch.arange(h, dtype=torch.half if amp else torch.float, device=device), torch.arange(w, dtype=torch.half if amp else torch.float, device=device), indexing='ij') # [H, W]

0 commit comments

Comments
 (0)