@@ -65,7 +65,7 @@ def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
6565# Defining the models.
6666class MLPHead (nn .Module ):
6767 def __init__ (self , in_channels : int , projection_size : int = 256 , hidden_size : int = 4096 ):
68- super (MLPHead , self ).__init__ ()
68+ super ().__init__ ()
6969
7070 self .net = nn .Sequential (
7171 nn .Linear (in_channels , hidden_size ),
@@ -81,7 +81,7 @@ def forward(self, x):
8181# Defining resnet encoders.
8282class ResnetEncoder (nn .Module ):
8383 def __init__ (self , pretrained , mlp_params ):
84- super (ResnetEncoder , self ).__init__ ()
84+ super ().__init__ ()
8585 resnet = torchvision .models .resnet18 (pretrained = pretrained )
8686 self .encoder = torch .nn .Sequential (* list (resnet .children ())[:- 1 ])
8787 self .projector = MLPHead (in_channels = resnet .fc .in_features , ** mlp_params )
@@ -95,7 +95,7 @@ def forward(self, x):
9595# Defining custom training method required as required by Bootstrap your own latent.(SSL)
9696class BYOLExperiment (Experiment ):
9797 def __init__ (self , momentum , augmentation_fn , image_size , ** kwargs ):
98- super (BYOLExperiment , self ).__init__ (** kwargs )
98+ super ().__init__ (** kwargs )
9999 self .momentum = momentum
100100 self .augmentation_fn = augmentation_fn (image_size )
101101
0 commit comments