22from torch .nn import functional as F
33
44from .base import GlobalConvolutionalNetwork , BoundaryRefinement , DeconvConv2dBnRelu
5- from .encoders import ResNetEncoders
5+ from .encoders import get_encoder_channel_nr
66
77
88class LargeKernelMatters (nn .Module ):
@@ -11,34 +11,29 @@ class LargeKernelMatters(nn.Module):
1111 https://arxiv.org/pdf/1703.02719.pdf
1212 """
1313
14- def __init__ (self , encoder_depth , num_classes , kernel_size = 9 , internal_channels = 21 , use_relu = False , pool0 = False ,
15- pretrained = False , dropout_2d = 0.0 ):
14+ def __init__ (self , encoder , num_classes , kernel_size = 9 , internal_channels = 21 , use_relu = False , pool0 = False ,
15+ dropout_2d = 0.0 ):
1616 super ().__init__ ()
1717
1818 self .dropout_2d = dropout_2d
19+ self .pool0 = pool0
1920
20- self .encoders = ResNetEncoders (encoder_depth , pretrained = pretrained , pool0 = pool0 )
21+ self .encoder = encoder
22+ encoder_channel_nr = get_encoder_channel_nr (self .encoder )
2123
22- if encoder_depth in [18 , 34 ]:
23- bottom_channel_nr = 512
24- elif encoder_depth in [50 , 101 , 152 ]:
25- bottom_channel_nr = 2048
26- else :
27- raise NotImplementedError ('only 18, 34, 50, 101, 152 version of Resnet are implemented' )
28-
29- self .gcn2 = GlobalConvolutionalNetwork (in_channels = bottom_channel_nr // 8 ,
24+ self .gcn2 = GlobalConvolutionalNetwork (in_channels = encoder_channel_nr [0 ],
3025 out_channels = internal_channels ,
3126 kernel_size = kernel_size ,
3227 use_relu = use_relu )
33- self .gcn3 = GlobalConvolutionalNetwork (in_channels = bottom_channel_nr // 4 ,
28+ self .gcn3 = GlobalConvolutionalNetwork (in_channels = encoder_channel_nr [ 1 ] ,
3429 out_channels = internal_channels ,
3530 kernel_size = kernel_size ,
3631 use_relu = use_relu )
37- self .gcn4 = GlobalConvolutionalNetwork (in_channels = bottom_channel_nr // 2 ,
32+ self .gcn4 = GlobalConvolutionalNetwork (in_channels = encoder_channel_nr [ 2 ] ,
3833 out_channels = internal_channels ,
3934 kernel_size = kernel_size ,
4035 use_relu = use_relu )
41- self .gcn5 = GlobalConvolutionalNetwork (in_channels = bottom_channel_nr ,
36+ self .gcn5 = GlobalConvolutionalNetwork (in_channels = encoder_channel_nr [ 3 ] ,
4237 out_channels = internal_channels ,
4338 kernel_size = kernel_size ,
4439 use_relu = use_relu )
@@ -79,10 +74,18 @@ def __init__(self, encoder_depth, num_classes, kernel_size=9, internal_channels=
7974 self .deconv3 = DeconvConv2dBnRelu (in_channels = internal_channels , out_channels = internal_channels )
8075 self .deconv2 = DeconvConv2dBnRelu (in_channels = internal_channels , out_channels = internal_channels )
8176
77+ self .deconv1 = DeconvConv2dBnRelu (in_channels = internal_channels , out_channels = internal_channels )
78+ self .dec_br0_1 = BoundaryRefinement (in_channels = internal_channels ,
79+ out_channels = internal_channels ,
80+ kernel_size = 3 )
81+ self .dec_br0_2 = BoundaryRefinement (in_channels = internal_channels ,
82+ out_channels = internal_channels ,
83+ kernel_size = 3 )
84+
8285 self .final = nn .Conv2d (internal_channels , num_classes , kernel_size = 1 , padding = 0 )
8386
8487 def forward (self , x ):
85- encoder2 , encoder3 , encoder4 , encoder5 = self .encoders (x )
88+ encoder2 , encoder3 , encoder4 , encoder5 = self .encoder (x )
8689 encoder5 = F .dropout2d (encoder5 , p = self .dropout_2d )
8790
8891 gcn2 = self .enc_br2 (self .gcn2 (encoder2 ))
@@ -95,4 +98,7 @@ def forward(self, x):
9598 decoder3 = self .deconv3 (self .dec_br3 (decoder4 + gcn3 ))
9699 decoder2 = self .dec_br1 (self .deconv2 (self .dec_br2 (decoder3 + gcn2 )))
97100
101+ if self .pool0 :
102+ decoder2 = self .dec_br0_2 (self .deconv1 (self .dec_br0_1 (decoder2 )))
103+
98104 return self .final (decoder2 )
0 commit comments