11import torch
22import torch .nn .functional as F
33
4- from NeuFlow import backbone_v8
4+ from NeuFlow import backbone_v7
55from NeuFlow import transformer
66from NeuFlow import matching
77from NeuFlow import corr
@@ -17,7 +17,7 @@ class NeuFlow(torch.nn.Module):
1717 def __init__ (self ):
1818 super (NeuFlow , self ).__init__ ()
1919
20- self .backbone = backbone_v8 .CNNEncoder (config .feature_dim_s16 , config .context_dim_s16 , config .feature_dim_s8 , config .context_dim_s8 )
20+ self .backbone = backbone_v7 .CNNEncoder (config .feature_dim_s16 , config .context_dim_s16 , config .feature_dim_s8 , config .context_dim_s8 )
2121
2222 self .cross_attn_s16 = transformer .FeatureAttention (config .feature_dim_s16 + config .context_dim_s16 , num_layers = 2 , ffn = True , ffn_dim_expansion = 1 , post_norm = True )
2323
@@ -30,16 +30,18 @@ def __init__(self):
3030
3131 self .merge_s8 = torch .nn .Sequential (torch .nn .Conv2d (config .feature_dim_s16 + config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
3232 torch .nn .GELU (),
33- torch .nn .Conv2d (config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ))
33+ torch .nn .Conv2d (config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
34+ torch .nn .BatchNorm2d (config .feature_dim_s8 ))
3435
3536 self .context_merge_s8 = torch .nn .Sequential (torch .nn .Conv2d (config .context_dim_s16 + config .context_dim_s8 , config .context_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
3637 torch .nn .GELU (),
37- torch .nn .Conv2d (config .context_dim_s8 , config .context_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ))
38+ torch .nn .Conv2d (config .context_dim_s8 , config .context_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
39+ torch .nn .BatchNorm2d (config .context_dim_s8 ))
3840
3941 self .refine_s16 = refine .Refine (config .context_dim_s16 , config .iter_context_dim_s16 , num_layers = 5 , levels = 1 , radius = 4 , inter_dim = 128 )
4042 self .refine_s8 = refine .Refine (config .context_dim_s8 , config .iter_context_dim_s8 , num_layers = 5 , levels = 1 , radius = 4 , inter_dim = 96 )
4143
42- self .conv_s8 = backbone_v8 .ConvBlock (3 , config .feature_dim_s1 , kernel_size = 8 , stride = 8 , padding = 0 )
44+ self .conv_s8 = backbone_v7 .ConvBlock (3 , config .feature_dim_s1 , kernel_size = 8 , stride = 8 , padding = 0 )
4345 self .upsample_s8 = upsample .UpSample (config .feature_dim_s1 , upsample_factor = 8 )
4446
4547 for p in self .parameters ():
@@ -70,7 +72,7 @@ def split_features(self, features, context_dim, feature_dim):
7072
7173 return features , torch .relu (context )
7274
73- def forward (self , img0 , img1 , iters_s16 = 3 , iters_s8 = 7 ):
75+ def forward (self , img0 , img1 , iters_s16 = 2 , iters_s8 = 7 ):
7476
7577 flow_list = []
7678
@@ -122,7 +124,6 @@ def forward(self, img0, img1, iters_s16=3, iters_s8=7):
122124
123125 context_s16 = F .interpolate (context_s16 , scale_factor = 2 , mode = 'nearest' )
124126
125- context_s8 = torch .zeros_like (context_s8 )
126127 context_s8 = self .context_merge_s8 (torch .cat ([context_s8 , context_s16 ], dim = 1 ))
127128
128129 iter_context_s8 = self .init_iter_context_s8
0 commit comments