11import torch
22import torch .nn .functional as F
33
4- from NeuFlow import backbone_v6
4+ from NeuFlow import backbone_v8
55from NeuFlow import transformer
66from NeuFlow import matching
77from NeuFlow import corr
@@ -17,38 +17,29 @@ class NeuFlow(torch.nn.Module):
1717 def __init__ (self ):
1818 super (NeuFlow , self ).__init__ ()
1919
20- self .backbone = backbone_v6 .CNNEncoder (config .feature_dim_s16 , config .feature_dim_s8 )
20+ self .backbone = backbone_v8 .CNNEncoder (config .feature_dim_s16 , config .context_dim_s16 , config . feature_dim_s8 , config . context_dim_s8 )
2121
22- self .cross_attn_s16 = transformer .FeatureAttention (config .feature_dim_s16 , num_layers = 2 , ffn = True , ffn_dim_expansion = 1 , post_norm = True )
22+ self .cross_attn_s16 = transformer .FeatureAttention (config .feature_dim_s16 + config . context_dim_s16 , num_layers = 2 , ffn = True , ffn_dim_expansion = 1 , post_norm = True )
2323
2424 self .matching_s16 = matching .Matching ()
2525
2626 # self.flow_attn_s16 = transformer.FlowAttention(config.feature_dim_s16)
27-
28- self .merge_s8 = torch .nn .Sequential (torch .nn .Conv2d (config .feature_dim_s16 + config .feature_dim_s8 + 2 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
29- torch .nn .GELU (),
30- torch .nn .Conv2d (config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
31- torch .nn .Tanh ())
3227
3328 self .corr_block_s16 = corr .CorrBlock (radius = 4 , levels = 1 )
3429 self .corr_block_s8 = corr .CorrBlock (radius = 4 , levels = 1 )
30+
31+ self .merge_s8 = torch .nn .Sequential (torch .nn .Conv2d (config .feature_dim_s16 + config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
32+ torch .nn .GELU (),
33+ torch .nn .Conv2d (config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ))
3534
36- self .context_s16 = torch .nn .Sequential (torch .nn .Conv2d (config .feature_dim_s16 , config .hidden_dim_s16 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
37- torch .nn .GELU (),
38- torch .nn .Conv2d (config .hidden_dim_s16 , config .hidden_dim_s16 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ))
39-
40- self .context_merge_s8 = torch .nn .Sequential (torch .nn .Conv2d (config .hidden_dim_s16 + config .feature_dim_s8 , config .hidden_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
35+ self .context_merge_s8 = torch .nn .Sequential (torch .nn .Conv2d (config .context_dim_s16 + config .context_dim_s8 , config .context_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
4136 torch .nn .GELU (),
42- torch .nn .Conv2d (config .hidden_dim_s8 , config .hidden_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
43- torch .nn .Tanh ())
44-
45- self .refine_s16 = refine .Refine (config .hidden_dim_s16 , num_layers = 6 , levels = 1 , radius = 4 )
46- self .refine_s8 = refine .Refine (config .hidden_dim_s8 , num_layers = 6 , levels = 1 , radius = 4 )
37+ torch .nn .Conv2d (config .context_dim_s8 , config .context_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ))
4738
48- # self.conv_s16 = backbone_v6.ConvBlock(3 , config.feature_dim_s1 * 2, kernel_size=16, stride=16, padding=0 )
49- # self.upsample_s16 = upsample.UpSample (config.feature_dim_s1 * 2, upsample_factor=16 )
39+ self .refine_s16 = refine . Refine ( config . context_dim_s16 , config .iter_context_dim_s16 , num_layers = 5 , levels = 1 , radius = 4 , inter_dim = 128 )
40+ self .refine_s8 = refine . Refine (config .context_dim_s8 , config . iter_context_dim_s8 , num_layers = 5 , levels = 1 , radius = 4 , inter_dim = 96 )
5041
51- self .conv_s8 = backbone_v6 .ConvBlock (3 , config .feature_dim_s1 , kernel_size = 8 , stride = 8 , padding = 0 )
42+ self .conv_s8 = backbone_v8 .ConvBlock (3 , config .feature_dim_s1 , kernel_size = 8 , stride = 8 , padding = 0 )
5243 self .upsample_s8 = upsample .UpSample (config .feature_dim_s1 , upsample_factor = 8 )
5344
5445 for p in self .parameters ():
@@ -57,7 +48,7 @@ def __init__(self):
5748
5849 def init_bhwd (self , batch_size , height , width , device , amp = True ):
5950
60- self .backbone .init_bhwd (batch_size * 2 , height , width , device , amp )
51+ self .backbone .init_bhwd (batch_size * 2 , height // 16 , width // 16 , device , amp )
6152
6253 self .matching_s16 .init_bhwd (batch_size , height // 16 , width // 16 , device , amp )
6354
@@ -67,7 +58,19 @@ def init_bhwd(self, batch_size, height, width, device, amp=True):
6758 self .refine_s16 .init_bhwd (batch_size , height // 16 , width // 16 , device , amp )
6859 self .refine_s8 .init_bhwd (batch_size , height // 8 , width // 8 , device , amp )
6960
70- def forward (self , img0 , img1 , iters_s16 = 1 , iters_s8 = 6 ):
61+ self .init_iter_context_s16 = torch .zeros (batch_size , config .iter_context_dim_s16 , height // 16 , width // 16 , device = device , dtype = torch .half if amp else torch .float )
62+ self .init_iter_context_s8 = torch .zeros (batch_size , config .iter_context_dim_s8 , height // 8 , width // 8 , device = device , dtype = torch .half if amp else torch .float )
63+
64+ def split_features (self , features , context_dim , feature_dim ):
65+
66+ context , features = torch .split (features , [context_dim , feature_dim ], dim = 1 )
67+
68+ context , _ = context .chunk (chunks = 2 , dim = 0 )
69+ feature0 , feature1 = features .chunk (chunks = 2 , dim = 0 )
70+
71+ return features , torch .relu (context )
72+
73+ def forward (self , img0 , img1 , iters_s16 = 3 , iters_s8 = 7 ):
7174
7275 flow_list = []
7376
@@ -78,6 +81,9 @@ def forward(self, img0, img1, iters_s16=1, iters_s8=6):
7881
7982 features_s16 = self .cross_attn_s16 (features_s16 )
8083
84+ features_s16 , context_s16 = self .split_features (features_s16 , config .context_dim_s16 , config .feature_dim_s16 )
85+ features_s8 , context_s8 = self .split_features (features_s8 , config .context_dim_s8 , config .feature_dim_s8 )
86+
8187 feature0_s16 , feature1_s16 = features_s16 .chunk (chunks = 2 , dim = 0 )
8288
8389 flow0 = self .matching_s16 .global_correlation_softmax (feature0_s16 , feature1_s16 )
@@ -86,18 +92,17 @@ def forward(self, img0, img1, iters_s16=1, iters_s8=6):
8692
8793 corr_pyr_s16 = self .corr_block_s16 .init_corr_pyr (feature0_s16 , feature1_s16 )
8894
89- context_s16 = self .context_s16 (feature0_s16 )
90- iter_context_s16 = context_s16 .clone ()
95+ iter_context_s16 = self .init_iter_context_s16
9196
9297 for i in range (iters_s16 ):
9398
9499 if self .training and i > 0 :
95100 flow0 = flow0 .detach ()
96- # iter_feature0_s16 = iter_feature0_s16 .detach()
101+ # iter_context_s16 = iter_context_s16 .detach()
97102
98103 corrs = self .corr_block_s16 (corr_pyr_s16 , flow0 )
99104
100- iter_context_s16 , delta_flow = self .refine_s16 (corrs , iter_context_s16 , flow0 )
105+ iter_context_s16 , delta_flow = self .refine_s16 (corrs , context_s16 , iter_context_s16 , flow0 )
101106
102107 flow0 = flow0 + delta_flow
103108
@@ -117,18 +122,20 @@ def forward(self, img0, img1, iters_s16=1, iters_s8=6):
117122
118123 context_s16 = F .interpolate (context_s16 , scale_factor = 2 , mode = 'nearest' )
119124
120- context_s8 = self .context_merge_s8 (torch .cat ([feature0_s8 , context_s16 ], dim = 1 ))
121- iter_context_s8 = context_s8 .clone ()
125+ context_s8 = torch .zeros_like (context_s8 )
126+ context_s8 = self .context_merge_s8 (torch .cat ([context_s8 , context_s16 ], dim = 1 ))
127+
128+ iter_context_s8 = self .init_iter_context_s8
122129
123130 for i in range (iters_s8 ):
124131
125132 if self .training and i > 0 :
126133 flow0 = flow0 .detach ()
127- # iter_feature0_s8 = iter_feature0_s8 .detach()
134+ # iter_context_s8 = iter_context_s8 .detach()
128135
129136 corrs = self .corr_block_s8 (corr_pyr_s8 , flow0 )
130137
131- iter_context_s8 , delta_flow = self .refine_s8 (corrs , iter_context_s8 , flow0 )
138+ iter_context_s8 , delta_flow = self .refine_s8 (corrs , context_s8 , iter_context_s8 , flow0 )
132139
133140 flow0 = flow0 + delta_flow
134141
0 commit comments